diff --git a/pyrecest/_backend/__init__.py b/pyrecest/_backend/__init__.py index 98dd2a64e..839833713 100644 --- a/pyrecest/_backend/__init__.py +++ b/pyrecest/_backend/__init__.py @@ -248,6 +248,9 @@ def get_backend_name(): "randint", "seed", "uniform", + # For PyRecEst + "get_state", + "set_state", ], "fft": [ # For PyRecEst "rfft", diff --git a/pyrecest/_backend/jax/random.py b/pyrecest/_backend/jax/random.py index 9a06369d6..b8f486ba2 100644 --- a/pyrecest/_backend/jax/random.py +++ b/pyrecest/_backend/jax/random.py @@ -25,12 +25,13 @@ def create_random_state(seed = 0): def global_random_state(): return backend.jax_global_random_state - def set_global_random_state(state): backend.jax_global_random_state = state +get_state = global_random_state +set_state = set_global_random_state -def get_state(**kwargs): +def _get_state(**kwargs): has_state = 'state' in kwargs state = kwargs.pop('state', backend.jax_global_random_state) return state, has_state, kwargs @@ -51,7 +52,7 @@ def _rand(state, size, *args, **kwargs): def rand(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _rand(state, size, *args, **kwargs) return set_state_return(has_state, state, res) @@ -66,7 +67,7 @@ def _randint(state, size, *args, **kwargs): def randint(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _randint(state, size, *args, **kwargs) return set_state_return(has_state, state, res) @@ -78,7 +79,7 @@ def _normal(state, size, *args, **kwargs): def normal(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) # Check and remove 'mean' and 'cov' from kwargs mean = kwargs.pop('mean', None) @@ -102,7 +103,7 @@ def _choice(state, a, n, *args, **kwargs): def choice(a, n, *args, **kwargs): - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _choice(state, a, n, *args, **kwargs) return set_state_return(has_state, state, res) @@ -114,14 +115,19 @@ def _multivariate_normal(state, size, *args, **kwargs): def multivariate_normal(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _multivariate_normal(state, size, *args, **kwargs) return set_state_return(has_state, state, res) -unsupported_functions = [ - 'multinomial', -] -for func_name in unsupported_functions: - exec(f"{func_name} = lambda *args, **kwargs: NotImplementedError('This function is not supported in this JAX backend.')") +def multinomial(n, pvals): + """Sample from a multinomial distribution using JAX.""" + import jax.numpy as _jnp + state, has_state, _ = _get_state() + state, key = jax.random.split(state) + backend.jax_global_random_state = state + pvals = _jnp.asarray(pvals, dtype=_jnp.float32) + pvals = pvals / pvals.sum() + samples = jax.random.categorical(key, _jnp.log(pvals), shape=(n,)) + return _jnp.bincount(samples, minlength=len(pvals)) diff --git a/pyrecest/_backend/numpy/random.py b/pyrecest/_backend/numpy/random.py index 33f479036..2822318af 100644 --- a/pyrecest/_backend/numpy/random.py +++ b/pyrecest/_backend/numpy/random.py @@ -3,5 +3,6 @@ import numpy as _np from numpy.random import default_rng as _default_rng from numpy.random import randint, seed, multinomial +from numpy.random import set_state, get_state # For PyRecEst from .._shared_numpy.random import choice, multivariate_normal, normal, rand, uniform diff --git a/pyrecest/_backend/pytorch/random.py b/pyrecest/_backend/pytorch/random.py index ca405228e..4db7aac0a 100644 --- a/pyrecest/_backend/pytorch/random.py +++ b/pyrecest/_backend/pytorch/random.py @@ -2,6 +2,8 @@ import torch as _torch from torch import rand, randint +from torch import get_rng_state as get_state # For PyRecEst +from torch import set_rng_state as set_state # For PyRecEst from torch.distributions.multivariate_normal import ( MultivariateNormal as _MultivariateNormal, ) diff --git a/pyrecest/distributions/abstract_dirac_distribution.py b/pyrecest/distributions/abstract_dirac_distribution.py index 19a38bfbe..2c8ce6733 100644 --- a/pyrecest/distributions/abstract_dirac_distribution.py +++ b/pyrecest/distributions/abstract_dirac_distribution.py @@ -3,8 +3,6 @@ from collections.abc import Callable from typing import Union -from beartype import beartype - # pylint: disable=redefined-builtin,no-name-in-module,no-member from pyrecest.backend import ( all, @@ -54,8 +52,7 @@ def normalize(self) -> "AbstractDiracDistribution": dist.normalize_in_place() return dist - @beartype - def apply_function(self, f: Callable, function_is_vectorized: bool = True): + def apply_function(self, f: Callable, f_supports_multiple: bool = True): """ Apply a function to the Dirac locations and return a new distribution. @@ -63,7 +60,7 @@ def apply_function(self, f: Callable, function_is_vectorized: bool = True): :returns: A new distribution with the function applied to the locations. """ dist = copy.deepcopy(self) - if function_is_vectorized: + if f_supports_multiple: dist.d = f(dist.d) else: dist.d = apply_along_axis(f, 1, dist.d) diff --git a/pyrecest/distributions/abstract_manifold_specific_distribution.py b/pyrecest/distributions/abstract_manifold_specific_distribution.py index 2e9f2df5d..774de226b 100644 --- a/pyrecest/distributions/abstract_manifold_specific_distribution.py +++ b/pyrecest/distributions/abstract_manifold_specific_distribution.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod from collections.abc import Callable from typing import Union +import inspect import pyrecest.backend -# pylint: disable=no-name-in-module,no-member +# pylint: disable=no-name-in-module,no-member,redefined-builtin from pyrecest.backend import empty, int32, int64, log, random, squeeze @@ -70,7 +71,7 @@ def sample(self, n: Union[int, int32, int64]): return self.sample_metropolis_hastings(n) # jscpd:ignore-start - # pylint: disable=too-many-positional-arguments + # pylint: disable=too-many-positional-arguments,too-many-locals def sample_metropolis_hastings( self, n: Union[int, int32, int64], @@ -81,30 +82,48 @@ def sample_metropolis_hastings( ): # jscpd:ignore-end """Metropolis Hastings sampling algorithm.""" - assert ( - pyrecest.backend.__backend_name__ != "jax" - ), "Not supported on this backend" + if pyrecest.backend.__backend_name__ == "jax": + # Get a key from your global JAX random state *outside* of lax.scan + import jax as _jax # pylint: disable=import-error + + key = random.get_state() + key, key_for_mh = _jax.random.split(key) + # Optionally update global state for future calls + random.set_state(key) + + if proposal is None or start_point is None: + raise NotImplementedError( + "Default proposals and starting points should be set in inheriting classes." + ) + _assert_proposal_supports_key(proposal) + + samples, _ = sample_metropolis_hastings_jax( + key=key_for_mh, + log_pdf=self.ln_pdf, + proposal=proposal, # must be (key, x) -> x_prop for JAX + start_point=start_point, + n=int(n), + burn_in=int(burn_in), + skipping=int(skipping), + ) + # You could optionally stash `key_out` somewhere if you want chain continuation. + return squeeze(samples) + + # Non-JAX backends -> your old NumPy/Torch code if proposal is None or start_point is None: raise NotImplementedError( "Default proposals and starting points should be set in inheriting classes." ) total_samples = burn_in + n * skipping - s = empty( - ( - total_samples, - self.input_dim, - ), - ) + s = empty((total_samples, self.input_dim)) x = start_point i = 0 pdfx = self.pdf(x) while i < total_samples: x_new = proposal(x) - assert ( - x_new.shape == x.shape - ), "Proposal must return a vector of same shape as input" + assert x_new.shape == x.shape, "Proposal must return a vector of same shape as input" pdfx_new = self.pdf(x_new) a = pdfx_new / pdfx if a.item() > 1 or a.item() > random.rand(1): @@ -115,3 +134,106 @@ def sample_metropolis_hastings( relevant_samples = s[burn_in::skipping, :] return squeeze(relevant_samples) + +# pylint: disable=too-many-positional-arguments,too-many-locals,too-many-arguments +def sample_metropolis_hastings_jax( + key, + log_pdf, # function: x -> log p(x) + proposal, # function: (key, x) -> x_prop + start_point, + n: int, + burn_in: int = 10, + skipping: int = 5, +): + """ + Metropolis-Hastings sampler in JAX using a plain Python loop. + + Uses a Python loop (rather than lax.scan) so that log_pdf may call + non-JAX-traceable code (e.g. scipy). + + key: jax.random.PRNGKey + log_pdf: callable x -> log p(x) + proposal: callable (key, x) -> x_proposed + start_point: initial state (array) + n: number of samples to return (after burn-in and thinning) + """ + import jax.numpy as _jnp # pylint: disable=import-error + from jax import random as _random # pylint: disable=import-error + + start_point = _jnp.asarray(start_point) + total_steps = burn_in + n * skipping + chain = [] + + x = start_point + + def _to_scalar(val): + """Convert a JAX array of any shape to a Python float.""" + return float(_jnp.asarray(val).ravel()[0]) + + log_px = _to_scalar(log_pdf(x)) + + for _ in range(total_steps): + key, key_prop, key_u = _random.split(key, 3) + + # Propose new state + x_prop = proposal(key_prop, x) + log_px_prop = _to_scalar(log_pdf(x_prop)) + + # Metropolis acceptance + log_alpha = log_px_prop - log_px + log_u = _to_scalar(_jnp.log(_random.uniform(key_u, shape=()))) + + if log_u < min(0.0, log_alpha): + x = x_prop + log_px = log_px_prop + + chain.append(x) + + chain_array = _jnp.stack(chain, axis=0) + samples = chain_array[burn_in::skipping] + return samples, key + + +def _assert_proposal_supports_key(proposal: Callable): + """ + Check that `proposal` can be called as proposal(key, x). + + Raises a TypeError with a helpful message if this is not the case. + """ + # Unwrap jitted / partial / decorated functions if possible + func = proposal + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + + try: + sig = inspect.signature(func) + except (TypeError, ValueError): + # Can't introspect (e.g. builtins); fall back to a generic error + raise TypeError( + "For the JAX backend, `proposal` must accept (key, x) as arguments, " + "but its signature could not be inspected." + ) from None + + params = list(sig.parameters.values()) + + # Count positional(-or-keyword) parameters + num_positional = sum( + p.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD) + for p in params + ) + has_var_positional = any( + p.kind == inspect.Parameter.VAR_POSITIONAL + for p in params + ) + + if has_var_positional or num_positional >= 2: + # Looks compatible with (key, x) + return + + raise TypeError( + "For the JAX backend, `proposal` must accept `(key, x)` as arguments.\n" + f"Got signature: {sig}\n" + "Hint: change your proposal from `def proposal(x): ...` to\n" + "`def proposal(key, x): ...` and use `jax.random` with the passed key." + ) diff --git a/pyrecest/distributions/abstract_mixture.py b/pyrecest/distributions/abstract_mixture.py index 77fde89ca..1e17aa7c9 100644 --- a/pyrecest/distributions/abstract_mixture.py +++ b/pyrecest/distributions/abstract_mixture.py @@ -3,6 +3,8 @@ import warnings from typing import Union +import pyrecest.backend + # pylint: disable=redefined-builtin,no-name-in-module,no-member from pyrecest.backend import ( count_nonzero, @@ -67,12 +69,28 @@ def input_dim(self) -> int: def sample(self, n: Union[int, int32, int64]): occurrences = random.multinomial(n, self.w) + if pyrecest.backend.__backend_name__ == "jax": + samples = [] + for i, occ in enumerate(occurrences): + occ_val = occ.item() if hasattr(occ, "item") else int(occ) + if occ_val != 0: + try: + sample_i = self.dists[i].sample(occ_val) + except (NotImplementedError, AssertionError, ValueError, TypeError): + sample_i = self.dists[i].sample_metropolis_hastings(occ_val) + sample_i = pyrecest.backend.atleast_2d(sample_i) + samples.append(sample_i) + if not samples: + return empty((0, self.input_dim)) + return pyrecest.backend.concatenate(samples, axis=0) + count = 0 s = empty((n, self.input_dim)) for i, occ in enumerate(occurrences): - if occ != 0: - s[count : count + occ, :] = self.dists[i].sample(occ) # noqa: E203 - count += occ + occ_val = occ.item() if hasattr(occ, "item") else int(occ) + if occ_val != 0: + s[count:count + occ_val, :] = self.dists[i].sample(occ_val) # noqa: E203 + count += occ_val return s diff --git a/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py b/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py index ba4a44fb9..929b50121 100644 --- a/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py +++ b/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py @@ -111,21 +111,9 @@ def mode(self): """ return self.mu - def set_mean(self, new_mean): - """ - Return a copy of this distribution with the location parameter shifted to ``new_mean``. - - For bounded dimensions, the mean is wrapped into [0, 2*pi) to stay on the manifold. - """ - new_dist = copy.deepcopy(self) - wrapped_mean = where( - arange(new_mean.shape[0]) < self.bound_dim, mod(new_mean, 2 * pi), new_mean - ) - new_dist.mu = wrapped_mean - return new_dist - def set_mode(self, new_mode): - return self.set_mean(new_mode) + self.mu = copy.copy(new_mode) + return self def hybrid_moment(self): """ diff --git a/pyrecest/distributions/circle/circular_fourier_distribution.py b/pyrecest/distributions/circle/circular_fourier_distribution.py index 68fdb65dd..9eff1dc3d 100644 --- a/pyrecest/distributions/circle/circular_fourier_distribution.py +++ b/pyrecest/distributions/circle/circular_fourier_distribution.py @@ -100,10 +100,10 @@ def __sub__( cNew = self.c - other.c fdNew = CircularFourierDistribution( c=cNew, - n=self.n, transformation=self.transformation, multiplied_by_n=self.multiplied_by_n, ) + fdNew.n = self.n # Preserve the original n value (should remain unchanged) return fdNew def pdf(self, xs): diff --git a/pyrecest/distributions/circle/wrapped_normal_distribution.py b/pyrecest/distributions/circle/wrapped_normal_distribution.py index 0a620fcc8..b1f121e06 100644 --- a/pyrecest/distributions/circle/wrapped_normal_distribution.py +++ b/pyrecest/distributions/circle/wrapped_normal_distribution.py @@ -50,10 +50,6 @@ def __init__( """ AbstractCircularDistribution.__init__(self) HypertoroidalWrappedNormalDistribution.__init__(self, mu, sigma**2) - if ndim(mu) != 0: - raise ValueError(f"mu must be a scalar, but got shape {mu.shape}.") - if ndim(sigma) != 0: - raise ValueError(f"sigma must be a scalar, but got shape {sigma.shape}.") @property def sigma(self): diff --git a/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py index a16e685b1..d643f1566 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import pyrecest.backend -# pylint: disable=no-name-in-module,no-member +# pylint: disable=no-name-in-module,no-member,duplicate-code from pyrecest.backend import ( array, concatenate, @@ -61,8 +61,29 @@ def sample_metropolis_hastings( HyperhemisphericalUniformDistribution, ) - def proposal(_): - return HyperhemisphericalUniformDistribution(self.dim).sample(1) + if pyrecest.backend.__backend_name__ in ("numpy", "pytorch"): + def proposal_np(_): + return HyperhemisphericalUniformDistribution(self.dim).sample(1) + + proposal = proposal_np + else: + # JAX backend: proposal(key, x) -> x_prop + import jax as _jax # pylint: disable=import-error + import jax.numpy as _jnp # pylint: disable=import-error + + def proposal_jax(key, _): + """JAX independence proposal: uniform on upper hemisphere.""" + key, subkey = _jax.random.split(key) + s = _jax.random.normal(subkey, shape=(1, self.dim + 1)) + # Project to upper hemisphere: last coordinate >= 0 + sign = _jnp.where(s[..., -1:] < 0.0, -1.0, 1.0) + s = sign * s + + # Ensure exact unit norm (avoids float32 rounding errors) + s = s / _jnp.linalg.norm(s, axis=-1, keepdims=True) + return s + + proposal = proposal_jax if start_point is None: start_point = HyperhemisphericalUniformDistribution(self.dim).sample(1) diff --git a/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_grid_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_grid_distribution.py index 890e28cff..af8eda8e7 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_grid_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_grid_distribution.py @@ -37,7 +37,7 @@ def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): enforce_pdf_nonnegative=enforce_pdf_nonnegative, ) AbstractHypersphereSubsetDistribution.__init__(self, dim=grid.shape[1]) - self.normalize(warn_unnorm=False) + self.normalize() def mean_direction(self): warnings.warn( diff --git a/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py index 378b6fff7..8925a05f4 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py @@ -74,8 +74,24 @@ def sample_metropolis_hastings( HypersphericalUniformDistribution, ) - def proposal(_): - return HypersphericalUniformDistribution(self.dim).sample(1) + if pyrecest.backend.__backend_name__ in ("numpy", "pytorch"): + def proposal_np(_): + return HypersphericalUniformDistribution(self.dim).sample(1) + + proposal = proposal_np + else: + import jax as _jax # pylint: disable=import-error + import jax.numpy as _jnp # pylint: disable=import-error + + def proposal_jax(key, _): + """JAX independence proposal: uniform on hypersphere.""" + key, subkey = _jax.random.split(key) + s = _jax.random.normal(subkey, shape=(1, self.dim + 1)) + # Ensure exact unit norm (avoids float32 rounding errors) + s = s / _jnp.linalg.norm(s, axis=-1, keepdims=True) + return s + + proposal = proposal_jax if start_point is None: start_point = HypersphericalUniformDistribution(self.dim).sample(1) diff --git a/pyrecest/distributions/hypersphere_subset/abstract_spherical_harmonics_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_spherical_harmonics_distribution.py index b1694fe8a..2c8cf101e 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_spherical_harmonics_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_spherical_harmonics_distribution.py @@ -52,7 +52,7 @@ def __init__(self, coeff_mat, transformation="identity"): def pdf(self, xs): return AbstractOrthogonalBasisDistribution.pdf(self, xs) - def normalize_in_place(self, warn_unnorm=True): + def normalize_in_place(self): int_val = self.integrate() if int_val < 0: warnings.warn( @@ -66,11 +66,10 @@ def normalize_in_place(self, warn_unnorm=True): "this usually points to a user error" ) elif abs(int_val - 1) > 1e-5: - if warn_unnorm: - warnings.warn( - "Warning: Normalization:notNormalized - Coefficients apparently do not belong " - "to normalized density. Normalizing..." - ) + warnings.warn( + "Warning: Normalization:notNormalized - Coefficients apparently do not belong " + "to normalized density. Normalizing..." + ) else: return diff --git a/pyrecest/distributions/hypertorus/abstract_hypertoroidal_distribution.py b/pyrecest/distributions/hypertorus/abstract_hypertoroidal_distribution.py index 23ad551a6..3c6b799b6 100644 --- a/pyrecest/distributions/hypertorus/abstract_hypertoroidal_distribution.py +++ b/pyrecest/distributions/hypertorus/abstract_hypertoroidal_distribution.py @@ -268,9 +268,21 @@ def sample_metropolis_hastings( ): # jscpd:ignore-end if proposal is None: + if pyrecest.backend.__backend_name__ == "jax": + import jax as _jax # pylint: disable=import-error + import jax.numpy as _jnp # pylint: disable=import-error - def proposal(x): - return mod(x + random.normal(0.0, 1.0, (self.dim,)), 2.0 * pi) + def proposal_jax(key, x): + key, subkey = _jax.random.split(key) + noise = _jax.random.normal(subkey, shape=(self.dim,)) + return _jnp.mod(x + noise, 2.0 * _jnp.pi) + + proposal = proposal_jax + else: + def proposal_np(x): + return mod(x + random.normal(0.0, 1.0, (self.dim,)), 2.0 * pi) + + proposal = proposal_np if start_point is None: start_point = self.mean_direction() diff --git a/pyrecest/distributions/hypertorus/hypertoroidal_dirac_distribution.py b/pyrecest/distributions/hypertorus/hypertoroidal_dirac_distribution.py index 50b9e405c..6e45dc5eb 100644 --- a/pyrecest/distributions/hypertorus/hypertoroidal_dirac_distribution.py +++ b/pyrecest/distributions/hypertorus/hypertoroidal_dirac_distribution.py @@ -79,8 +79,8 @@ def trigonometric_moment(self, n: Union[int, int32, int64]): """ return sum(exp(1j * n * self.d.T) * tile(self.w, (self.dim, 1)), axis=1) - def apply_function(self, f: Callable, function_is_vectorized: bool = True): - dist = super().apply_function(f, function_is_vectorized) + def apply_function(self, f: Callable, f_supports_multiple: bool = True): + dist = super().apply_function(f, f_supports_multiple) dist.d = mod(dist.d, 2.0 * pi) return dist diff --git a/pyrecest/filters/random_matrix_tracker.py b/pyrecest/filters/random_matrix_tracker.py index 1024d83ce..ae6388a15 100644 --- a/pyrecest/filters/random_matrix_tracker.py +++ b/pyrecest/filters/random_matrix_tracker.py @@ -110,12 +110,14 @@ def update(self, measurements, meas_mat, meas_noise_cov): def plot_point_estimate(self, scaling_factor=1, color=(0, 0.4470, 0.7410)): if self.kinematic_state_to_pos_matrix is None: - raise ValueError("""No kinematic_state_to_pos_matrix + raise ValueError( + """No kinematic_state_to_pos_matrix matrix was set, so it is unclear what the individual components of the kinematic state are (position, velocity, etc.). Please set it directly or perform an update step - before plotting.""") + before plotting.""" + ) position_estimate = self.kinematic_state_to_pos_matrix @ self.kinematic_state plot_ellipsoid(position_estimate, self.extent, scaling_factor, color) diff --git a/pyrecest/tests/distributions/test_abstract_mixture.py b/pyrecest/tests/distributions/test_abstract_mixture.py index 950e1c62a..d3fe3183e 100644 --- a/pyrecest/tests/distributions/test_abstract_mixture.py +++ b/pyrecest/tests/distributions/test_abstract_mixture.py @@ -25,10 +25,6 @@ def _test_sample(self, mix, n): self.assertEqual(s.shape, (n, mix.input_dim)) return s - @unittest.skipIf( - pyrecest.backend.__backend_name__ == "jax", - reason="Not supported on this backend", - ) def test_sample_metropolis_hastings_basics_only_t2(self): vmf = ToroidalWrappedNormalDistribution(array([1.0, 0.0]), eye(2)) mix = HypertoroidalMixture( @@ -37,33 +33,33 @@ def test_sample_metropolis_hastings_basics_only_t2(self): self._test_sample(mix, 10) @unittest.skipIf( - pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + pyrecest.backend.__backend_name__ in ("pytorch",), reason="Not supported on this backend", ) def test_sample_metropolis_hastings_basics_only_s2(self): vmf1 = VonMisesFisherDistribution( array([1.0, 0.0, 0.0]), 2.0 - ) # Needs to be float for scipy + ) vmf2 = VonMisesFisherDistribution( array([0.0, 1.0, 0.0]), 2.0 - ) # Needs to be float for scipy + ) mix = HypersphericalMixture([vmf1, vmf2], array([0.5, 0.5])) s = self._test_sample(mix, 10) - self.assertTrue(allclose(linalg.norm(s, axis=1), ones(10), rtol=1e-10)) + self.assertTrue(allclose(linalg.norm(s, axis=1), ones(10), rtol=5e-7)) @unittest.skipIf( - pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + pyrecest.backend.__backend_name__ in ("pytorch",), reason="Not supported on this backend", ) def test_sample_metropolis_hastings_basics_only_h2(self): vmf = VonMisesFisherDistribution( array([1.0, 0.0, 0.0]), 2.0 - ) # Needs to be float for scipy + ) mix = CustomHyperhemisphericalDistribution( lambda x: vmf.pdf(x) + vmf.pdf(-x), 2 ) s = self._test_sample(mix, 10) - self.assertTrue(allclose(linalg.norm(s, axis=1), ones(10), rtol=1e-10)) + self.assertTrue(allclose(linalg.norm(s, axis=1), ones(10), rtol=5e-7)) if __name__ == "__main__": diff --git a/pyrecest/tests/distributions/test_hypertoroidal_fourier_distribution.py b/pyrecest/tests/distributions/test_hypertoroidal_fourier_distribution.py index b9346d909..5eb942e17 100644 --- a/pyrecest/tests/distributions/test_hypertoroidal_fourier_distribution.py +++ b/pyrecest/tests/distributions/test_hypertoroidal_fourier_distribution.py @@ -93,7 +93,7 @@ def test_normalization_nd(self, _, transform, shape, index): unnormalizedCoeffs = fft.fftshift(fft.fftn(arr)) unnormalizedCoeffs[index] = 1 with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) + warnings.simplefilter("ignore", UserWarning) hfd = HypertoroidalFourierDistribution( unnormalizedCoeffs, transformation=transform ) diff --git a/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py b/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py index 2383b78d2..b1e5d0c43 100644 --- a/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py +++ b/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py @@ -1,6 +1,4 @@ import unittest -import warnings - import numpy.testing as npt import pyrecest.backend from parameterized import parameterized @@ -121,11 +119,9 @@ def test_integral_analytical(self, transformation): ] ) # First initialize and overwrite afterward to prevent normalization - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - shd = SphericalHarmonicsDistributionComplex( - array([[1.0, float("NaN"), float("NaN")], [0.0, 0.0, 0.0]]) - ) + shd = SphericalHarmonicsDistributionComplex( + array([[1.0, float("NaN"), float("NaN")], [0.0, 0.0, 0.0]]) + ) shd.coeff_mat = unnormalized_coeffs shd.transformation = transformation int_val_num = shd.integrate_numerically() @@ -133,16 +129,13 @@ def test_integral_analytical(self, transformation): npt.assert_almost_equal(int_val_ana, int_val_num) def test_truncation(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - shd = SphericalHarmonicsDistributionComplex(self.unnormalized_coeffs) + shd = SphericalHarmonicsDistributionComplex(self.unnormalized_coeffs) with self.assertWarns(UserWarning): shd2 = shd.truncate(4) self.assertEqual(shd2.coeff_mat.shape, (5, 9)) self.assertTrue(all(isnan(shd2.coeff_mat[4, :]) | (shd2.coeff_mat[4, :] == 0))) - with self.assertWarns(UserWarning): - shd3 = shd.truncate(5) + shd3 = shd.truncate(5) self.assertEqual(shd3.coeff_mat.shape, (6, 11)) self.assertTrue( all( @@ -1081,9 +1074,7 @@ def test_basis_function_complex(self, name, coeff_mat, expected_func): "Test not supported for this backend", ) def test_conversion(self, _, coeff_mat): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - shd = SphericalHarmonicsDistributionComplex(coeff_mat) + shd = SphericalHarmonicsDistributionComplex(coeff_mat) rshd = shd.to_spherical_harmonics_distribution_real() phi, theta = meshgrid(linspace(0.0, 2 * pi, 10), linspace(-pi / 2, pi / 2, 10)) x, y, z = AbstractSphericalDistribution.sph_to_cart(phi.ravel(), theta.ravel()) @@ -1185,9 +1176,7 @@ def test_conversion(self, _, coeff_mat): "Test not supported for this backend", ) def test_mean_direction(self, _, input_array, expected_output, fun_to_test): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - shd = SphericalHarmonicsDistributionComplex(array(input_array)) + shd = SphericalHarmonicsDistributionComplex(array(input_array)) npt.assert_allclose(fun_to_test(shd), expected_output, atol=1e-10) @unittest.skipIf( @@ -1235,11 +1224,9 @@ def test_from_distribution_via_integral_uniform(self): ) def test_transformation_via_integral_shd(self): # Test approximating a spherical harmonic distribution - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - dist = SphericalHarmonicsDistributionComplex( - array([[1, float("NaN"), float("NaN")], [0.0, 1, 0]]) - ) + dist = SphericalHarmonicsDistributionComplex( + array([[1, float("NaN"), float("NaN")], [0.0, 1, 0]]) + ) shd = SphericalHarmonicsDistributionComplex.from_function_via_integral_cart( dist.pdf, 1 diff --git a/pyrecest/tests/distributions/test_spherical_harmonics_distribution_real.py b/pyrecest/tests/distributions/test_spherical_harmonics_distribution_real.py index b29e2602d..b303b92b6 100644 --- a/pyrecest/tests/distributions/test_spherical_harmonics_distribution_real.py +++ b/pyrecest/tests/distributions/test_spherical_harmonics_distribution_real.py @@ -45,8 +45,7 @@ def testNormalizationWarning(self): ) def testNormalization(self): unnormalized_coeffs = random.uniform(size=(3, 5)) - with self.assertWarns(UserWarning): - shd = SphericalHarmonicsDistributionReal(unnormalized_coeffs) + shd = SphericalHarmonicsDistributionReal(unnormalized_coeffs) self.assertAlmostEqual(shd.integrate(), 1.0, delta=1e-6) x, y, z = SphericalHarmonicsDistributionRealTest._gen_naive_grid(10) @@ -454,9 +453,7 @@ def _gen_naive_grid(n_per_dim): "Test not supported for this backend", ) def test_conversion(self, _, coeff_mat): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - rshd = SphericalHarmonicsDistributionReal(coeff_mat) + rshd = SphericalHarmonicsDistributionReal(coeff_mat) cshd = rshd.to_spherical_harmonics_distribution_complex() phi_to_test, theta_to_test = ( random.uniform(size=10) * 2 * pi, diff --git a/pyrecest/tests/test_evaluation_basic.py b/pyrecest/tests/test_evaluation_basic.py index e823feb8b..fbe83b116 100644 --- a/pyrecest/tests/test_evaluation_basic.py +++ b/pyrecest/tests/test_evaluation_basic.py @@ -2,7 +2,6 @@ import os import tempfile import unittest -import warnings from typing import Optional import matplotlib @@ -73,13 +72,11 @@ def test_plot_results(self): self.test_evaluate_for_simulation_config_R2_random_walk() filename = self._get_single_evaluation_file() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - figs, _ = plot_results( - filename=filename, - plot_log=False, - plot_stds=False, - ) + figs, _ = plot_results( + filename=filename, + plot_log=False, + plot_stds=False, + ) try: for fig in figs: @@ -613,9 +610,7 @@ def test_group_results_by_filter(self): ) def test_summarize_filter_results(self): data = self._load_evaluation_data() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - results_summarized = summarize_filter_results(**data) + results_summarized = summarize_filter_results(**data) for result in results_summarized: error_mean = result["error_mean"]