From 5e15922cef41c68d0797d46bc41228aa78db0165 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 3 Apr 2024 17:51:06 -0400 Subject: [PATCH 1/6] Adding parallel matmuls for strict tri quasisep matrices --- src/tinygp/solvers/quasisep/core.py | 33 +++++++++++++++++++ tests/test_solvers/test_quasisep/test_core.py | 14 ++++++++ 2 files changed, 47 insertions(+) diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 18766014..4b0c4226 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -68,6 +68,9 @@ def matmul(self, x: JAXArray) -> JAXArray: """ raise NotImplementedError + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.matmul(x) + @abstractmethod def scale(self, other: JAXArray) -> QSM: """The multiplication of this matrix times a scalar, as a QSM""" @@ -121,6 +124,14 @@ def __matmul__(self, other: Any) -> Any: from tinygp.solvers.quasisep.ops import qsm_mul return qsm_mul(self, other) + + elif any(d.platform != "cpu" for d in jnp.asarray(other).devices()): + # When using a hardware accelerator, we can sometimes get better + # performance using a special purpose matmul implementation. This + # will fall back on the standard matmul implementation if the + # parallel version doesn't exist. + return self.parallel_matmul(other) + else: return self.matmul(other) @@ -197,6 +208,17 @@ def impl(f, data): # type: ignore _, f = jax.lax.scan(impl, init, (self.q, self.a, x)) return jax.vmap(jnp.dot)(self.p, f) + @jax.jit + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + def impl(sm, sn): + return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1]) + + states = jax.vmap(lambda l, x: (l.a, jnp.outer(l.q, x)))(self, x) + f = jax.lax.associative_scan(impl, states)[1] + f = jnp.concatenate((jnp.zeros_like(f[:1]), f[:-1]), axis=0) + return jax.vmap(jnp.dot)(self.p, f) + def scale(self, other: JAXArray) -> StrictLowerTriQSM: return StrictLowerTriQSM(p=self.p * other, q=self.q, a=self.a) @@ -270,6 +292,17 @@ def impl(f, data): # type: ignore _, f = jax.lax.scan(impl, init, (self.p, self.a, x), reverse=True) return jax.vmap(jnp.dot)(self.q, f) + @jax.jit + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + def impl(sm, sn): + return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1]) + + states = jax.vmap(lambda u, x: (u.a, jnp.outer(u.p, x)))(self, x) + f = jax.lax.associative_scan(impl, states, reverse=True)[1] + f = jnp.concatenate((f[1:], jnp.zeros_like(f[:1])), axis=0) + return jax.vmap(jnp.dot)(self.q, f) + def scale(self, other: JAXArray) -> StrictUpperTriQSM: return StrictUpperTriQSM(p=self.p, q=self.q * other, a=self.a) diff --git a/tests/test_solvers/test_quasisep/test_core.py b/tests/test_solvers/test_quasisep/test_core.py index ae1de4c4..af565693 100644 --- a/tests/test_solvers/test_quasisep/test_core.py +++ b/tests/test_solvers/test_quasisep/test_core.py @@ -166,6 +166,20 @@ def test_strict_tri_matmul(matrices): assert_allclose(mat.T @ m, u @ m) +def test_strict_lower_tri_parallel_matmul(matrices): + _, p, q, a, v, m, _, _ = matrices + mat = StrictLowerTriQSM(p=p, q=q, a=a) + assert_allclose(mat.parallel_matmul(v), mat.matmul(v)) + assert_allclose(mat.parallel_matmul(m), mat.matmul(m)) + + +def test_strict_upper_tri_parallel_matmul(matrices): + _, p, q, a, v, m, _, _ = matrices + mat = StrictLowerTriQSM(p=p, q=q, a=a).T + assert_allclose(mat.parallel_matmul(v), mat.matmul(v)) + assert_allclose(mat.parallel_matmul(m), mat.matmul(m)) + + def test_tri_matmul(matrices): diag, p, q, a, v, m, l, _ = matrices mat = LowerTriQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a)) From 2c8bce93fcd6c5b9aa0674b12ee2458965c72842 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 3 Apr 2024 17:57:04 -0400 Subject: [PATCH 2/6] propagate parallel matmuls --- src/tinygp/solvers/quasisep/core.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 4b0c4226..18e43bf0 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -336,6 +336,10 @@ def transpose(self) -> UpperTriQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.diag.matmul(x) + self.lower.matmul(x) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.diag.parallel_matmul(x) + self.lower.parallel_matmul(x) + def scale(self, other: JAXArray) -> LowerTriQSM: return LowerTriQSM(diag=self.diag.scale(other), lower=self.lower.scale(other)) @@ -392,6 +396,10 @@ def transpose(self) -> LowerTriQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.diag.matmul(x) + self.upper.matmul(x) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.diag.parallel_matmul(x) + self.upper.parallel_matmul(x) + def scale(self, other: JAXArray) -> UpperTriQSM: return UpperTriQSM(diag=self.diag.scale(other), upper=self.upper.scale(other)) @@ -448,6 +456,14 @@ def transpose(self) -> SquareQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.diag.matmul(x) + self.lower.matmul(x) + self.upper.matmul(x) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return ( + self.diag.parallel_matmul(x) + + self.lower.parallel_matmul(x) + + self.upper.parallel_matmul(x) + ) + def scale(self, other: JAXArray) -> SquareQSM: return SquareQSM( diag=self.diag.scale(other), @@ -538,6 +554,14 @@ def matmul(self, x: JAXArray) -> JAXArray: + self.lower.transpose().matmul(x) ) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return ( + self.diag.parallel_matmul(x) + + self.lower.parallel_matmul(x) + + self.lower.transpose().parallel_matmul(x) + ) + def scale(self, other: JAXArray) -> SymmQSM: return SymmQSM(diag=self.diag.scale(other), lower=self.lower.scale(other)) From 93f7db5eb46b607f6619fd3c37620b96ee364ac3 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 3 Apr 2024 17:59:10 -0400 Subject: [PATCH 3/6] force general implementation of parallel matmul --- src/tinygp/solvers/quasisep/core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 18e43bf0..696a1b73 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -68,8 +68,9 @@ def matmul(self, x: JAXArray) -> JAXArray: """ raise NotImplementedError + @abstractmethod def parallel_matmul(self, x: JAXArray) -> JAXArray: - return self.matmul(x) + raise NotImplementedError @abstractmethod def scale(self, other: JAXArray) -> QSM: @@ -161,6 +162,10 @@ def transpose(self) -> DiagQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.d[:, None] * x + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.matmul(x) + def scale(self, other: JAXArray) -> DiagQSM: return DiagQSM(d=self.d * other) From 56503a47e6d1e2aba9302b66d89edacf93792b8f Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 3 Apr 2024 18:31:14 -0400 Subject: [PATCH 4/6] revert checking devices --- src/tinygp/solvers/quasisep/core.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 696a1b73..8867761a 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -125,14 +125,6 @@ def __matmul__(self, other: Any) -> Any: from tinygp.solvers.quasisep.ops import qsm_mul return qsm_mul(self, other) - - elif any(d.platform != "cpu" for d in jnp.asarray(other).devices()): - # When using a hardware accelerator, we can sometimes get better - # performance using a special purpose matmul implementation. This - # will fall back on the standard matmul implementation if the - # parallel version doesn't exist. - return self.parallel_matmul(other) - else: return self.matmul(other) From c447bca936b962bc437ae2e43ca148025cb3ecad Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Thu, 4 Apr 2024 11:26:01 -0400 Subject: [PATCH 5/6] fixing transpose bug --- src/tinygp/solvers/quasisep/core.py | 2 +- tests/test_solvers/test_quasisep/test_core.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 8867761a..814cb216 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -295,7 +295,7 @@ def parallel_matmul(self, x: JAXArray) -> JAXArray: def impl(sm, sn): return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1]) - states = jax.vmap(lambda u, x: (u.a, jnp.outer(u.p, x)))(self, x) + states = jax.vmap(lambda u, x: (u.a.T, jnp.outer(u.p, x)))(self, x) f = jax.lax.associative_scan(impl, states, reverse=True)[1] f = jnp.concatenate((f[1:], jnp.zeros_like(f[:1])), axis=0) return jax.vmap(jnp.dot)(self.q, f) diff --git a/tests/test_solvers/test_quasisep/test_core.py b/tests/test_solvers/test_quasisep/test_core.py index af565693..bd033960 100644 --- a/tests/test_solvers/test_quasisep/test_core.py +++ b/tests/test_solvers/test_quasisep/test_core.py @@ -6,6 +6,7 @@ import pytest from numpy import random as np_random +from tinygp.kernels.quasisep import Matern52 from tinygp.solvers.quasisep.core import ( DiagQSM, LowerTriQSM, @@ -17,7 +18,7 @@ from tinygp.test_utils import assert_allclose -@pytest.fixture(params=["random", "celerite"]) +@pytest.fixture(params=["random", "celerite", "matern"]) def name(request): return request.param @@ -104,6 +105,17 @@ def get_matrices(name): a = jnp.stack([jnp.diag(v) for v in jnp.exp(-c[None] * dt[:, None])], axis=0) p = jnp.einsum("ni,nij->nj", p, a) + elif name == "matern": + t = jnp.sort(random.uniform(0, 10, N)) + kernel = Matern52(1.5, 1.0) + matrix = kernel.to_symm_qsm(t) + diag = matrix.diag.d + p = matrix.lower.p + q = matrix.lower.q + a = matrix.lower.a + l = matrix.to_dense() + u = l.T + else: raise AssertionError() From 1fca6b591a26408b9a5026248eeff031c553c2ed Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Thu, 4 Apr 2024 11:35:41 -0400 Subject: [PATCH 6/6] fixing new matern test --- tests/test_solvers/test_quasisep/test_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_solvers/test_quasisep/test_core.py b/tests/test_solvers/test_quasisep/test_core.py index bd033960..80c6a861 100644 --- a/tests/test_solvers/test_quasisep/test_core.py +++ b/tests/test_solvers/test_quasisep/test_core.py @@ -109,11 +109,11 @@ def get_matrices(name): t = jnp.sort(random.uniform(0, 10, N)) kernel = Matern52(1.5, 1.0) matrix = kernel.to_symm_qsm(t) - diag = matrix.diag.d + diag += matrix.diag.d p = matrix.lower.p q = matrix.lower.q a = matrix.lower.a - l = matrix.to_dense() + l = matrix.lower.to_dense() u = l.T else: