diff sicm/utils.clj @ 2:b4de894a1e2e

initial import
author Robert McIntyre <rlm@mit.edu>
date Fri, 28 Oct 2011 00:03:05 -0700
parents
children
line wrap: on
line diff
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/sicm/utils.clj	Fri Oct 28 00:03:05 2011 -0700
     1.3 @@ -0,0 +1,496 @@
     1.4 +
     1.5 +(ns sicm.utils)
     1.6 +
     1.7 +(in-ns 'sicm.utils)
     1.8 +
     1.9 +;; Let some objects have spin
    1.10 +
    1.11 +(defprotocol Spinning
    1.12 +  (up? [this])
    1.13 +  (down? [this]))
    1.14 +
    1.15 +(defn spin
    1.16 +  "Returns the spin of the Spinning s, either :up or :down"
    1.17 +  [#^Spinning s]
    1.18 +  (cond (up? s) :up (down? s) :down))
    1.19 +
    1.20 +
    1.21 +;; DEFINITION: A tuple is a sequence with spin
    1.22 +
    1.23 +(deftype Tuple
    1.24 +  [spin coll]
    1.25 +
    1.26 +  clojure.lang.Seqable
    1.27 +  (seq [this] (seq (.coll this)))
    1.28 +
    1.29 +  clojure.lang.Counted
    1.30 +  (count [this] (count (.coll this)))
    1.31 +
    1.32 +  Spinning
    1.33 +  (up? [this] (= ::up (.spin this)))
    1.34 +  (down? [this] (= ::down (.spin this))))
    1.35 +
    1.36 +(defmethod print-method Tuple
    1.37 +  [o w]
    1.38 +  (print-simple
    1.39 +   (if (up? o)
    1.40 +     (str "u" (.coll o))
    1.41 +     (str "d" (vec(.coll o))))
    1.42 +   w))
    1.43 +
    1.44 +(def tuple? #(= (type %) Tuple))
    1.45 +
    1.46 +;; CONSTRUCTORS
    1.47 +
    1.48 +(defn up
    1.49 +  "Create a new up-tuple containing the contents of coll."
    1.50 +  [coll]
    1.51 +  (Tuple. ::up coll))       
    1.52 +
    1.53 +(defn down
    1.54 +  "Create a new down-tuple containing the contents of coll."
    1.55 +  [coll]
    1.56 +  (Tuple. ::down coll))
    1.57 +
    1.58 +(defn same-spin
    1.59 +  "Creates a tuple which has the same spin as tuple and which contains
    1.60 +the contents of coll."
    1.61 +  [tuple coll]
    1.62 +  (if (up? tuple)
    1.63 +    (up coll)
    1.64 +    (down coll)))
    1.65 +
    1.66 +(defn opposite-spin
    1.67 +  "Create a tuple which has opposite spin to tuple and which contains
    1.68 +the contents of coll."
    1.69 +  [tuple coll]
    1.70 +  (if (up? tuple)
    1.71 +    (down coll)
    1.72 +    (up coll)))
    1.73 +(in-ns 'sicm.utils)
    1.74 +(require 'incanter.core) ;; use incanter's fast matrices
    1.75 +
    1.76 +
    1.77 +(defn all-equal? [coll]
    1.78 +  (if (empty? (rest coll)) true
    1.79 +      (and (= (first coll) (second coll))
    1.80 +	   (recur (rest coll)))))
    1.81 +
    1.82 +
    1.83 +(defprotocol Matrix
    1.84 +  (rows [matrix])
    1.85 +  (cols [matrix])
    1.86 +  (diagonal [matrix])
    1.87 +  (trace [matrix])
    1.88 +  (determinant [matrix])
    1.89 +  (transpose [matrix])
    1.90 +  (conjugate [matrix])
    1.91 +)
    1.92 +
    1.93 +(extend-protocol Matrix incanter.Matrix
    1.94 +  (rows [rs] (map down (apply map vector (apply map vector rs))))
    1.95 +  (cols [rs] (map up (apply map vector rs)))
    1.96 +  (diagonal [matrix] (incanter.core/diag matrix) )
    1.97 +  (determinant [matrix] (incanter.core/det matrix))
    1.98 +  (trace [matrix] (incanter.core/trace matrix))
    1.99 +  (transpose [matrix] (incanter.core/trans matrix)))
   1.100 +
   1.101 +(defn count-rows [matrix]
   1.102 +  ((comp count rows) matrix))
   1.103 +
   1.104 +(defn count-cols [matrix]
   1.105 +  ((comp count cols) matrix))
   1.106 +
   1.107 +(defn square? [matrix]
   1.108 +  (= (count-rows matrix) (count-cols matrix)))
   1.109 +
   1.110 +(defn identity-matrix
   1.111 +  "Define a square matrix of size n-by-n with 1s along the diagonal and
   1.112 +  0s everywhere else."
   1.113 +  [n]
   1.114 +  (incanter.core/identity-matrix n))
   1.115 +
   1.116 +
   1.117 +(defn matrix-by-rows
   1.118 +  "Define a matrix by giving its rows."
   1.119 +  [& rows]
   1.120 +  (if
   1.121 +   (not (all-equal? (map count rows)))
   1.122 +   (throw (Exception. "All rows in a matrix must have the same number of elements."))
   1.123 +   (incanter.core/matrix (vec rows))))
   1.124 +
   1.125 +(defn matrix-by-cols
   1.126 +  "Define a matrix by giving its columns"
   1.127 +  [& cols]
   1.128 +  (if (not (all-equal? (map count cols)))
   1.129 +   (throw (Exception. "All columns in a matrix must have the same number of elements."))
   1.130 +   (incanter.core/matrix (vec (apply map vector cols)))))
   1.131 +
   1.132 +(defn identity-matrix
   1.133 +  "Define a square matrix of size n-by-n with 1s along the diagonal and
   1.134 +  0s everywhere else."
   1.135 +  [n]
   1.136 +  (incanter.core/identity-matrix n))
   1.137 +
   1.138 +(in-ns 'sicm.utils)
   1.139 +(use 'clojure.contrib.generic.arithmetic
   1.140 +     'clojure.contrib.generic.collection
   1.141 +     'clojure.contrib.generic.functor
   1.142 +     'clojure.contrib.generic.math-functions)
   1.143 +
   1.144 +(defn numbers?
   1.145 +  "Returns true if all arguments are numbers, else false."
   1.146 +  [& xs]
   1.147 +  (every? number? xs))
   1.148 +
   1.149 +(defn tuple-surgery
   1.150 +  "Applies the function f to the items of tuple and the additional
   1.151 +  arguments, if any. Returns a Tuple of the same type as tuple."
   1.152 +  [tuple f & xs]
   1.153 +  ((if (up? tuple) up down)
   1.154 +   (apply f (seq tuple) xs)))
   1.155 +
   1.156 +
   1.157 +
   1.158 +;;; CONTRACTION collapses two compatible tuples into a number.
   1.159 +
   1.160 +(defn contractible?
   1.161 +  "Returns true if the tuples a and b are compatible for contraction,
   1.162 +  else false. Tuples are compatible if they have the same number of
   1.163 +  components, they have opposite spins, and their elements are
   1.164 +  pairwise-compatible."
   1.165 +  [a b]
   1.166 +  (and
   1.167 +   (isa? (type a) Tuple)
   1.168 +   (isa? (type b) Tuple)
   1.169 +   (= (count a) (count b))
   1.170 +   (not= (spin a) (spin b))
   1.171 +   
   1.172 +   (not-any? false?
   1.173 +	     (map #(or
   1.174 +		    (numbers? %1 %2)
   1.175 +		    (contractible? %1 %2))
   1.176 +		  a b))))
   1.177 +
   1.178 +(defn contract
   1.179 +  "Contracts two tuples, returning the sum of the
   1.180 +  products of the corresponding items. Contraction is recursive on
   1.181 +  nested tuples."
   1.182 +  [a b]
   1.183 +  (if (not (contractible? a b))
   1.184 +    (throw
   1.185 +     (Exception. "Not compatible for contraction."))
   1.186 +    (reduce +
   1.187 +	    (map
   1.188 +	     (fn [x y]
   1.189 +	       (if (numbers? x y)
   1.190 +		 (* x y)
   1.191 +		 (contract x y)))
   1.192 +	     a b))))
   1.193 +
   1.194 +
   1.195 +
   1.196 +
   1.197 +
   1.198 +(defmethod conj Tuple
   1.199 +  [tuple & xs]
   1.200 +  (tuple-surgery tuple #(apply conj % xs)))
   1.201 +
   1.202 +(defmethod fmap Tuple
   1.203 +  [f tuple]
   1.204 +  (tuple-surgery tuple (partial map f)))
   1.205 +
   1.206 +
   1.207 +
   1.208 +;; TODO: define Scalar, and add it to the hierarchy above Number and Complex
   1.209 +
   1.210 +					
   1.211 +(defmethod * [Tuple Tuple]             ; tuple*tuple
   1.212 +  [a b]
   1.213 +  (if (contractible? a b)
   1.214 +    (contract a b)
   1.215 +    (map (partial * a) b)))
   1.216 +
   1.217 +
   1.218 +(defmethod * [java.lang.Number Tuple]  ;; scalar *  tuple
   1.219 +  [a x] (fmap (partial * a) x))
   1.220 +
   1.221 +(defmethod * [Tuple java.lang.Number]
   1.222 +  [x a] (* a x))
   1.223 +
   1.224 +(defmethod * [java.lang.Number incanter.Matrix] ;; scalar *  matrix
   1.225 +  [x M] (incanter.core/mult x M))
   1.226 +
   1.227 +(defmethod * [incanter.Matrix java.lang.Number]
   1.228 +  [M x] (* x M))
   1.229 +
   1.230 +(defmethod * [incanter.Matrix incanter.Matrix] ;; matrix * matrix
   1.231 +  [M1 M2]
   1.232 +  (incanter.core/mmult M1 M2))
   1.233 +
   1.234 +(defmethod * [incanter.Matrix Tuple] ;; matrix * tuple
   1.235 +  [M v]
   1.236 +  (if (and (apply numbers? v) (up? v)) 
   1.237 +    (* M (matrix-by-cols v))
   1.238 +    (throw (Exception. "Currently, you can only multiply a matrix by a tuple of *numbers*"))
   1.239 +    ))
   1.240 +
   1.241 +(defmethod * [Tuple incanter.Matrix] ;; tuple * Matrix
   1.242 +  [v M]
   1.243 +  (if (and (apply numbers? v) (down? v))
   1.244 +    (* (matrix-by-rows v) M)
   1.245 +    (throw (Exception. "Currently, you can only multiply a matrix by a tuple of *numbers*"))
   1.246 +    ))
   1.247 +
   1.248 +
   1.249 +(defmethod exp incanter.Matrix
   1.250 +  [M]
   1.251 +  (incanter.core/exp M))
   1.252 +
   1.253 +
   1.254 +(in-ns 'sicm.utils)
   1.255 +(use 'clojure.contrib.seq
   1.256 +     'clojure.contrib.generic.arithmetic
   1.257 +     'clojure.contrib.generic.collection
   1.258 +     'clojure.contrib.generic.math-functions)
   1.259 +
   1.260 +;;∂
   1.261 +
   1.262 +;; DEFINITION : Differential Term
   1.263 +
   1.264 +;; A quantity with infinitesimal components, e.g. x, dxdy, 4dydz. The
   1.265 +;; coefficient of the quantity is returned by the 'coefficient' method,
   1.266 +;; while the sequence of differential parameters is returned by the
   1.267 +;; method 'partials'.
   1.268 +
   1.269 +;; Instead of using (potentially ambiguous) letters to denote
   1.270 +;; differential parameters (dx,dy,dz), we use integers. So, dxdz becomes [0 2].
   1.271 +
   1.272 +;; The coefficient can be any arithmetic object; the
   1.273 +;; partials must be a nonrepeating sorted sequence of nonnegative
   1.274 +;; integers.
   1.275 +
   1.276 +(deftype DifferentialTerm [coefficient partials])
   1.277 +
   1.278 +(defn differential-term
   1.279 +  "Make a differential term from a  coefficient and list of partials."
   1.280 +  [coefficient partials]
   1.281 +  (if (and (coll? partials) (every? #(and (integer? %) (not(neg? %))) partials)) 
   1.282 +    (DifferentialTerm. coefficient (set partials))
   1.283 +    (throw (java.lang.IllegalArgumentException. "Partials must be a collection of integers."))))
   1.284 +
   1.285 +
   1.286 +;; DEFINITION : Differential Sequence
   1.287 +;; A differential sequence is a sequence of differential terms, all with different partials.
   1.288 +;; Internally, it is a map from the partials of each term to their coefficients.
   1.289 +
   1.290 +(deftype DifferentialSeq
   1.291 +  [terms]
   1.292 +  ;;clojure.lang.IPersistentMap
   1.293 +  clojure.lang.Associative
   1.294 +  (assoc [this key val]
   1.295 +    (DifferentialSeq.
   1.296 +     (cons (differential-term val key) terms)))
   1.297 +  (cons [this x]
   1.298 +	(DifferentialSeq. (cons x terms)))
   1.299 +  (containsKey [this key]
   1.300 +	       (not(nil? (find-first #(= (.partials %) key) terms))))
   1.301 +  (count [this] (count (.terms this)))
   1.302 +  (empty [this] (DifferentialSeq. []))
   1.303 +  (entryAt [this key]
   1.304 +	   ((juxt #(.partials %) #(.coefficient %))
   1.305 +	    (find-first #(= (.partials %) key) terms)))
   1.306 +  (seq [this] (seq (.terms this))))
   1.307 +
   1.308 +(def differential? #(= (type %) DifferentialSeq))
   1.309 +
   1.310 +(defn zeroth-order?
   1.311 +  "Returns true if the differential sequence has at most a constant term."
   1.312 +  [dseq]
   1.313 +  (and
   1.314 +   (differential? dseq)
   1.315 +   (every?
   1.316 +    #(= #{} %)
   1.317 +    (keys (.terms dseq)))))
   1.318 +
   1.319 +(defmethod fmap DifferentialSeq
   1.320 +  [f dseq]
   1.321 +  (DifferentialSeq.
   1.322 +   (fmap f (.terms dseq))))
   1.323 +
   1.324 +
   1.325 +
   1.326 +
   1.327 +;; BUILDING DIFFERENTIAL OBJECTS
   1.328 +
   1.329 +(defn differential-seq
   1.330 +    "Define a differential sequence by specifying an alternating
   1.331 +sequence of coefficients and lists of partials."
   1.332 +  ([coefficient partials]
   1.333 +     (DifferentialSeq. {(set partials) coefficient}))
   1.334 +  ([coefficient partials & cps]
   1.335 +     (if (odd? (count cps))
   1.336 +       (throw (Exception. "differential-seq requires an even number of terms."))
   1.337 +       (DifferentialSeq.
   1.338 +	(reduce
   1.339 +	 #(assoc %1 (set (second %2)) (first %2))
   1.340 +	 {(set partials) coefficient}
   1.341 +	 (partition 2 cps))))))
   1.342 +  
   1.343 +
   1.344 +
   1.345 +(defn big-part
   1.346 +  "Returns the part of the differential sequence that is finite,
   1.347 +  i.e. not infinitely small. If the sequence is zeroth-order, returns
   1.348 +  the coefficient of the zeroth-order term instead. "
   1.349 +  [dseq]
   1.350 +  (if (zeroth-order? dseq) (get (.terms dseq) #{})
   1.351 +      (let [m (.terms dseq)
   1.352 +	    keys (sort-by count (keys m))
   1.353 +	    smallest-var (last (last keys))]
   1.354 +	(DifferentialSeq.
   1.355 +	 (reduce
   1.356 +	  #(assoc %1 (first %2) (second %2))
   1.357 +	  {}
   1.358 +	  (remove #((first %) smallest-var) m))))))
   1.359 +
   1.360 +
   1.361 +(defn small-part
   1.362 +  "Returns the part of the differential sequence that infinitely
   1.363 +  small. If the sequence is zeroth-order, returns zero."
   1.364 +  [dseq]
   1.365 +  (if (zeroth-order? dseq) 0
   1.366 +      (let [m (.terms dseq)
   1.367 +	    keys (sort-by count (keys m))
   1.368 +	    smallest-var (last (last keys))]
   1.369 +	(DifferentialSeq.
   1.370 +	 (reduce
   1.371 +	  #(assoc %1 (first %2) (second %2)) {}
   1.372 +	  (filter #((first %) smallest-var) m))))))
   1.373 +
   1.374 +
   1.375 +
   1.376 +(defn cartesian-product [set1 set2]
   1.377 +  (reduce concat
   1.378 +	  (for [x set1]
   1.379 +	    (for [y set2]
   1.380 +	      [x y]))))
   1.381 +
   1.382 +(defn nth-subset [n]
   1.383 +  (if (zero? n) []
   1.384 +      (let [lg2 #(/ (log %) (log 2))
   1.385 +	    k (int(java.lang.Math/floor (lg2 n)))
   1.386 +	    ]
   1.387 +	(cons k
   1.388 +	 (nth-subset  (- n (pow 2 k)))))))
   1.389 +    
   1.390 +(def all-partials
   1.391 +     (lazy-seq (map nth-subset (range))))
   1.392 +
   1.393 +
   1.394 +(defn differential-multiply
   1.395 +  "Multiply two differential sequences. The square of any differential
   1.396 +  variable is zero since differential variables are infinitesimally
   1.397 +  small."
   1.398 +  [dseq1 dseq2]
   1.399 +  (DifferentialSeq.
   1.400 +   (reduce
   1.401 +    (fn [m [[vars1 coeff1] [vars2 coeff2]]]
   1.402 +      (if (not (empty? (clojure.set/intersection vars1 vars2)))
   1.403 +	m
   1.404 +	(assoc m (clojure.set/union vars1 vars2) (* coeff1 coeff2))))
   1.405 +    {}
   1.406 +    (cartesian-product (.terms dseq1) (.terms dseq2)))))
   1.407 +
   1.408 +
   1.409 +
   1.410 +(defmethod * [DifferentialSeq DifferentialSeq]
   1.411 +  [dseq1 dseq2]
   1.412 +   (differential-multiply dseq1 dseq2))
   1.413 +
   1.414 +(defmethod + [DifferentialSeq DifferentialSeq]
   1.415 +  [dseq1 dseq2]
   1.416 +  (DifferentialSeq.
   1.417 +   (merge-with + (.terms dseq1) (.terms dseq2))))
   1.418 +
   1.419 +(defmethod * [java.lang.Number DifferentialSeq]
   1.420 +  [x dseq]
   1.421 +  (fmap (partial * x) dseq))
   1.422 +
   1.423 +(defmethod * [DifferentialSeq java.lang.Number]
   1.424 +  [dseq x]
   1.425 +  (fmap (partial * x) dseq))
   1.426 +
   1.427 +(defmethod + [java.lang.Number DifferentialSeq]
   1.428 +  [x dseq]
   1.429 +  (+ (differential-seq x []) dseq))
   1.430 +(defmethod + [DifferentialSeq java.lang.Number]
   1.431 +  [dseq x]
   1.432 +  (+ dseq (differential-seq x [])))
   1.433 +
   1.434 +(defmethod - DifferentialSeq
   1.435 +  [x]
   1.436 +  (fmap - x))
   1.437 +
   1.438 +
   1.439 +;; DERIVATIVES
   1.440 +
   1.441 +	      
   1.442 +
   1.443 +(defn linear-approximator
   1.444 +  "Returns an operator that linearly approximates the given function."
   1.445 +  ([f df|dx]
   1.446 +      (fn [x]
   1.447 +	(let [big-part (big-part x)
   1.448 +	      small-part (small-part x)]
   1.449 +	  ;; f(x+dx) ~= f(x) + f'(x)dx
   1.450 +	  (+ (f big-part)
   1.451 +	     (* (df|dx big-part) small-part)
   1.452 +	  ))))
   1.453 +     
   1.454 +  ([f df|dx df|dy]
   1.455 +     (fn [x y]
   1.456 +       (let [X (big-part x)
   1.457 +	     Y (big-part y)
   1.458 +	     DX (small-part x)
   1.459 +	     DY (small-part y)]
   1.460 +	 (+ (f X Y)
   1.461 +	    (* DX (f df|dx X Y))
   1.462 +	    (* DY (f df|dy X Y)))))))
   1.463 +
   1.464 +
   1.465 +
   1.466 +
   1.467 +
   1.468 +(defn D[f]
   1.469 +  (fn[x] (f (+ x (differential-seq  1 [0] 1 [1] 1 [2])))))
   1.470 +  
   1.471 +(defn d[partials f]
   1.472 +  (fn [x]
   1.473 +    (get 
   1.474 +     (.terms ((D f)x))
   1.475 +     (set partials)
   1.476 +     0
   1.477 +    )))
   1.478 +
   1.479 +(defmethod exp DifferentialSeq [x]
   1.480 +	   ((linear-approximator exp exp) x))
   1.481 +
   1.482 +(defmethod sin DifferentialSeq
   1.483 +  [x]
   1.484 +  ((linear-approximator sin cos) x))
   1.485 +
   1.486 +(defmethod cos DifferentialSeq
   1.487 +  [x]
   1.488 +  ((linear-approximator cos #(- (sin %))) x))
   1.489 +
   1.490 +(defmethod log DifferentialSeq
   1.491 +  [x]
   1.492 +  ((linear-approximator log (fn [x] (/ x)) ) x))
   1.493 +
   1.494 +(defmethod / [DifferentialSeq DifferentialSeq]
   1.495 +  [x y]
   1.496 +  ((linear-approximator /
   1.497 +			(fn [x y] (/ 1 y))
   1.498 +			(fn [x y] (- (/ x (* y y)))))
   1.499 +   x y))