diff --git a/src/tinygp/gp.py b/src/tinygp/gp.py index 08970162..6578617b 100644 --- a/src/tinygp/gp.py +++ b/src/tinygp/gp.py @@ -66,6 +66,7 @@ def __init__( solver: Optional[Any] = None, mean_value: Optional[JAXArray] = None, covariance_value: Optional[Any] = None, + **solver_kwargs: Any, ): self.kernel = kernel self.X = X @@ -101,7 +102,11 @@ def __init__( else: solver = DirectSolver self.solver = solver.init( - kernel, self.X, self.noise, covariance=covariance_value + kernel, + self.X, + self.noise, + covariance=covariance_value, + **solver_kwargs, ) @property diff --git a/src/tinygp/kernels/base.py b/src/tinygp/kernels/base.py index 621c1487..64890ae0 100644 --- a/src/tinygp/kernels/base.py +++ b/src/tinygp/kernels/base.py @@ -118,8 +118,6 @@ def __radd__(self, other: Any) -> "Kernel": # We'll hit this first branch when using the `sum` function if other == 0: return self - if isinstance(other, Kernel): - return Sum(other, self) return Sum(Constant(other), self) def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel": @@ -128,8 +126,6 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel": return Product(self, Constant(other)) def __rmul__(self, other: Any) -> "Kernel": - if isinstance(other, Kernel): - return Product(other, self) return Product(Constant(other), self) diff --git a/src/tinygp/solvers/quasisep/solver.py b/src/tinygp/solvers/quasisep/solver.py index 01c23442..7ab402a0 100644 --- a/src/tinygp/solvers/quasisep/solver.py +++ b/src/tinygp/solvers/quasisep/solver.py @@ -4,9 +4,9 @@ __all__ = ["QuasisepSolver"] -from typing import Any, Optional +from functools import wraps +from typing import Any, Callable, Optional -import jax import jax.numpy as jnp import numpy as np @@ -17,6 +17,21 @@ from tinygp.solvers.solver import Solver +def handle_sorting(func: Callable[..., JAXArray]) -> Callable[..., JAXArray]: + @wraps(func) + def wrapped( + self: "QuasisepSolver", y: JAXArray, *args: Any, **kwargs: Any + ) -> JAXArray: + if self.inds_to_sorted is not None: + y = y[self.inds_to_sorted] + r = func(self, y, *args, **kwargs) + if self.sorted_to_inds is not None: + return r[self.sorted_to_inds] + return r + + return wrapped + + @dataclass class QuasisepSolver(Solver): """A scalable solver that uses quasiseparable matrices @@ -32,6 +47,8 @@ class QuasisepSolver(Solver): X: JAXArray matrix: SymmQSM factor: LowerTriQSM + inds_to_sorted: Optional[JAXArray] + sorted_to_inds: Optional[JAXArray] @classmethod def init( @@ -41,6 +58,7 @@ def init( noise: Noise, *, covariance: Optional[Any] = None, + sort: bool = True, ) -> "QuasisepSolver": """Build a :class:`QuasisepSolver` for a given kernel and coordinates @@ -55,27 +73,55 @@ def init( """ from tinygp.kernels.quasisep import Quasisep + inds_to_sorted = None + sorted_to_inds = None if covariance is None: assert isinstance(kernel, Quasisep) + + if sort: + inds_to_sorted = jnp.argsort(kernel.coord_to_sortable(X)) + sorted_to_inds = ( + jnp.empty_like(inds_to_sorted) + .at[inds_to_sorted] + .set(jnp.arange(len(inds_to_sorted))) + ) + X = X[inds_to_sorted] + matrix = kernel.to_symm_qsm(X) matrix += noise.to_qsm() + else: assert isinstance(covariance, SymmQSM) matrix = covariance + factor = matrix.cholesky() - return cls(X=X, matrix=matrix, factor=factor) + return cls( + X=X, + matrix=matrix, + factor=factor, + inds_to_sorted=inds_to_sorted, + sorted_to_inds=sorted_to_inds, + ) def variance(self) -> JAXArray: return self.matrix.diag.d + if self.sorted_to_inds is None: + return self.matrix.diag.d + return self.matrix.diag.d[self.sorted_to_inds] def covariance(self) -> JAXArray: - return self.matrix.to_dense() + cov = self.matrix.to_dense() + return cov + if self.sorted_to_inds is None: + return cov + return cov[self.sorted_to_inds[:, None], self.sorted_to_inds[None, :]] def normalization(self) -> JAXArray: return jnp.sum(jnp.log(self.factor.diag.d)) + 0.5 * self.factor.shape[ 0 ] * np.log(2 * np.pi) + @handle_sorting def solve_triangular( self, y: JAXArray, *, transpose: bool = False ) -> JAXArray: @@ -84,6 +130,7 @@ def solve_triangular( else: return self.factor.solve(y) + @handle_sorting def dot_triangular(self, y: JAXArray) -> JAXArray: return self.factor @ y diff --git a/tests/test_solvers/test_quasisep/test_solver.py b/tests/test_solvers/test_quasisep/test_solver.py index 4b3c3527..77c5be73 100644 --- a/tests/test_solvers/test_quasisep/test_solver.py +++ b/tests/test_solvers/test_quasisep/test_solver.py @@ -125,3 +125,37 @@ def test_celerite(data): calc = gp.log_probability(y) np.testing.assert_allclose(calc, expected) + + +def test_unsorted(data): + random = np.random.default_rng(0) + inds = random.permutation(len(data[0])) + inds_t = random.permutation(len(data[2])) + x_ = data[0][inds] + y_ = data[1][inds] + t_ = data[2][inds_t] + + kernel = quasisep.Matern32(sigma=1.8, scale=1.5) + gp = GaussianProcess(kernel, data[0], diag=0.1) + gp_ = GaussianProcess(kernel, x_, diag=0.1) + assert isinstance(gp_.solver, QuasisepSolver) + + assert np.isfinite(gp_.log_probability(y_)) + np.testing.assert_allclose( + gp_.log_probability(y_), gp.log_probability(data[1]) + ) + + np.testing.assert_allclose( + gp_.solver.solve_triangular(y_), + gp.solver.solve_triangular(data[1])[inds], + ) + + cond = gp.condition(data[1]).gp + cond_ = gp_.condition(y_).gp + np.testing.assert_allclose(cond_.loc, cond.loc[inds]) + np.testing.assert_allclose(cond_.variance, cond.variance[inds]) + + cond = gp.condition(data[1], data[2]).gp + cond_ = gp_.condition(y_, t_).gp + np.testing.assert_allclose(cond_.loc, cond.loc[inds_t]) + np.testing.assert_allclose(cond_.variance, cond.variance[inds_t])