Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import copy
from abc import abstractmethod


class StateSpaceSubdivisionDistribution:
"""
Represents a joint distribution over a Cartesian product of a grid-based
(periodic/bounded) space and a linear space, where the linear part is
represented as a collection of distributions conditioned on each grid point
of the periodic/bounded part.

The periodic part is stored as an AbstractGridDistribution, which holds
grid_values (unnormalized weights) at each grid point. The linear part is
stored as a list of distributions, one per grid point.
"""

def __init__(self, gd, linear_distributions):
"""
Parameters
----------
gd : AbstractGridDistribution
Grid-based distribution for the periodic/bounded part. Its
grid_values represent (unnormalized) marginal weights over the
grid points.
linear_distributions : list
One distribution per grid point representing the conditional
distribution of the linear state given that grid point.
"""
assert gd.n_grid_points == len(linear_distributions), (
"Number of grid points in gd must match length of linear_distributions."
)
self.gd = copy.deepcopy(gd)
self.linear_distributions = list(copy.deepcopy(linear_distributions))

@property
def bound_dim(self):
"""Dimension of the periodic/bounded space (ambient dimension of grid points)."""
return self.gd.dim

@property
def lin_dim(self):
"""Dimension of the linear space."""
return self.linear_distributions[0].dim

def hybrid_mean(self):
"""
Returns the hybrid mean, i.e. the concatenation of the mean direction
of the periodic part and the mean of the linear marginal.
"""
# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import concatenate

periodic_mean = self.gd.mean_direction()
linear_mean_val = self.marginalize_periodic().mean()
return concatenate([periodic_mean.reshape(-1), linear_mean_val.reshape(-1)])

@abstractmethod
def marginalize_linear(self):
"""Marginalise out the linear dimensions, returning a distribution over
the periodic/bounded part only."""

@abstractmethod
def marginalize_periodic(self):
"""Marginalise out the periodic/bounded dimensions, returning a
distribution over the linear part only."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import copy
import warnings

# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import (
allclose,
any as backend_any,
argmax,
array,
asarray,
concatenate,
empty,
stack,
sum as backend_sum,
zeros,
)

from ..nonperiodic.gaussian_distribution import GaussianDistribution
from ..nonperiodic.gaussian_mixture import GaussianMixture
from .state_space_subdivision_distribution import StateSpaceSubdivisionDistribution


class StateSpaceSubdivisionGaussianDistribution(StateSpaceSubdivisionDistribution):
"""
Joint distribution over a Cartesian product of a grid-based
(periodic/bounded) space and a linear space where every conditional
linear distribution is a Gaussian.

The periodic part is a grid distribution (e.g. HypertoroidalGridDistribution
or HyperhemisphericalGridDistribution). The linear part is a list of
GaussianDistribution objects, one per grid point.
"""

def __init__(self, gd, gaussians):
"""
Parameters
----------
gd : AbstractGridDistribution
Grid-based distribution for the periodic/bounded part.
gaussians : list of GaussianDistribution
One Gaussian per grid point of *gd*.
"""
assert all(isinstance(g, GaussianDistribution) for g in gaussians), (
"All elements of gaussians must be GaussianDistribution instances."
)
super().__init__(gd, gaussians)

# ------------------------------------------------------------------
# Marginalisation
# ------------------------------------------------------------------

def marginalize_linear(self):
"""Return the grid distribution (marginalised over the linear part)."""
return copy.deepcopy(self.gd)

def marginalize_periodic(self):
"""
Marginalise over the periodic/bounded dimensions.

Returns a GaussianMixture whose components are the conditional
Gaussians and whose weights are the (normalised) grid values.
"""
weights = self.gd.grid_values / backend_sum(self.gd.grid_values)
return GaussianMixture(list(self.linear_distributions), weights)

# ------------------------------------------------------------------
# Linear moments
# ------------------------------------------------------------------

def linear_mean(self):
"""
Compute the mean of the marginal linear distribution by treating
the state as a Gaussian mixture.

Returns
-------
mu : array, shape (lin_dim,)
"""
means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim)
covs = stack(
[ld.C for ld in self.linear_distributions], axis=2
) # (lin_dim, lin_dim, n)
weights = self.gd.grid_values / backend_sum(self.gd.grid_values)
mu, _ = GaussianMixture.mixture_parameters_to_gaussian_parameters(
means, covs, weights
)
return mu

def linear_covariance(self):
"""
Compute the covariance of the marginal linear distribution by treating
the state as a Gaussian mixture.

Returns
-------
C : array, shape (lin_dim, lin_dim)
"""
means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim)
covs = stack(
[ld.C for ld in self.linear_distributions], axis=2
) # (lin_dim, lin_dim, n)
weights = self.gd.grid_values / backend_sum(self.gd.grid_values)
_, C = GaussianMixture.mixture_parameters_to_gaussian_parameters(
means, covs, weights
)
return C

# ------------------------------------------------------------------
# Multiplication
# ------------------------------------------------------------------

def multiply(self, other):
"""
Multiply two StateSpaceSubdivisionGaussianDistributions.

