8000 Performance improvements for Differential by sritchie · Pull Request #360 · sicmutils/sicmutils · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Performance improvements for Differential #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## [unreleased]

- #360 introduces a number of performance improvements to the
`sicmutils.differential.Differential` implementation, primarily in `terms:+`
and `terms:*`. thanks again to @ptaoussanis and the
[Tufte](https://github.com/ptaoussanis/tufte) profiling library for helping me
track these down.

- #358:

- Converts the Clojurescript test build and REPL command from `lein-cljsbuild`
Expand Down
150 changes: 94 additions & 56 deletions src/sicmutils/differential.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,11 @@
;; A differential term is implemented as a pair whose first element is a set of
;; tags and whose second is the coefficient.

(def ^:private tags first)
(def ^:private coefficient peek)
(defn- tags [term]
(nth term 0))

(defn- coefficient [term]
(nth term 1))

;; The set of tags is implemented as a "vector set",
;; from [[sicmutils.util.vector-set]]. This is a sorted set data structure,
Expand Down Expand Up @@ -477,34 +480,43 @@
in the result with a zero coefficient will be removed.

Each input must be sequence of `[tag-set, coefficient]` pairs, sorted by
`tag-set`."
`tag-set`.

NOTE that this function recurs on increasing indices internally instead of
walking through the lists directly. This method of traversing vectors is more
efficient, and this function is called so often that the performance gain is
worth it, and reads almost like the explicit sequence traversal."
([] [])
([xs] xs)
([xs ys]
(loop [xs xs, ys ys, result []]
(cond (empty? xs) (into result ys)
(empty? ys) (into result xs)
:else (let [[x-tags x-coef :as x] (first xs)
[y-tags y-coef :as y] (first ys)
compare-flag (v/compare x-tags y-tags)]
(cond
;; If the terms have the same tag set, add the coefficients
;; together. Include the term in the result only if the new
;; coefficient is non-zero.
(zero? compare-flag)
(let [sum (g/+ x-coef y-coef)]
(recur (rest xs)
(rest ys)
(if (v/zero? sum)
result
(conj result (make-term x-tags sum)))))

;; Else, pass the smaller term on unchanged and proceed.
(neg? compare-flag)
(recur (rest xs) ys (conj result x))

:else
(recur xs (rest ys) (conj result y))))))))
(loop [i (long 0)
j (long 0)
result (transient [])]
(let [x (nth xs i nil)
y (nth ys j nil)]
(cond (not x) (into (persistent! result) (subvec ys j))
(not y) (into (persistent! result) (subvec xs i))
:else (let [[x-tags x-coef] x
[y-tags y-coef] y
compare-flag (core-compare x-tags y-tags)]
(cond
;; If the terms have the same tag set, add the coefficients
;; together. Include the term in the result only if the new
;; coefficient is non-zero.
(zero? compare-flag)
(let [sum (g/add x-coef y-coef)]
(recur (inc i)
(inc j)
(if (v/zero? sum)
result
(conj! result (make-term x-tags sum)))))

;; Else, pass the smaller term on unchanged and proceed.
(neg? compare-flag)
(recur (inc i) j (conj! result x))

:else
(recur i (inc j) (conj! result y)))))))))

;; Because we've decided to store terms as a vector, we can multiply two vectors
;; of terms by:
Expand Down Expand Up @@ -534,18 +546,37 @@
;; [[terms:*]] implements the first three steps, and calls [[collect-terms]] on
;; the resulting sequence:

