8000 Adding parallel implementations of (some?) quasisep algorithms by dfm · Pull Request #210 · dfm/tinygp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Adding parallel implementations of (some?) quasisep algorithms #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/tinygp/solvers/quasisep/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def matmul(self, x: JAXArray) -> JAXArray:
"""
raise NotImplementedError

@abstractmethod
def parallel_matmul(self, x: JAXArray) -> JAXArray:
raise NotImplementedError

@abstractmethod
def scale(self, other: JAXArray) -> QSM:
"""The multiplication of this matrix times a scalar, as a QSM"""
Expand Down Expand Up @@ -150,6 +154,10 @@ def transpose(self) -> DiagQSM:
def matmul(self, x: JAXArray) -> JAXArray:
return self.d[:, None] * x

@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
return self.matmul(x)

def scale(self, other: JAXArray) -> DiagQSM:
return DiagQSM(d=self.d * other)

Expand Down Expand Up @@ -197,6 +205,17 @@ def impl(f, data): # type: ignore
_, f = jax.lax.scan(impl, init, (self.q, self.a, x))
return jax.vmap(jnp.dot)(self.p, f)

@jax.jit
@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
def impl(sm, sn):
return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1])

states = jax.vmap(lambda l, x: (l.a, jnp.outer(l.q, x)))(self, x)
f = jax.lax.associative_scan(impl, states)[1]
f = jnp.concatenate((jnp.zeros_like(f[:1]), f[:-1]), axis=0)
return jax.vmap(jnp.dot)(self.p, f)

def scale(self, other: JAXArray) -> StrictLowerTriQSM:
return StrictLowerTriQSM(p=self.p * other, q=self.q, a=self.a)

Expand Down Expand Up @@ -270,6 +289,17 @@ def impl(f, data): # type: ignore
_, f = jax.lax.scan(impl, init, (self.p, self.a, x), reverse=True)
return jax.vmap(jnp.dot)(self.q, f)

@jax.jit
@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
def impl(sm, sn):
return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1])

states = jax.vmap(lambda u, x: (u.a.T, jnp.outer 10000 (u.p, x)))(self, x)
f = jax.lax.associative_scan(impl, states, reverse=True)[1]
f = jnp.concatenate((f[1:], jnp.zeros_like(f[:1])), axis=0)
return jax.vmap(jnp.dot)(self.q, f)

def scale(self, other: JAXArray) -> StrictUpperTriQSM:
return StrictUpperTriQSM(p=self.p, q=self.q * other, a=self.a)

Expand Down Expand Up @@ -303,6 +333,10 @@ def transpose(self) -> UpperTriQSM:
def matmul(self, x: JAXArray) -> JAXArray:
return self.diag.matmul(x) + self.lower.matmul(x)

@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
return self.diag.parallel_matmul(x) + self.lower.parallel_matmul(x)

def scale(self, other: JAXArray) -> LowerTriQSM:
return LowerTriQSM(diag=self.diag.scale(other), lower=self.lower.scale(other))

Expand Down Expand Up @@ -359,6 +393,10 @@ def transpose(self) -> LowerTriQSM:
def matmul(self, x: JAXArray) -> JAXArray:
return self.diag.matmul(x) + self.upper.matmul(x)

@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
return self.diag.parallel_matmul(x) + self.upper.parallel_matmul(x)

def scale(self, other: JAXArray) -> UpperTriQSM:
return UpperTriQSM(diag=self.diag.scale(other), upper=self.upper.scale(other))

Expand Down Expand Up @@ -415,6 +453,14 @@ def transpose(self) -> SquareQSM:
def matmul(self, x: JAXArray) -> JAXArray:
return self.diag.matmul(x) + self.lower.matmul(x) + self.upper.matmul(x)

@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
return (
self.diag.parallel_matmul(x)
+ self.lower.parallel_matmul(x)
+ self.upper.parallel_matmul(x)
)

def scale(self, other: JAXArray) -> SquareQSM:
return SquareQSM(
diag=self.diag.scale(other),
Expand Down Expand Up @@ -505,6 +551,14 @@ def matmul(self, x: JAXArray) -> JAXArray:
+ self.lower.transpose().matmul(x)
)

@handle_matvec_shapes
def parallel_matmul(self, x: JAXArray) -> JAXArray:
return (
self.diag.parallel_matmul(x)
+ self.lower.parallel_matmul(x)
+ self.lower.transpose().parallel_matmul(x)
)

def scale(self, other: JAXArray) -> SymmQSM:
return SymmQSM(diag=self.diag.scale(other), lower=self.lower.scale(other))

Expand Down
28 changes: 27 additions & 1 deletion tests/test_solvers/test_quasisep/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from numpy import random as np_random

from tinygp.kernels.quasisep import Matern52
from tinygp.solvers.quasisep.core import (
DiagQSM,
LowerTriQSM,
Expand All @@ -17,7 +18,7 @@
from tinygp.test_utils import assert_allclose


@pytest.fixture(params=["random", "celerite"])
@pytest.fixture(params=["random", "celerite", "matern"])
def name(request):
return request.param

Expand Down Expand Up @@ -104,6 +105,17 @@ def get_matrices(name):
a = jnp.stack([jnp.diag(v) for v in jnp.exp(-c[None] * dt[:, None])], axis=0)
p = jnp.einsum("ni,nij->nj", p, a)

elif name == "matern":
t = jnp.sort(random.uniform(0, 10, N))
kernel = Matern52(1.5, 1.0)
matrix = kernel.to_symm_qsm(t)
diag += matrix.diag.d
p = matrix.lower.p
q = matrix.lower.q
a = matrix.lower.a
l = matrix.lower.to_dense()
u = l.T

else:
raise AssertionError()

Expand Down Expand Up @@ -166,6 +178,20 @@ def test_strict_tri_matmul(matrices):
assert_allclose(mat.T @ m, u @ m)


def test_strict_lower_tri_parallel_matmul(matrices):
_, p, q, a, v, m, _, _ = matrices
mat = StrictLowerTriQSM(p=p, q=q, a=a)
assert_allclose(mat.parallel_matmul(v), mat.matmul(v))
assert_allclose(mat.parallel_matmul(m), mat.matmul(m))


def test_strict_upper_tri_parallel_matmul(matrices):
_, p, q, a, v, m, _, _ = matrices
mat = StrictLowerTriQSM(p=p, q=q, a=a).T
assert_allclose(mat.parallel_matmul(v), mat.matmul(v))
assert_allclose(mat.parallel_matmul(m), mat.matmul(m))


def test_tri_matmul(matrices):
diag, p, q, a, v, m, l, _ = matrices
mat = LowerTriQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a))
Expand Down
0