Skip to content
3 changes: 3 additions & 0 deletions pyrecest/_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def get_backend_name():
"randint",
"seed",
"uniform",
# For PyRecEst
"get_state",
"set_state",
],
"fft": [ # For PyRecEst
"rfft",
Expand Down
30 changes: 18 additions & 12 deletions pyrecest/_backend/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def create_random_state(seed = 0):
def global_random_state():
return backend.jax_global_random_state


def set_global_random_state(state):
backend.jax_global_random_state = state

get_state = global_random_state
set_state = set_global_random_state

def get_state(**kwargs):
def _get_state(**kwargs):
has_state = 'state' in kwargs
state = kwargs.pop('state', backend.jax_global_random_state)
return state, has_state, kwargs
Expand All @@ -51,7 +52,7 @@ def _rand(state, size, *args, **kwargs):

def rand(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _rand(state, size, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand All @@ -66,7 +67,7 @@ def _randint(state, size, *args, **kwargs):

def randint(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _randint(state, size, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand All @@ -78,7 +79,7 @@ def _normal(state, size, *args, **kwargs):

def normal(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)

# Check and remove 'mean' and 'cov' from kwargs
mean = kwargs.pop('mean', None)
Expand All @@ -102,7 +103,7 @@ def _choice(state, a, n, *args, **kwargs):


def choice(a, n, *args, **kwargs):
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _choice(state, a, n, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand All @@ -114,14 +115,19 @@ def _multivariate_normal(state, size, *args, **kwargs):

def multivariate_normal(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _multivariate_normal(state, size, *args, **kwargs)
return set_state_return(has_state, state, res)


unsupported_functions = [
'multinomial',
]
for func_name in unsupported_functions:
exec(f"{func_name} = lambda *args, **kwargs: NotImplementedError('This function is not supported in this JAX backend.')")
def multinomial(n, pvals):
"""Sample from a multinomial distribution using JAX."""
import jax.numpy as _jnp
state, has_state, _ = _get_state()
state, key = jax.random.split(state)
backend.jax_global_random_state = state
pvals = _jnp.asarray(pvals, dtype=_jnp.float32)
pvals = pvals / pvals.sum()
samples = jax.random.categorical(key, _jnp.log(pvals), shape=(n,))
return _jnp.bincount(samples, minlength=len(pvals))

1 change: 1 addition & 0 deletions pyrecest/_backend/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
import numpy as _np
from numpy.random import default_rng as _default_rng
from numpy.random import randint, seed, multinomial
from numpy.random import set_state, get_state # For PyRecEst

from .._shared_numpy.random import choice, multivariate_normal, normal, rand, uniform
2 changes: 2 additions & 0 deletions pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch as _torch
from torch import rand, randint
from torch import get_rng_state as get_state # For PyRecEst
from torch import set_rng_state as set_state # For PyRecEst
from torch.distributions.multivariate_normal import (
MultivariateNormal as _MultivariateNormal,
)
Expand Down
7 changes: 2 additions & 5 deletions pyrecest/distributions/abstract_dirac_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from collections.abc import Callable
from typing import Union

from beartype import beartype

# pylint: disable=redefined-builtin,no-name-in-module,no-member
from pyrecest.backend import (
all,
Expand Down Expand Up @@ -54,16 +52,15 @@ def normalize(self) -> "AbstractDiracDistribution":
dist.normalize_in_place()
return dist

@beartype
def apply_function(self, f: Callable, function_is_vectorized: bool = True):
def apply_function(self, f: Callable, f_supports_multiple: bool = True):
"""
Apply a function to the Dirac locations and return a new distribution.

:param f: Function to apply.
:returns: A new distribution with the function applied to the locations.
"""
dist = copy.deepcopy(self)
if function_is_vectorized:
if f_supports_multiple:
dist.d = f(dist.d)
else:
dist.d = apply_along_axis(f, 1, dist.d)
Expand Down
150 changes: 136 additions & 14 deletions pyrecest/distributions/abstract_manifold_specific_distribution.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Union
import inspect

import pyrecest.backend

# pylint: disable=no-name-in-module,no-member
# pylint: disable=no-name-in-module,no-member,redefined-builtin
from pyrecest.backend import empty, int32, int64, log, random, squeeze


Expand Down Expand Up @@ -70,7 +71,7 @@ def sample(self, n: Union[int, int32, int64]):
return self.sample_metropolis_hastings(n)

# jscpd:ignore-start
# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-positional-arguments,too-many-locals
def sample_metropolis_hastings(
self,
n: Union[int, int32, int64],
Expand All @@ -81,30 +82,48 @@ def sample_metropolis_hastings(
):
# jscpd:ignore-end
"""Metropolis Hastings sampling algorithm."""
assert (
pyrecest.backend.__backend_name__ != "jax"
), "Not supported on this backend"
if pyrecest.backend.__backend_name__ == "jax":
# Get a key from your global JAX random state *outside* of lax.scan
import jax as _jax # pylint: disable=import-error

key = random.get_state()
key, key_for_mh = _jax.random.split(key)
# Optionally update global state for future calls
random.set_state(key)

if proposal is None or start_point is None:
raise NotImplementedError(
"Default proposals and starting points should be set in inheriting classes."
)
_assert_proposal_supports_key(proposal)

samples, _ = sample_metropolis_hastings_jax(
key=key_for_mh,
log_pdf=self.ln_pdf,
proposal=proposal, # must be (key, x) -> x_prop for JAX
start_point=start_point,
n=int(n),
burn_in=int(burn_in),
skipping=int(skipping),
)
# You could optionally stash `key_out` somewhere if you want chain continuation.
return squeeze(samples)

# Non-JAX backends -> your old NumPy/Torch code
if proposal is None or start_point is None:
raise NotImplementedError(
"Default proposals and starting points should be set in inheriting classes."
)

total_samples = burn_in + n * skipping
s = empty(
(
total_samples,
self.input_dim,
),
)
s = empty((total_samples, self.input_dim))
x = start_point
i = 0
pdfx = self.pdf(x)

while i < total_samples:
x_new = proposal(x)
assert (
x_new.shape == x.shape
), "Proposal must return a vector of same shape as input"
assert x_new.shape == x.shape, "Proposal must return a vector of same shape as input"
pdfx_new = self.pdf(x_new)
a = pdfx_new / pdfx
if a.item() > 1 or a.item() > random.rand(1):
Expand All @@ -115,3 +134,106 @@ def sample_metropolis_hastings(

relevant_samples = s[burn_in::skipping, :]
return squeeze(relevant_samples)

# pylint: disable=too-many-positional-arguments,too-many-locals,too-many-arguments
def sample_metropolis_hastings_jax(
key,
log_pdf, # function: x -> log p(x)
proposal, # function: (key, x) -> x_prop
start_point,
n: int,
burn_in: int = 10,
skipping: int = 5,
):
"""
Metropolis-Hastings sampler in JAX using a plain Python loop.

Uses a Python loop (rather than lax.scan) so that log_pdf may call
non-JAX-traceable code (e.g. scipy).

key: jax.random.PRNGKey
log_pdf: callable x -> log p(x)
proposal: callable (key, x) -> x_proposed
start_point: initial state (array)
n: number of samples to return (after burn-in and thinning)
"""
import jax.numpy as _jnp # pylint: disable=import-error
from jax import random as _random # pylint: disable=import-error

start_point = _jnp.asarray(start_point)
total_steps = burn_in + n * skipping
chain = []

x = start_point

def _to_scalar(val):
"""Convert a JAX array of any shape to a Python float."""
return float(_jnp.asarray(val).ravel()[0])

log_px = _to_scalar(log_pdf(x))

for _ in range(total_steps):
key, key_prop, key_u = _random.split(key, 3)

# Propose new state
x_prop = proposal(key_prop, x)
log_px_prop = _to_scalar(log_pdf(x_prop))

# Metropolis acceptance
log_alpha = log_px_prop - log_px
log_u = _to_scalar(_jnp.log(_random.uniform(key_u, shape=())))

if log_u < min(0.0, log_alpha):
x = x_prop
log_px = log_px_prop

chain.append(x)

chain_array = _jnp.stack(chain, axis=0)
samples = chain_array[burn_in::skipping]
return samples, key


def _assert_proposal_supports_key(proposal: Callable):
"""
Check that `proposal` can be called as proposal(key, x).

Raises a TypeError with a helpful message if this is not the case.
"""
# Unwrap jitted / partial / decorated functions if possible
func = proposal
while hasattr(func, "__wrapped__"):
func = func.__wrapped__

try:
sig = inspect.signature(func)
except (TypeError, ValueError):
# Can't introspect (e.g. builtins); fall back to a generic error
raise TypeError(
"For the JAX backend, `proposal` must accept (key, x) as arguments, "
"but its signature could not be inspected."
) from None

params = list(sig.parameters.values())

# Count positional(-or-keyword) parameters
num_positional = sum(
p.kind in (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD)
for p in params
)
has_var_positional = any(
p.kind == inspect.Parameter.VAR_POSITIONAL
for p in params
)

if has_var_positional or num_positional >= 2:
# Looks compatible with (key, x)
return

raise TypeError(
"For the JAX backend, `proposal` must accept `(key, x)` as arguments.\n"
f"Got signature: {sig}\n"
"Hint: change your proposal from `def proposal(x): ...` to\n"
"`def proposal(key, x): ...` and use `jax.random` with the passed key."
)
24 changes: 21 additions & 3 deletions pyrecest/distributions/abstract_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
from typing import Union

import pyrecest.backend

# pylint: disable=redefined-builtin,no-name-in-module,no-member
from pyrecest.backend import (
count_nonzero,
Expand Down Expand Up @@ -67,12 +69,28 @@ def input_dim(self) -> int:

def sample(self, n: Union[int, int32, int64]):
occurrences = random.multinomial(n, self.w)
if pyrecest.backend.__backend_name__ == "jax":
samples = []
for i, occ in enumerate(occurrences):
occ_val = occ.item() if hasattr(occ, "item") else int(occ)
if occ_val != 0:
try:
sample_i = self.dists[i].sample(occ_val)
except (NotImplementedError, AssertionError, ValueError, TypeError):
sample_i = self.dists[i].sample_metropolis_hastings(occ_val)
sample_i = pyrecest.backend.atleast_2d(sample_i)
samples.append(sample_i)
if not samples:
return empty((0, self.input_dim))
return pyrecest.backend.concatenate(samples, axis=0)

count = 0
s = empty((n, self.input_dim))
for i, occ in enumerate(occurrences):
if occ != 0:
s[count : count + occ, :] = self.dists[i].sample(occ) # noqa: E203
count += occ
occ_val = occ.item() if hasattr(occ, "item") else int(occ)
if occ_val != 0:
s[count:count + occ_val, :] = self.dists[i].sample(occ_val)
count += occ_val

return s

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,9 @@ def mode(self):
"""
return self.mu

def set_mean(self, new_mean):
"""
Return a copy of this distribution with the location parameter shifted to ``new_mean``.

For bounded dimensions, the mean is wrapped into [0, 2*pi) to stay on the manifold.
"""
new_dist = copy.deepcopy(self)
wrapped_mean = where(
arange(new_mean.shape[0]) < self.bound_dim, mod(new_mean, 2 * pi), new_mean
)
new_dist.mu = wrapped_mean
return new_dist

def set_mode(self, new_mode):
return self.set_mean(new_mode)
self.mu = copy.copy(new_mode)
return self

def hybrid_moment(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def __sub__(
cNew = self.c - other.c
fdNew = CircularFourierDistribution(
c=cNew,
n=self.n,
transformation=self.transformation,
multiplied_by_n=self.multiplied_by_n,
)
fdNew.n = self.n # Preserve the original n value (should remain unchanged)
return fdNew

def pdf(self, xs):
Expand Down
Loading
Loading