(defn- terms:*
(defn- t*ts
"Multiplies a single term on the left by a vector of `terms` on the right.
Returns a new vector of terms."
[[tags coeff] terms]
(loop [acc []
i 0]
(let [t (nth terms i nil)]
(if (nil? t)
acc
(let [[tags1 coeff1] t]
(if (empty? (uv/intersection tags tags1))
(recur (conj acc (make-term
(uv/union tags tags1)
(g/* coeff coeff1)))
(inc i))
(recur acc (inc i))))))))

(defn terms:*
"Returns a vector of non-zero [[Differential]] terms that represent the product
of the differential term lists `xs` and `ys`."
([] [])
([xs] xs)
([xs ys]
(collect-terms
(for [[x-tags x-coef] xs
[y-tags y-coef] ys
:when (empty? (uv/intersection x-tags y-tags))]
(make-term (uv/union x-tags y-tags)
(g/* x-coef y-coef))))))
of the differential term lists `xs` and `ys`.

NOTE that this function doesn't need to call [[collect-terms]] internally
because grouping is accomplished by the internal [[terms:+]] calls."
[xlist ylist]
(letfn [(call [i]
(let [x (nth xlist i nil)]
(if (nil? x)
[]
(terms:+ (t*ts x ylist)
(call (inc i))))))]
(call 0)))

;; ## Differential Type Implementation
;;
Expand Down Expand Up @@ -701,13 +732,10 @@
If you pass a non-[[Differential]], [[->terms]] will return a singleton term
list (or `[]` if the argument was zero)."
[dx]
(cond (differential? dx)
(filterv (fn [term]
(not (v/zero? (coefficient term))))
(bare-terms dx))

(v/zero? dx) []
:else [(make-term dx)]))
(cond (differential? dx) (bare-terms dx)
(vector? dx) dx
(v/zero? dx) []
:else [(make-term dx)]))

(defn- terms->differential
"Returns a differential instance generated from a vector of terms. This method
Expand Down Expand Up @@ -812,6 +840,16 @@
(terms:* (->terms dx)
(->terms dy)))))

(defn d:+*
"Identical to `(d:+ a) (d:* b c)`, but _slightly_ more efficient as the function
is able to skip creating a [[Differential]] instance during [[d:*]] and then
immediately tearing it down for [[d:+]]."
[a b c]
(terms->differential
(terms:+ (->terms a)
(terms:* (->terms b)
(->terms c)))))

(defn- d:apply
"Accepts a [[Differential]] and a sequence of `args`, interprets each
coefficient as a function and returns a new [[Differential]] generated by
Expand Down Expand Up @@ -848,7 +886,7 @@
(bundle-element primal 1 tag))
([primal tangent tag]
(let [term (make-term (uv/make [tag]) 1)]
(d:+ primal (d:* tangent (->Differential [term]))))))
(d:+* primal tangent [term]))))

;; ## Differential Parts API
;;
Expand All @@ -864,7 +902,7 @@
no non-zero tangent parts, or all non-[[Differential]]s), returns nil."
([dx]
(when (differential? dx)
(let [last-term (peek (->terms dx))
(let [last-term (peek (bare-terms dx))
highest-tag (peek (tags last-term))]
highest-tag)))
([dx & dxs]
Expand Down Expand Up @@ -898,7 +936,7 @@
([dx tag]
(if (differential? dx)
(let [sans-tag? #(not (tag-in-term? % tag))]
(->> (->terms dx)
(->> (bare-terms dx)
(filterv sans-tag?)
(terms->differential)))
dx)))
Expand All @@ -923,7 +961,7 @@
([dx] (tangent-part dx (max-order-tag dx)))
([dx tag]
(if (differential? dx)
(->> (->terms dx)
(->> (bare-terms dx)
(filterv #(tag-in-term? % tag))
(terms->differential))
0)))
Expand All @@ -944,7 +982,7 @@
[dx 0]
(let [[tangent-terms primal-terms]
(us/separatev #(tag-in-term? % tag)
(->terms dx))]
(bare-terms dx))]
[(terms->differential primal-terms)
(terms->differential tangent-terms)]))))

Expand Down Expand Up @@ -1114,9 +1152,9 @@
(f x)
(let [[px tx] (primal-tangent-pair x)
fx (call px)]
(if (and (v/number? tx) (v/zero? tx))
(if (v/numeric-zero? tx)
fx
(d:+ fx (d:* (df:dx px) tx))))))))
(d:+* fx (df:dx px) tx)))))))

(defn lift-2
"Given:
Expand Down Expand Up @@ -1148,12 +1186,12 @@
[xe dx] (primal-tangent-pair x tag)
[ye dy] (primal-tangent-pair y tag)
a (call xe ye)
b (if (and (v/number? dx) (v/zero? dx))
b (if (v/numeric-zero? dx)
a
(d:+ a (d:* (df:dx xe ye) dx)))]
(if (and (v/number? dy) (v/zero? dy))
(d:+* a (df:dx xe ye) dx))]
(if (v/numeric-zero? dy)
b
(d:+ b (d:* (df:dy xe ye) dy))))))))
(d:+* b (df:dy xe ye) dy)))))))

(defn lift-n
"Given:
Expand Down
6 changes: 6 additions & 0 deletions src/sicmutils/value.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@
(instance? goog.math.Long x)
(instance? Complex x))))

(defn numeric-zero?
"Returns `true` if `x` is both a [[number?]] and [[zero?]], false otherwise."
[x]
(and (number? x)
(zero? x)))

;; `::scalar` is a thing that symbolic expressions AND actual numbers both
;; derive from.

Expand Down
12 changes: 6 additions & 6 deletions test/sicmutils/fdg/bianchi_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -150,23 +150,23 @@
;; +----------+----------+-----------+
;; | |Bianchi 1 | Bianchi 2 |
;; +----------+----------+-----------+
;; |R2 |62ms, ?? | 1.44s, ?? |
;; |R2 |52ms, ?? | 1.29s, ?? |
;; +----------+----------+-----------+
;; |R3 |500ms, ?? |1.02m, ?? |
;; |R3 |426ms, ?? |1.02m, ?? |
;; +----------+----------+-----------+
;; |R4 |2.3s, ?? |17.44m, ?? |
;; |R4 |2.17s, ?? |17.44m, ?? |
;; +----------+----------+-----------+
;;
;; With a general connection (with torsion):
;;
;; +----------+----------+-----------+
;; | |Bianchi 1 | Bianchi 2 |
;; +----------+----------+-----------+
;; |R2 |220ms, ?? |1.38s, ?? |
;; |R2 |199ms, ?? |1.38s, ?? |
;; +----------+----------+-----------+
;; |R3 |1.48s, ?? |26.74s, ?? |
;; |R3 |1.33s, ?? |26.74s, ?? |
;; +----------+----------+-----------+
;; |R4 |7.75s, ?? |4.82m, ?? |
;; |R4 |6.96s, ?? |4.82m, ?? |
;; +----------+----------+-----------+

(testing "A system with a symmetric connection is torsion-free."
Expand Down
0