Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
solver: Optional[Any] = None,
mean_value: Optional[JAXArray] = None,
covariance_value: Optional[Any] = None,
**solver_kwargs: Any,
):
self.kernel = kernel
self.X = X
Expand Down Expand Up @@ -101,7 +102,11 @@ def __init__(
else:
solver = DirectSolver
self.solver = solver.init(
kernel, self.X, self.noise, covariance=covariance_value
kernel,
self.X,
self.noise,
covariance=covariance_value,
**solver_kwargs,
)

@property
Expand Down
4 changes: 0 additions & 4 deletions src/tinygp/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def __radd__(self, other: Any) -> "Kernel":
# We'll hit this first branch when using the `sum` function
if other == 0:
return self
if isinstance(other, Kernel):
return Sum(other, self)
return Sum(Constant(other), self)

def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
Expand All @@ -128,8 +126,6 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
return Product(self, Constant(other))

def __rmul__(self, other: Any) -> "Kernel":
if isinstance(other, Kernel):
return Product(other, self)
return Product(Constant(other), self)


Expand Down
55 changes: 51 additions & 4 deletions src/tinygp/solvers/quasisep/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

__all__ = ["QuasisepSolver"]

from typing import Any, Optional
from functools import wraps
from typing import Any, Callable, Optional

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -17,6 +17,21 @@
from tinygp.solvers.solver import Solver


def handle_sorting(func: Callable[..., JAXArray]) -> Callable[..., JAXArray]:
@wraps(func)
def wrapped(
self: "QuasisepSolver", y: JAXArray, *args: Any, **kwargs: Any
) -> JAXArray:
if self.inds_to_sorted is not None:
y = y[self.inds_to_sorted]
r = func(self, y, *args, **kwargs)
if self.sorted_to_inds is not None:
return r[self.sorted_to_inds]
return r

return wrapped


@dataclass
class QuasisepSolver(Solver):
"""A scalable solver that uses quasiseparable matrices
Expand All @@ -32,6 +47,8 @@ class QuasisepSolver(Solver):
X: JAXArray
matrix: SymmQSM
factor: LowerTriQSM
inds_to_sorted: Optional[JAXArray]
sorted_to_inds: Optional[JAXArray]

@classmethod
def init(
Expand All @@ -41,6 +58,7 @@ def init(
noise: Noise,
*,
covariance: Optional[Any] = None,
sort: bool = True,
) -> "QuasisepSolver":
"""Build a :class:`QuasisepSolver` for a given kernel and coordinates

Expand All @@ -55,27 +73,55 @@ def init(
"""
from tinygp.kernels.quasisep import Quasisep

inds_to_sorted = None
sorted_to_inds = None
if covariance is None:
assert isinstance(kernel, Quasisep)

if sort:
inds_to_sorted = jnp.argsort(kernel.coord_to_sortable(X))
sorted_to_inds = (
jnp.empty_like(inds_to_sorted)
.at[inds_to_sorted]
.set(jnp.arange(len(inds_to_sorted)))
)
X = X[inds_to_sorted]

matrix = kernel.to_symm_qsm(X)
matrix += noise.to_qsm()

else:
assert isinstance(covariance, SymmQSM)
matrix = covariance

factor = matrix.cholesky()
return cls(X=X, matrix=matrix, factor=factor)
return cls(
X=X,
matrix=matrix,
factor=factor,
inds_to_sorted=inds_to_sorted,
sorted_to_inds=sorted_to_inds,
)

def variance(self) -> JAXArray:
return self.matrix.diag.d
if self.sorted_to_inds is None:
return self.matrix.diag.d
return self.matrix.diag.d[self.sorted_to_inds]

def covariance(self) -> JAXArray:
return self.matrix.to_dense()
cov = self.matrix.to_dense()
return cov
if self.sorted_to_inds is None:
return cov
return cov[self.sorted_to_inds[:, None], self.sorted_to_inds[None, :]]

def normalization(self) -> JAXArray:
return jnp.sum(jnp.log(self.factor.diag.d)) + 0.5 * self.factor.shape[
0
] * np.log(2 * np.pi)

@handle_sorting
def solve_triangular(
self, y: JAXArray, *, transpose: bool = False
) -> JAXArray:
Expand All @@ -84,6 +130,7 @@ def solve_triangular(
else:
return self.factor.solve(y)

@handle_sorting
def dot_triangular(self, y: JAXArray) -> JAXArray:
return self.factor @ y

Expand Down
34 changes: 34 additions & 0 deletions tests/test_solvers/test_quasisep/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,37 @@ def test_celerite(data):
calc = gp.log_probability(y)

np.testing.assert_allclose(calc, expected)


def test_unsorted(data):
random = np.random.default_rng(0)
inds = random.permutation(len(data[0]))
inds_t = random.permutation(len(data[2]))
x_ = data[0][inds]
y_ = data[1][inds]
t_ = data[2][inds_t]

kernel = quasisep.Matern32(sigma=1.8, scale=1.5)
gp = GaussianProcess(kernel, data[0], diag=0.1)
gp_ = GaussianProcess(kernel, x_, diag=0.1)
assert isinstance(gp_.solver, QuasisepSolver)

assert np.isfinite(gp_.log_probability(y_))
np.testing.assert_allclose(
gp_.log_probability(y_), gp.log_probability(data[1])
)

np.testing.assert_allclose(
gp_.solver.solve_triangular(y_),
gp.solver.solve_triangular(data[1])[inds],
)

cond = gp.condition(data[1]).gp
cond_ = gp_.condition(y_).gp
np.testing.assert_allclose(cond_.loc, cond.loc[inds])
np.testing.assert_allclose(cond_.variance, cond.variance[inds])

cond = gp.condition(data[1], data[2]).gp
cond_ = gp_.condition(y_, t_).gp
np.testing.assert_allclose(cond_.loc, cond.loc[inds_t])
np.testing.assert_allclose(cond_.variance, cond.variance[inds_t])