Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tinygp/means.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ class Mean(MeanBase):
signature.
"""

value: JAXArray | None = None
value: JAXArray
func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True)

def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]):
if callable(value):
self.func = value
self.value = jax.numpy.zeros(())
else:
self.value = value

def __call__(self, X: JAXArray) -> JAXArray:
if self.value is None:
assert self.func is not None
if self.func is not None:
return self.func(X)
return self.value

Expand Down
27 changes: 27 additions & 0 deletions src/tinygp/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import Any

import jax
Expand Down Expand Up @@ -30,3 +31,29 @@ def assert_pytrees_allclose(calculated: Any, expected: Any, *args: Any, **kwargs
jax.tree_util.tree_map(
lambda a, b: assert_allclose(a, b, *args, **kwargs), calculated, expected
)


def _as_context_manager(obj):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more complicated than it needs to be. Let's just switch to using jax.enable_x64 everywhere where we were using the experimental version!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If you are okay with dropping Python 3.10 I will make that change.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it - thanks!

# If it's already a context manager
if hasattr(obj, "__enter__") and hasattr(obj, "__exit__"):
return obj

# If it's a generator, wrap it
if hasattr(obj, "__iter__") and hasattr(obj, "send"):
return contextmanager(lambda: obj)()

raise TypeError("Object is neither a context manager nor a generator")


@contextmanager
def jax_enable_x64():
if hasattr(jax, "enable_x64"):
cm = jax.enable_x64(True)
else:
# deprecated in jax>=0.9
from jax.experimental import enable_x64 as _enable_x64

cm = _enable_x64()

with _as_context_manager(cm):
yield
4 changes: 2 additions & 2 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy import random as np_random

from tinygp import GaussianProcess, kernels
from tinygp.test_utils import assert_allclose
from tinygp.test_utils import assert_allclose, jax_enable_x64


@pytest.fixture
Expand All @@ -24,7 +24,7 @@ def data(random):
def test_sample(data):
X, _ = data

with jax.experimental.enable_x64(True):
with jax_enable_x64():
gp = GaussianProcess(
kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kernels/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tinygp import kernels, noise
from tinygp.solvers import DirectSolver
from tinygp.test_utils import assert_allclose
from tinygp.test_utils import assert_allclose, jax_enable_x64


@pytest.fixture
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_ops(data):

def test_conditioned(data):
x1, x2 = data
with jax.experimental.enable_x64(): # type: ignore
with jax_enable_x64(): # type: ignore
k1 = 1.5 * kernels.Matern32(2.5)
k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3)
K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0])
Expand Down