diff --git a/pyrecest/distributions/cart_prod/state_space_subdivision_distribution.py b/pyrecest/distributions/cart_prod/state_space_subdivision_distribution.py new file mode 100644 index 000000000..980b21a89 --- /dev/null +++ b/pyrecest/distributions/cart_prod/state_space_subdivision_distribution.py @@ -0,0 +1,65 @@ +import copy +from abc import abstractmethod + + +class StateSpaceSubdivisionDistribution: + """ + Represents a joint distribution over a Cartesian product of a grid-based + (periodic/bounded) space and a linear space, where the linear part is + represented as a collection of distributions conditioned on each grid point + of the periodic/bounded part. + + The periodic part is stored as an AbstractGridDistribution, which holds + grid_values (unnormalized weights) at each grid point. The linear part is + stored as a list of distributions, one per grid point. + """ + + def __init__(self, gd, linear_distributions): + """ + Parameters + ---------- + gd : AbstractGridDistribution + Grid-based distribution for the periodic/bounded part. Its + grid_values represent (unnormalized) marginal weights over the + grid points. + linear_distributions : list + One distribution per grid point representing the conditional + distribution of the linear state given that grid point. + """ + assert gd.n_grid_points == len(linear_distributions), ( + "Number of grid points in gd must match length of linear_distributions." + ) + self.gd = copy.deepcopy(gd) + self.linear_distributions = list(copy.deepcopy(linear_distributions)) + + @property + def bound_dim(self): + """Dimension of the periodic/bounded space (ambient dimension of grid points).""" + return self.gd.dim + + @property + def lin_dim(self): + """Dimension of the linear space.""" + return self.linear_distributions[0].dim + + def hybrid_mean(self): + """ + Returns the hybrid mean, i.e. the concatenation of the mean direction + of the periodic part and the mean of the linear marginal. + """ + # pylint: disable=no-name-in-module,no-member + from pyrecest.backend import concatenate + + periodic_mean = self.gd.mean_direction() + linear_mean_val = self.marginalize_periodic().mean() + return concatenate([periodic_mean.reshape(-1), linear_mean_val.reshape(-1)]) + + @abstractmethod + def marginalize_linear(self): + """Marginalise out the linear dimensions, returning a distribution over + the periodic/bounded part only.""" + + @abstractmethod + def marginalize_periodic(self): + """Marginalise out the periodic/bounded dimensions, returning a + distribution over the linear part only.""" diff --git a/pyrecest/distributions/cart_prod/state_space_subdivision_gaussian_distribution.py b/pyrecest/distributions/cart_prod/state_space_subdivision_gaussian_distribution.py new file mode 100644 index 000000000..d16ddb929 --- /dev/null +++ b/pyrecest/distributions/cart_prod/state_space_subdivision_gaussian_distribution.py @@ -0,0 +1,244 @@ +import copy +import warnings + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import ( + allclose, + any as backend_any, + argmax, + array, + asarray, + concatenate, + stack, + sum as backend_sum, + zeros, +) + +from ..nonperiodic.gaussian_distribution import GaussianDistribution +from ..nonperiodic.gaussian_mixture import GaussianMixture +from .state_space_subdivision_distribution import StateSpaceSubdivisionDistribution + + +class StateSpaceSubdivisionGaussianDistribution(StateSpaceSubdivisionDistribution): + """ + Joint distribution over a Cartesian product of a grid-based + (periodic/bounded) space and a linear space where every conditional + linear distribution is a Gaussian. + + The periodic part is a grid distribution (e.g. HypertoroidalGridDistribution + or HyperhemisphericalGridDistribution). The linear part is a list of + GaussianDistribution objects, one per grid point. + """ + + def __init__(self, gd, gaussians): + """ + Parameters + ---------- + gd : AbstractGridDistribution + Grid-based distribution for the periodic/bounded part. + gaussians : list of GaussianDistribution + One Gaussian per grid point of *gd*. + """ + assert all(isinstance(g, GaussianDistribution) for g in gaussians), ( + "All elements of gaussians must be GaussianDistribution instances." + ) + super().__init__(gd, gaussians) + + # ------------------------------------------------------------------ + # Marginalisation + # ------------------------------------------------------------------ + + def marginalize_linear(self): + """Return the grid distribution (marginalised over the linear part).""" + return copy.deepcopy(self.gd) + + def marginalize_periodic(self): + """ + Marginalise over the periodic/bounded dimensions. + + Returns a GaussianMixture whose components are the conditional + Gaussians and whose weights are the (normalised) grid values. + """ + weights = self.gd.grid_values / backend_sum(self.gd.grid_values) + return GaussianMixture(list(self.linear_distributions), weights) + + # ------------------------------------------------------------------ + # Linear moments + # ------------------------------------------------------------------ + + def linear_mean(self): + """ + Compute the mean of the marginal linear distribution by treating + the state as a Gaussian mixture. + + Returns + ------- + mu : array, shape (lin_dim,) + """ + means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim) + covs = stack( + [ld.C for ld in self.linear_distributions], axis=2 + ) # (lin_dim, lin_dim, n) + weights = self.gd.grid_values / backend_sum(self.gd.grid_values) + mu, _ = GaussianMixture.mixture_parameters_to_gaussian_parameters( + means, covs, weights + ) + return mu + + def linear_covariance(self): + """ + Compute the covariance of the marginal linear distribution by treating + the state as a Gaussian mixture. + + Returns + ------- + C : array, shape (lin_dim, lin_dim) + """ + means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim) + covs = stack( + [ld.C for ld in self.linear_distributions], axis=2 + ) # (lin_dim, lin_dim, n) + weights = self.gd.grid_values / backend_sum(self.gd.grid_values) + _, C = GaussianMixture.mixture_parameters_to_gaussian_parameters( + means, covs, weights + ) + return C + + # ------------------------------------------------------------------ + # Multiplication + # ------------------------------------------------------------------ + + def multiply(self, other): + """ + Multiply two StateSpaceSubdivisionGaussianDistributions. + + Both operands must be defined on the same grid. For each grid point + the conditional Gaussians are multiplied (Bayesian update). The grid + weights are updated by the likelihood factors that arise from the + overlap of the two conditional Gaussians. + + Parameters + ---------- + other : StateSpaceSubdivisionGaussianDistribution + + Returns + ------- + StateSpaceSubdivisionGaussianDistribution + """ + assert isinstance(other, StateSpaceSubdivisionGaussianDistribution) + assert self.gd.n_grid_points == other.gd.n_grid_points, ( + "Can only multiply distributions defined on grids with the same " + "number of grid points." + ) + self_grid = asarray(self.gd.get_grid()) + other_grid = asarray(other.gd.get_grid()) + assert allclose(self_grid, other_grid), ( + "Can only multiply for equal grids." + ) + + n = len(self.linear_distributions) + new_linear_distributions = [] + pdf_values = [] + + for i in range(n): + ld_self = self.linear_distributions[i] + ld_other = other.linear_distributions[i] + + # The likelihood factor for grid point i is the pdf of + # N(mu_self_i, C_self_i + C_other_i) evaluated at mu_other_i. + # This is equivalent to N(0, C_self_i + C_other_i) at 0. + combined_cov = ld_self.C + ld_other.C + temp_g = GaussianDistribution(ld_other.mu, combined_cov, check_validity=False) + pdf_values.append(temp_g.pdf(ld_self.mu)) + + new_linear_distributions.append(ld_self.multiply(ld_other)) + + # Build a 1-D factors array. pdf() may return shape () or (1,) depending + # on backend and Gaussian dimension; reshape each value to (1,) before + # concatenating so the result is always shape (n,). + factors_linear = concatenate([asarray(v).reshape((1,)) for v in pdf_values]) + + # Build result + result = copy.deepcopy(self) + result.linear_distributions = new_linear_distributions + result.gd = copy.deepcopy(self.gd) + result.gd.grid_values = ( + self.gd.grid_values * other.gd.grid_values * array(factors_linear) + ) + result.gd.normalize_in_place(warn_unnorm=False) + return result + + # ------------------------------------------------------------------ + # Mode + # ------------------------------------------------------------------ + + def mode(self): + """ + Compute the (approximate) joint mode. + + The mode is found by maximising the product of the conditional + Gaussian peak value and the grid weight at each grid point. Only + the discrete grid is searched (no interpolation). + + Returns + ------- + m : array, shape (bound_dim + lin_dim,) + Concatenation of the periodic mode (grid point) and the linear + mode (mean of the conditional Gaussian at that grid point). + + Warns + ----- + UserWarning + If the density appears multimodal (i.e. another grid point has a + joint value within a factor of 1.001 of the maximum). + """ + lin_dim = self.linear_distributions[0].dim + zeros_d = zeros(lin_dim) + + # Peak value of N(mu_i, C_i) depends only on C_i; it equals + # N(0 | 0, C_i). We evaluate each conditional Gaussian at its own + # mean to obtain the maximum pdf value. + peak_vals = array( + [ + float( + GaussianDistribution(zeros_d, ld.C, check_validity=False).pdf( + zeros_d + ) + ) + for ld in self.linear_distributions + ] + ) + + fun_vals_joint = peak_vals * asarray(self.gd.grid_values) + index = int(argmax(fun_vals_joint)) + max_val = float(fun_vals_joint[index]) + + # Remove the maximum entry to check for multimodality + remaining = concatenate( + [fun_vals_joint[:index], fun_vals_joint[index + 1:]] # noqa: E203 + ) + if len(remaining) > 0 and ( + backend_any((max_val - remaining) < 1e-15) + or backend_any((max_val / remaining) < 1.001) + ): + warnings.warn( + "Density may not be unimodal. However, this can also be caused " + "by a high grid resolution and thus very similar function values " + "at the grid points.", + UserWarning, + stacklevel=2, + ) + + periodic_mode = self.gd.get_grid_point(index) # shape (bound_dim,) + linear_mode = self.linear_distributions[index].mu # shape (lin_dim,) + return concatenate([periodic_mode.reshape(-1), linear_mode.reshape(-1)]) + + # ------------------------------------------------------------------ + # Unsupported operations + # ------------------------------------------------------------------ + + def convolve(self, _other): + raise NotImplementedError( + "convolve is not supported for " + "StateSpaceSubdivisionGaussianDistribution." + ) diff --git a/pyrecest/distributions/nonperiodic/gaussian_mixture.py b/pyrecest/distributions/nonperiodic/gaussian_mixture.py index 89be400e9..ce07ef961 100644 --- a/pyrecest/distributions/nonperiodic/gaussian_mixture.py +++ b/pyrecest/distributions/nonperiodic/gaussian_mixture.py @@ -1,6 +1,6 @@ # pylint: disable=redefined-builtin,no-name-in-module,no-member # pylint: disable=no-name-in-module,no-member -from pyrecest.backend import array, dot, ones, stack, sum +from pyrecest.backend import array, ones, reshape, stack, sum from .abstract_linear_distribution import AbstractLinearDistribution from .gaussian_distribution import GaussianDistribution @@ -15,7 +15,8 @@ def __init__(self, dists: list[GaussianDistribution], w): def mean(self): gauss_array = self.dists - return dot(array([g.mu for g in gauss_array]), self.w) + means = array([g.mu for g in gauss_array]) # shape (n, dim) + return sum(means * reshape(self.w, (-1, 1)), axis=0) def set_mean(self, new_mean): mean_offset = new_mean - self.mean() diff --git a/pyrecest/tests/distributions/test_state_space_subdivision_gaussian_distribution.py b/pyrecest/tests/distributions/test_state_space_subdivision_gaussian_distribution.py new file mode 100644 index 000000000..eb01ceadd --- /dev/null +++ b/pyrecest/tests/distributions/test_state_space_subdivision_gaussian_distribution.py @@ -0,0 +1,164 @@ +import unittest + +import numpy.testing as npt +import pyrecest + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import array, eye, linalg, pi +from pyrecest.distributions.cart_prod.state_space_subdivision_gaussian_distribution import ( + StateSpaceSubdivisionGaussianDistribution, +) +from pyrecest.distributions.circle.circular_uniform_distribution import ( + CircularUniformDistribution, +) +from pyrecest.distributions.circle.von_mises_distribution import VonMisesDistribution +from pyrecest.distributions.hypersphere_subset.hyperhemispherical_grid_distribution import ( + HyperhemisphericalGridDistribution, +) +from pyrecest.distributions.hypersphere_subset.hyperhemispherical_uniform_distribution import ( + HyperhemisphericalUniformDistribution, +) +from pyrecest.distributions.hypertorus.hypertoroidal_grid_distribution import ( + HypertoroidalGridDistribution, +) +from pyrecest.distributions.nonperiodic.gaussian_distribution import ( + GaussianDistribution, +) + + +class TestStateSpaceSubdivisionGaussianDistribution(unittest.TestCase): + def test_multiply_s1_x_r1_identical_precise(self): + """Multiply two S1xR1 distributions; linear uncertainty must decrease.""" + n = 100 + gd = HypertoroidalGridDistribution.from_distribution( + CircularUniformDistribution(), (n,) + ) + gaussians = [GaussianDistribution(array([0.0]), array([[1.0]])) for _ in range(n)] + rbd1 = StateSpaceSubdivisionGaussianDistribution(gd, gaussians) + + gaussians2 = [ + GaussianDistribution(array([2.0]), array([[1.0]])) for _ in range(n) + ] + rbd2 = StateSpaceSubdivisionGaussianDistribution(gd, gaussians2) + + rbd_up = rbd1.multiply(rbd2) + + for i in range(n): + npt.assert_array_less( + linalg.det(rbd_up.linear_distributions[i].C), + linalg.det(rbd1.linear_distributions[i].C), + ) + npt.assert_allclose( + rbd_up.linear_distributions[i].mu, array([1.0]), atol=1e-14 + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on this backend", + ) + def test_multiply_s2_x_r3_rough(self): + """Multiply two S2xR3 distributions; linear uncertainty must decrease.""" + n = 100 + gd = HyperhemisphericalGridDistribution.from_distribution( + HyperhemisphericalUniformDistribution(2), n, "leopardi_symm" + ) + gaussians = [ + GaussianDistribution(array([0.0, 0.0, 0.0]), 1000.0 * eye(3)) + for _ in range(n) + ] + rbd1 = StateSpaceSubdivisionGaussianDistribution(gd, gaussians) + + gaussians2 = [ + GaussianDistribution(array([2.0, 2.0, 2.0]), 1000.0 * eye(3)) + for _ in range(n) + ] + rbd2 = StateSpaceSubdivisionGaussianDistribution(gd, gaussians2) + + rbd_up = rbd1.multiply(rbd2) + + for i in range(n): + npt.assert_array_less( + linalg.det(rbd_up.linear_distributions[i].C), + linalg.det(rbd1.linear_distributions[i].C), + ) + npt.assert_allclose( + rbd_up.linear_distributions[i].mu, + array([1.0, 1.0, 1.0]), + atol=1e-10, + ) + + def test_hybrid_mean(self): + """hybridMean returns concatenation of periodic and linear means.""" + n = 100 + mu_periodic = 4.0 + mu_linear = array([1.0, 2.0, 3.0]) + gd = HypertoroidalGridDistribution.from_distribution( + VonMisesDistribution(mu_periodic, 1.0), (n,) + ) + gaussians = [ + GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n) + ] + rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians) + npt.assert_allclose( + rbd.hybrid_mean(), + array([mu_periodic, 1.0, 2.0, 3.0]), + atol=1e-4, + ) + + def test_linear_mean(self): + """linearMean returns the correct linear mean.""" + n = 100 + mu_linear = array([1.0, 2.0, 3.0]) + gd = HypertoroidalGridDistribution.from_distribution( + VonMisesDistribution(4.0, 1.0), (n,) + ) + gaussians = [ + GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n) + ] + rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians) + npt.assert_allclose(rbd.linear_mean(), mu_linear, rtol=5e-7) + + def test_mode_warning_uniform(self): + """mode() warns about potential multimodality for a uniform periodic part.""" + n = 100 + mu_linear = array([1.0, 2.0, 3.0]) + gd = HypertoroidalGridDistribution.from_distribution( + CircularUniformDistribution(), (n,) + ) + gaussians = [ + GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n) + ] + rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians) + with self.assertWarns(UserWarning): + rbd.mode() + + def test_mode(self): + """mode() returns the correct mode without warning for a unimodal density.""" + n = 100 + mu_periodic = 4.0 + mu_linear = array([1.0, 2.0, 3.0]) + gd = HypertoroidalGridDistribution.from_distribution( + VonMisesDistribution(mu_periodic, 10.0), (n,) + ) + gaussians = [ + GaussianDistribution(mu_linear, eye(3)) for _ in range(n) + ] + rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians) + + # Should not warn + with self.assertNoLogs(level="WARNING"): + import warnings as _warnings + with _warnings.catch_warnings(): + _warnings.simplefilter("error", UserWarning) + m = rbd.mode() + + # Mode should be close to [mu_periodic, mu_linear]; tolerance is grid resolution + npt.assert_allclose( + m, + array([mu_periodic, 1.0, 2.0, 3.0]), + atol=pi / n, + ) + + +if __name__ == "__main__": + unittest.main()