Both operands must be defined on the same grid. For each grid point
the conditional Gaussians are multiplied (Bayesian update). The grid
weights are updated by the likelihood factors that arise from the
overlap of the two conditional Gaussians.

Parameters
----------
other : StateSpaceSubdivisionGaussianDistribution

Returns
-------
StateSpaceSubdivisionGaussianDistribution
"""
assert isinstance(other, StateSpaceSubdivisionGaussianDistribution)
assert self.gd.n_grid_points == other.gd.n_grid_points, (
"Can only multiply distributions defined on grids with the same "
"number of grid points."
)
self_grid = asarray(self.gd.get_grid())
other_grid = asarray(other.gd.get_grid())
assert allclose(self_grid, other_grid), (
"Can only multiply for equal grids."
)

n = len(self.linear_distributions)
new_linear_distributions = []
factors_linear = empty(n)

for i in range(n):
ld_self = self.linear_distributions[i]
ld_other = other.linear_distributions[i]

# The likelihood factor for grid point i is the pdf of
# N(mu_self_i, C_self_i + C_other_i) evaluated at mu_other_i.
# This is equivalent to N(0, C_self_i + C_other_i) at 0.
combined_cov = ld_self.C + ld_other.C
temp_g = GaussianDistribution(ld_other.mu, combined_cov, check_validity=False)
factors_linear[i] = float(temp_g.pdf(ld_self.mu))

new_linear_distributions.append(ld_self.multiply(ld_other))

# Build result
result = copy.deepcopy(self)
result.linear_distributions = new_linear_distributions
result.gd = copy.deepcopy(self.gd)
result.gd.grid_values = (
self.gd.grid_values * other.gd.grid_values * array(factors_linear)
)
result.gd.normalize_in_place(warn_unnorm=False)
return result

# ------------------------------------------------------------------
# Mode
# ------------------------------------------------------------------

def mode(self):
"""
Compute the (approximate) joint mode.

The mode is found by maximising the product of the conditional
Gaussian peak value and the grid weight at each grid point. Only
the discrete grid is searched (no interpolation).

Returns
-------
m : array, shape (bound_dim + lin_dim,)
Concatenation of the periodic mode (grid point) and the linear
mode (mean of the conditional Gaussian at that grid point).

Warns
-----
UserWarning
If the density appears multimodal (i.e. another grid point has a
joint value within a factor of 1.001 of the maximum).
"""
lin_dim = self.linear_distributions[0].dim
zeros_d = zeros(lin_dim)

# Peak value of N(mu_i, C_i) depends only on C_i; it equals
# N(0 | 0, C_i). We evaluate each conditional Gaussian at its own
# mean to obtain the maximum pdf value.
peak_vals = array(
[
float(
GaussianDistribution(zeros_d, ld.C, check_validity=False).pdf(
zeros_d
)
)
for ld in self.linear_distributions
]
)

fun_vals_joint = peak_vals * asarray(self.gd.grid_values)
index = int(argmax(fun_vals_joint))
max_val = float(fun_vals_joint[index])

# Remove the maximum entry to check for multimodality
remaining = concatenate(
[fun_vals_joint[:index], fun_vals_joint[index + 1:]]
)
if len(remaining) > 0 and (
backend_any((max_val - remaining) < 1e-15)
or backend_any((max_val / remaining) < 1.001)
):
warnings.warn(
"Density may not be unimodal. However, this can also be caused "
"by a high grid resolution and thus very similar function values "
"at the grid points.",
UserWarning,
stacklevel=2,
)

periodic_mode = self.gd.get_grid_point(index) # shape (bound_dim,)
linear_mode = self.linear_distributions[index].mu # shape (lin_dim,)
return concatenate([periodic_mode.reshape(-1), linear_mode.reshape(-1)])

# ------------------------------------------------------------------
# Unsupported operations
# ------------------------------------------------------------------

def convolve(self, _other):
raise NotImplementedError(
"convolve is not supported for "
"StateSpaceSubdivisionGaussianDistribution."
)
5 changes: 3 additions & 2 deletions pyrecest/distributions/nonperiodic/gaussian_mixture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=redefined-builtin,no-name-in-module,no-member
# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import array, dot, ones, stack, sum
from pyrecest.backend import array, ones, reshape, stack, sum

from .abstract_linear_distribution import AbstractLinearDistribution
from .gaussian_distribution import GaussianDistribution
Expand All @@ -15,7 +15,8 @@ def __init__(self, dists: list[GaussianDistribution], w):

def mean(self):
gauss_array = self.dists
return dot(array([g.mu for g in gauss_array]), self.w)
means = array([g.mu for g in gauss_array]) # shape (n, dim)
return sum(means * reshape(self.w, (-1, 1)), axis=0)

def set_mean(self, new_mean):
mean_offset = new_mean - self.mean()
Expand Down
Loading
Loading