diff sicm/bk/utils.clj @ 2:b4de894a1e2e

initial import
author Robert McIntyre <rlm@mit.edu>
date Fri, 28 Oct 2011 00:03:05 -0700
parents
children
line wrap: on
line diff
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/sicm/bk/utils.clj	Fri Oct 28 00:03:05 2011 -0700
     1.3 @@ -0,0 +1,196 @@
     1.4 +
     1.5 +(ns sicm.utils)
     1.6 +
     1.7 +					;***** GENERIC ARITHMETIC
     1.8 +(ns sicm.utils)
     1.9 +(in-ns 'sicm.utils)
    1.10 +
    1.11 +(defprotocol Arithmetic
    1.12 +  (zero [this])
    1.13 +  (one [this]))
    1.14 +
    1.15 +
    1.16 +(extend-protocol Arithmetic
    1.17 +  java.lang.Number
    1.18 +  (zero [this] 0)
    1.19 +  (one [this] 1))
    1.20 +
    1.21 +(extend-protocol Arithmetic
    1.22 +  clojure.lang.Seqable
    1.23 +  (zero [this] (map zero this))
    1.24 +  (one [this] (map one this)))
    1.25 +
    1.26 +
    1.27 +					;***** TUPLES AND MATRICES
    1.28 +(in-ns 'sicm.utils)
    1.29 +
    1.30 +(defprotocol Spinning
    1.31 +  (up? [this])
    1.32 +  (down? [this]))
    1.33 +
    1.34 +(defn spin
    1.35 +  "Returns the spin of the Spinning s, either :up or :down"
    1.36 +  [#^Spinning s]
    1.37 +  (cond (up? s) :up (down? s) :down))
    1.38 +
    1.39 +
    1.40 +(deftype Tuple
    1.41 +  [spin coll]
    1.42 +  clojure.lang.Seqable
    1.43 +  (seq [this] (seq (.coll this)))
    1.44 +  clojure.lang.Counted
    1.45 +  (count [this] (count (.coll this))))
    1.46 +
    1.47 +(extend-type Tuple
    1.48 +  Spinning
    1.49 +  (up? [this] (= ::up (.spin this)))
    1.50 +  (down? [this] (= ::down (.spin this))))
    1.51 +
    1.52 +(defmethod print-method Tuple
    1.53 +  [o w]
    1.54 +  (print-simple (str (if (up? o) 'u 'd) (.coll o))  w))
    1.55 +
    1.56 +
    1.57 +
    1.58 +(defn up
    1.59 +  "Create a new up-tuple containing the contents of coll."
    1.60 +  [coll]
    1.61 +  (Tuple. ::up coll))       
    1.62 +
    1.63 +(defn down
    1.64 +  "Create a new down-tuple containing the contents of coll."
    1.65 +  [coll]
    1.66 +  (Tuple. ::down coll))
    1.67 +
    1.68 +
    1.69 +(in-ns 'sicm.utils)
    1.70 +
    1.71 +(defn numbers?
    1.72 +  "Returns true if all arguments are numbers, else false."
    1.73 +  [& xs]
    1.74 +  (every? number? xs))
    1.75 +
    1.76 +(defn contractible?
    1.77 +  "Returns true if the tuples a and b are compatible for contraction,
    1.78 +  else false. Tuples are compatible if they have the same number of
    1.79 +  components, they have opposite spins, and their elements are
    1.80 +  pairwise-compatible."
    1.81 +  [a b]
    1.82 +  (and
    1.83 +   (isa? (type a) Tuple)
    1.84 +   (isa? (type b) Tuple)
    1.85 +   (= (count a) (count b))
    1.86 +   (not= (spin a) (spin b))
    1.87 +   
    1.88 +   (not-any? false?
    1.89 +	     (map #(or
    1.90 +		    (numbers? %1 %2)
    1.91 +		    (contractible? %1 %2))
    1.92 +		  a b))))
    1.93 +
    1.94 +
    1.95 +
    1.96 +(defn contract
    1.97 +  "Contracts two tuples, returning the sum of the
    1.98 +  products of the corresponding items. Contraction is recursive on
    1.99 +  nested tuples."
   1.100 +  [a b]
   1.101 +  (if (not (contractible? a b))
   1.102 +    (throw
   1.103 +     (Exception. "Not compatible for contraction."))
   1.104 +    (reduce +
   1.105 +	    (map
   1.106 +	     (fn [x y]
   1.107 +	       (if (numbers? x y)
   1.108 +		 (* x y)
   1.109 +		 (contract x y)))
   1.110 +	     a b))))
   1.111 +
   1.112 +					;***** MATRICES
   1.113 +(in-ns 'sicm.utils)
   1.114 +(require 'incanter.core) ;; use incanter's fast matrices
   1.115 +
   1.116 +(defprotocol Matrix
   1.117 +  (rows [this])
   1.118 +  (cols [this])
   1.119 +  (diagonal [this])
   1.120 +  (trace [this])
   1.121 +  (determinant [this]))
   1.122 +
   1.123 +(extend-protocol Matrix
   1.124 +  incanter.Matrix
   1.125 +  (rows [this] (map down this)))
   1.126 +
   1.127 +
   1.128 +
   1.129 +
   1.130 +(defn count-rows [matrix]
   1.131 +  ((comp count rows) matrix))
   1.132 +
   1.133 +(defn count-cols [matrix]
   1.134 +  ((comp count cols) matrix))
   1.135 +
   1.136 +
   1.137 +(defn matrix-by-rows
   1.138 +  "Define a matrix by giving its rows."
   1.139 +  [& rows]
   1.140 +  (cond
   1.141 +   (not (all-equal? (map count rows)))
   1.142 +   (throw (Exception. "All rows in a matrix must have the same number of elements."))
   1.143 +   :else
   1.144 +   (reify Matrix
   1.145 +	  (rows [this] (map down rows))
   1.146 +	  (cols [this] (map up (apply map vector rows)))
   1.147 +	  (diagonal [this] (map-indexed (fn [i row] (nth row i) rows)))
   1.148 +	  (trace [this]
   1.149 +		 (if (not= (count-rows this) (count-cols this))
   1.150 +		   (throw (Exception.
   1.151 +			   "Cannot take the trace of a non-square matrix."))
   1.152 +		   (reduce + (diagonal this))))
   1.153 +	  
   1.154 +	  (determinant [this]
   1.155 +		       (if (not= (count-rows this) (count-cols this))
   1.156 +			 (throw (Exception.
   1.157 +				 "Cannot take the determinant of a non-square matrix."))
   1.158 +			 (reduce * (diagonal this))))
   1.159 +   )))
   1.160 +
   1.161 +
   1.162 +(defn matrix-by-cols
   1.163 +  "Define a matrix by giving its columns."
   1.164 +  [& cols]
   1.165 +  (cond
   1.166 +   (not (all-equal? (map count cols)))
   1.167 +   (throw (Exception. "All columns in a matrix must have the same number of elements."))
   1.168 +   :else
   1.169 +   (reify Matrix
   1.170 +	  (cols [this] (map up cols))
   1.171 +	  (rows [this] (map down (apply map vector cols)))
   1.172 +	  (diagonal [this] (map-indexed (fn [i col] (nth col i) cols)))
   1.173 +	  (trace [this]
   1.174 +		 (if (not= (count-cols this) (count-rows this))
   1.175 +		   (throw (Exception.
   1.176 +			   "Cannot take the trace of a non-square matrix."))
   1.177 +		   (reduce + (diagonal this))))
   1.178 +	  
   1.179 +	  (determinant [this]
   1.180 +		       (if (not= (count-cols this) (count-rows this))
   1.181 +			 (throw (Exception.
   1.182 +				 "Cannot take the determinant of a non-square matrix."))
   1.183 +			 (reduce * (map-indexed (fn [i col] (nth col i)) cols))))
   1.184 +	  )))
   1.185 +
   1.186 +(extend-protocol Matrix Tuple
   1.187 +		 (rows [this] (if (down? this)
   1.188 +				(list this)
   1.189 +				(map (comp up vector) this)))
   1.190 +
   1.191 +		 (cols [this] (if (up? this)
   1.192 +				(list this)
   1.193 +				(map (comp down vector) this))))
   1.194 +
   1.195 +(defn matrix-multiply [A B]
   1.196 +  (apply matrix-by-rows
   1.197 +	 (for [a (rows A)]
   1.198 +	   (for [b (cols B)]
   1.199 +	     (contract a b)))))