diff --git a/pyrecest/distributions/conditional/__init__.py b/pyrecest/distributions/conditional/__init__.py index 986f13346..eb78db61d 100644 --- a/pyrecest/distributions/conditional/__init__.py +++ b/pyrecest/distributions/conditional/__init__.py @@ -1,3 +1,4 @@ from .sd_cond_sd_grid_distribution import SdCondSdGridDistribution +from .td_cond_td_grid_distribution import TdCondTdGridDistribution -__all__ = ["SdCondSdGridDistribution"] +__all__ = ["SdCondSdGridDistribution", "TdCondTdGridDistribution"] diff --git a/pyrecest/distributions/conditional/abstract_conditional_distribution.py b/pyrecest/distributions/conditional/abstract_conditional_distribution.py index 840021129..c5e332177 100644 --- a/pyrecest/distributions/conditional/abstract_conditional_distribution.py +++ b/pyrecest/distributions/conditional/abstract_conditional_distribution.py @@ -1,5 +1,164 @@ +import copy +import warnings from abc import ABC +# pylint: disable=redefined-builtin,no-name-in-module,no-member +from pyrecest.backend import ( + any, + arange, + argmin, + array_equal, + linalg, + meshgrid, +) + class AbstractConditionalDistribution(ABC): - pass + """Abstract base class for conditional grid distributions on manifolds. + + Subclasses represent distributions of the form f(a | b) where both a and b + live on the same manifold. The joint state is stored as a square matrix + ``grid_values`` where ``grid_values[i, j] = f(grid[i] | grid[j])``. + """ + + def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): + """Common initialisation for conditional grid distributions. + + Parameters + ---------- + grid : array of shape (n_points, d) + Grid points on the individual manifold. + grid_values : array of shape (n_points, n_points) + Conditional pdf values; ``grid_values[i, j] = f(grid[i] | grid[j])``. + enforce_pdf_nonnegative : bool + Whether to require non-negative ``grid_values``. + """ + if grid.ndim != 2: + raise ValueError("grid must be a 2D array of shape (n_points, d).") + + n_points, d = grid.shape + + if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points): + raise ValueError( + f"grid_values must be a square 2D array of shape ({n_points}, {n_points})." + ) + + if enforce_pdf_nonnegative and any(grid_values < 0): + raise ValueError("grid_values must be non-negative.") + + self.grid = grid + self.grid_values = grid_values + self.enforce_pdf_nonnegative = enforce_pdf_nonnegative + # Embedding dimension of the Cartesian product space (convention from + # libDirectional: dim = 2 * dim_of_individual_manifold). + self.dim = 2 * d + + # ------------------------------------------------------------------ + # Normalization + # ------------------------------------------------------------------ + + def normalize(self): + """No-op – returns ``self`` for compatibility.""" + return self + + # ------------------------------------------------------------------ + # Arithmetic + # ------------------------------------------------------------------ + + def multiply(self, other): + """Element-wise multiply two conditional grid distributions. + + The resulting distribution is *not* normalized. + + Parameters + ---------- + other : AbstractConditionalDistribution + Must be defined on the same grid. + + Returns + ------- + AbstractConditionalDistribution + Same concrete type as ``self``. + """ + if not array_equal(self.grid, other.grid): + raise ValueError( + "Multiply:IncompatibleGrid: Can only multiply distributions " + "defined on identical grids." + ) + warnings.warn( + "Multiply:UnnormalizedResult: Multiplication does not yield a " + "normalized result.", + UserWarning, + ) + result = copy.deepcopy(self) + result.grid_values = result.grid_values * other.grid_values + return result + + # ------------------------------------------------------------------ + # Protected helpers + # ------------------------------------------------------------------ + + def _get_grid_slice(self, first_or_second, point): + """Return the ``grid_values`` slice for a fixed grid point. + + Parameters + ---------- + first_or_second : int (1 or 2) + Which variable to fix. + point : array of shape (d,) + Must be an existing grid point. + + Returns + ------- + array of shape (n_points,) + """ + d = self.grid.shape[1] + if point.shape[0] != d: + raise ValueError( + f"point must have length {d} (grid dimension)." + ) + diffs = linalg.norm(self.grid - point[None, :], axis=1) + locb = argmin(diffs) + if diffs[locb] > 1e-10: + raise ValueError( + "Cannot fix value at this point because it is not on the grid." + ) + if first_or_second == 1: + return self.grid_values[locb, :] + if first_or_second == 2: + return self.grid_values[:, locb] + raise ValueError("first_or_second must be 1 or 2.") + + @staticmethod + def _evaluate_on_grid(fun, grid, n, fun_does_cartesian_product): + """Evaluate ``fun`` on all grid point pairs and return an (n, n) array. + + Parameters + ---------- + fun : callable + ``f(a, b)`` with the semantics described in ``from_function``. + grid : array of shape (n, d) + Grid points on the individual manifold. + n : int + Number of grid points (``grid.shape[0]``). + fun_does_cartesian_product : bool + Whether *fun* handles all grid combinations internally. + + Returns + ------- + array of shape (n, n) + """ + if fun_does_cartesian_product: + fvals = fun(grid, grid) + return fvals.reshape(n, n) + idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij") + grid_a = grid[idx_a.ravel()] + grid_b = grid[idx_b.ravel()] + fvals = fun(grid_a, grid_b) + if fvals.shape == (n**2, n**2): + raise ValueError( + "Function apparently performs the Cartesian product itself. " + "Set fun_does_cartesian_product=True." + ) + return fvals.reshape(n, n) + diff --git a/pyrecest/distributions/conditional/sd_cond_sd_grid_distribution.py b/pyrecest/distributions/conditional/sd_cond_sd_grid_distribution.py index 50e6e0244..f3d40582b 100644 --- a/pyrecest/distributions/conditional/sd_cond_sd_grid_distribution.py +++ b/pyrecest/distributions/conditional/sd_cond_sd_grid_distribution.py @@ -1,4 +1,3 @@ -import copy import warnings # pylint: disable=redefined-builtin,no-name-in-module,no-member @@ -6,12 +5,7 @@ abs, all, any, - arange, - argmin, - array_equal, - linalg, mean, - meshgrid, sum, ) from pyrecest.distributions.hypersphere_subset.abstract_hypersphere_subset_distribution import ( @@ -50,31 +44,11 @@ def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): enforce_pdf_nonnegative : bool Whether non-negativity of ``grid_values`` is required. """ - if grid.ndim != 2: - raise ValueError("grid must be a 2D array of shape (n_points, d).") - - n_points, d = grid.shape - - if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points): - raise ValueError( - f"grid_values must be a square 2D array of shape ({n_points}, {n_points})." - ) - - if any(abs(grid) > 1 + 1e-12): + super().__init__(grid, grid_values, enforce_pdf_nonnegative) + if any(abs(self.grid) > 1 + 1e-12): raise ValueError( "Grid points must have coordinates in [-1, 1] (unit sphere)." ) - - if enforce_pdf_nonnegative and any(grid_values < 0): - raise ValueError("grid_values must be non-negative.") - - self.grid = grid - self.grid_values = grid_values - self.enforce_pdf_nonnegative = enforce_pdf_nonnegative - # Embedding dimension of the Cartesian product space (convention from - # libDirectional: dim = 2 * embedding_dim_of_individual_sphere). - self.dim = 2 * d - self._check_normalization() # ------------------------------------------------------------------ @@ -107,43 +81,6 @@ def _check_normalization(self, tol=0.01): UserWarning, ) - def normalize(self): - """No-op – returns ``self`` for compatibility.""" - return self - - # ------------------------------------------------------------------ - # Arithmetic - # ------------------------------------------------------------------ - - def multiply(self, other): - """ - Element-wise multiply two conditional grid distributions. - - The resulting distribution is *not* normalized. - - Parameters - ---------- - other : SdCondSdGridDistribution - Must be defined on the same grid. - - Returns - ------- - SdCondSdGridDistribution - """ - if not array_equal(self.grid, other.grid): - raise ValueError( - "Multiply:IncompatibleGrid: Can only multiply distributions " - "defined on identical grids." - ) - warnings.warn( - "Multiply:UnnormalizedResult: Multiplication does not yield a " - "normalized result.", - UserWarning, - ) - result = copy.deepcopy(self) - result.grid_values = result.grid_values * other.grid_values - return result - # ------------------------------------------------------------------ # Marginalisation and conditioning # ------------------------------------------------------------------ @@ -201,26 +138,7 @@ def fix_dim(self, first_or_second, point): HypersphericalGridDistribution, ) - d = self.grid.shape[1] - if point.shape[0] != d: - raise ValueError( - f"point must have length {d} (embedding dimension of the sphere)." - ) - - diffs = linalg.norm(self.grid - point[None, :], axis=1) - locb = argmin(diffs) - if diffs[locb] > 1e-10: - raise ValueError( - "Cannot fix value at this point because it is not on the grid." - ) - - if first_or_second == 1: - grid_values_slice = self.grid_values[locb, :] - elif first_or_second == 2: - grid_values_slice = self.grid_values[:, locb] - else: - raise ValueError("first_or_second must be 1 or 2.") - + grid_values_slice = self._get_grid_slice(first_or_second, point) return HypersphericalGridDistribution(self.grid, grid_values_slice) # ------------------------------------------------------------------ @@ -276,24 +194,8 @@ def from_function( # manifold dim: embedding_dim = dim // 2, manifold_dim = embedding_dim - 1. manifold_dim = dim // 2 - 1 grid, _ = get_grid_hypersphere(grid_type, n, manifold_dim) - # grid is (n, dim//2) - - if fun_does_cartesian_product: - fvals = fun(grid, grid) - grid_values = fvals.reshape(n, n) - else: - # Build index pairs: idx_a[i, j] = i, idx_b[i, j] = j - idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij") - grid_a = grid[idx_a.ravel()] # (n*n, d) - grid_b = grid[idx_b.ravel()] # (n*n, d) - fvals = fun(grid_a, grid_b) # (n*n,) - - if fvals.shape == (n**2, n**2): - raise ValueError( - "Function apparently performs the Cartesian product itself. " - "Set fun_does_cartesian_product=True." - ) - - grid_values = fvals.reshape(n, n) + grid_values = SdCondSdGridDistribution._evaluate_on_grid( + fun, grid, n, fun_does_cartesian_product + ) return SdCondSdGridDistribution(grid, grid_values) diff --git a/pyrecest/distributions/conditional/td_cond_td_grid_distribution.py b/pyrecest/distributions/conditional/td_cond_td_grid_distribution.py new file mode 100644 index 000000000..86007b153 --- /dev/null +++ b/pyrecest/distributions/conditional/td_cond_td_grid_distribution.py @@ -0,0 +1,202 @@ +import warnings + +# pylint: disable=redefined-builtin,no-name-in-module,no-member +from pyrecest.backend import ( + abs, + all, + any, + mean, + pi, + sum, +) + +from .abstract_conditional_distribution import AbstractConditionalDistribution + + +class TdCondTdGridDistribution(AbstractConditionalDistribution): + """ + Conditional distribution on Td x Td represented by a grid of values. + + For a conditional distribution f(a|b), ``grid_values[i, j]`` stores + the value f(grid[i] | grid[j]). + + Convention + ---------- + - ``grid`` has shape ``(n_points, d)`` where ``d`` is the number of + dimensions of the individual torus (e.g. d=1 for T1). + - ``grid_values`` has shape ``(n_points, n_points)``. + - ``dim = 2 * d`` is the dimension of the Cartesian product space + (convention inherited from libDirectional). + """ + + def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): + """ + Parameters + ---------- + grid : array of shape (n_points, d) + Grid points on the torus. + grid_values : array of shape (n_points, n_points) + Conditional pdf values: ``grid_values[i, j] = f(grid[i] | grid[j])``. + Must be non-negative when ``enforce_pdf_nonnegative=True``. + enforce_pdf_nonnegative : bool + Whether non-negativity of ``grid_values`` is required. + """ + super().__init__(grid, grid_values, enforce_pdf_nonnegative) + self._check_normalization() + + # ------------------------------------------------------------------ + # Normalization + # ------------------------------------------------------------------ + + def _check_normalization(self, tol=0.01): + """Warn if any column is not normalized to 1 over the torus.""" + d = self.dim // 2 + manifold_size = float((2.0 * pi) ** d) + # For each fixed second argument j, the mean over i times the torus + # volume should equal 1. + ints = mean(self.grid_values, 0) * manifold_size + if any(abs(ints - 1) > tol): + # Check whether swapping the two arguments would yield normalisation. + ints_swapped = mean(self.grid_values, 1) * manifold_size + if all(abs(ints_swapped - 1) <= tol): + raise ValueError( + "Normalization:maybeWrongOrder: Not normalized but would be if " + "the order of the two tori were swapped. Check input." + ) + warnings.warn( + "Normalization:notNormalized: When conditioning values for the first " + "torus on the second, normalisation is not ensured. " + "Check input or increase tolerance. " + "No normalisation is performed; you may want to do this manually.", + UserWarning, + ) + + # ------------------------------------------------------------------ + # Marginalisation and conditioning + # ------------------------------------------------------------------ + + def marginalize_out(self, first_or_second): + """ + Marginalize out one of the two tori. + + Parameters + ---------- + first_or_second : int (1 or 2) + ``1`` marginalizes out the *first* dimension (sums over rows); + ``2`` marginalizes out the *second* dimension (sums over columns). + + Returns + ------- + HypertoroidalGridDistribution + """ + # Import here to avoid circular imports + from pyrecest.distributions.hypertorus.hypertoroidal_grid_distribution import ( # pylint: disable=import-outside-toplevel + HypertoroidalGridDistribution, + ) + + if first_or_second == 1: + grid_values_sgd = sum(self.grid_values, 0) + elif first_or_second == 2: + grid_values_sgd = sum(self.grid_values, 1) + else: + raise ValueError("first_or_second must be 1 or 2.") + + return HypertoroidalGridDistribution(grid_values_sgd, "custom", self.grid) + + def fix_dim(self, first_or_second, point): + """ + Return the conditional slice for a fixed grid point. + + The supplied ``point`` must be an existing grid point. + + Parameters + ---------- + first_or_second : int (1 or 2) + ``1`` fixes the *first* argument at ``point`` and returns values + over the second torus; + ``2`` fixes the *second* argument and returns values over the + first torus. + point : array of shape (d,) + Grid point at which to evaluate. + + Returns + ------- + HypertoroidalGridDistribution + """ + # Import here to avoid circular imports + from pyrecest.distributions.hypertorus.hypertoroidal_grid_distribution import ( # pylint: disable=import-outside-toplevel + HypertoroidalGridDistribution, + ) + + grid_values_slice = self._get_grid_slice(first_or_second, point) + return HypertoroidalGridDistribution(grid_values_slice, "custom", self.grid) + + # ------------------------------------------------------------------ + # Factory + # ------------------------------------------------------------------ + + @staticmethod + def from_function( + fun, + no_of_grid_points, + fun_does_cartesian_product=False, + grid_type="CartesianProd", + dim=2, + ): + """ + Construct a :class:`TdCondTdGridDistribution` from a callable. + + Parameters + ---------- + fun : callable + Conditional pdf function ``f(a, b)``. + + * When ``fun_does_cartesian_product=False`` (default): ``fun`` + is called once with two arrays of shape ``(n_pairs, d)`` + containing all ``n_points²`` pairs, and must return a 1-D + array of length ``n_points²``. + * When ``fun_does_cartesian_product=True``: ``fun`` is called + with both full grids of shape ``(n_points, d)`` and must + return an array of shape ``(n_points, n_points)``. + no_of_grid_points : int + Number of grid points for each torus. + fun_does_cartesian_product : bool + See ``fun`` description above. + grid_type : str + Grid type passed to + :meth:`~pyrecest.distributions.hypertorus.hypertoroidal_grid_distribution.HypertoroidalGridDistribution.generate_cartesian_product_grid`. + Defaults to ``'CartesianProd'``. + dim : int + Dimension of the Cartesian product space + (``2 * dim_of_individual_torus``). + Defaults to 2 (T1 × T1). + + Returns + ------- + TdCondTdGridDistribution + """ + # Import inside the function to avoid circular imports at module level. + from pyrecest.distributions.hypertorus.hypertoroidal_grid_distribution import ( # pylint: disable=import-outside-toplevel + HypertoroidalGridDistribution, + ) + + if dim % 2 != 0: + raise ValueError( + "dim must be even (it represents two copies of a hypertorus)." + ) + if grid_type not in ("CartesianProd", "CartesianProduct"): + raise ValueError( + "Grid scheme not recognized; only 'CartesianProd' / " + "'CartesianProduct' is currently supported." + ) + + dim_half = dim // 2 + n = no_of_grid_points + grid = HypertoroidalGridDistribution.generate_cartesian_product_grid( + [n] * dim_half + ) + + grid_values = TdCondTdGridDistribution._evaluate_on_grid( + fun, grid, n, fun_does_cartesian_product + ) + return TdCondTdGridDistribution(grid, grid_values) diff --git a/pyrecest/tests/distributions/test_td_cond_td_grid_distribution.py b/pyrecest/tests/distributions/test_td_cond_td_grid_distribution.py new file mode 100644 index 000000000..100007edc --- /dev/null +++ b/pyrecest/tests/distributions/test_td_cond_td_grid_distribution.py @@ -0,0 +1,288 @@ +import unittest + +import numpy.testing as npt + +from pyrecest.backend import ( # pylint: disable=redefined-builtin + abs, + array, + asarray, + exp, + linspace, + minimum, + ones, + pi, + random, + zeros, +) +from pyrecest.distributions.conditional.td_cond_td_grid_distribution import ( + TdCondTdGridDistribution, +) +from pyrecest.distributions.hypertorus.hypertoroidal_grid_distribution import ( + HypertoroidalGridDistribution, +) + + +def _make_normalized_grid_values(n: int): + """Return an (n x n) matrix whose columns are normalized (integrate to 1).""" + random.seed(0) + gv = 0.5 + random.uniform(size=(n, n)) + # Normalize each column so that mean(col) * (2*pi)^1 == 1 + gv = gv / (gv.mean(axis=0) * 2.0 * pi) + return gv + + +class TdCondTdGridDistributionTest(unittest.TestCase): + # -------------------------------------------------------------- construction + + def test_construction_t1(self): + """Basic construction for T1 x T1.""" + n = 5 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid, gv) + self.assertEqual(td.dim, 2) + npt.assert_allclose(td.grid, grid, rtol=1e-6) + npt.assert_allclose(td.grid_values, gv, rtol=1e-6) + + def test_construction_wrong_shape_raises(self): + n = 4 + grid = zeros((n, 1)) + with self.assertRaises(ValueError): + # Non-square grid_values + TdCondTdGridDistribution( + grid, ones((n, n + 1)) / (n * 2 * pi) + ) + + def test_construction_wrong_order_raises(self): + """Transposed (row-normalized) matrix should raise an error.""" + n = 6 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + # Transpose → rows are normalized, columns are not + with self.assertRaises(ValueError): + TdCondTdGridDistribution(grid, gv.T) + + def test_construction_unnormalized_warns(self): + """An unnormalized matrix that cannot be fixed by transposing should warn.""" + n = 5 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = ones((n, n)) # neither rows nor cols sum to 1/(2pi) + with self.assertWarns(UserWarning): + TdCondTdGridDistribution(grid, gv) + + # -------------------------------------------------------------- normalize + + def test_normalize_returns_self(self): + n = 4 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid, gv) + self.assertIs(td.normalize(), td) + + # -------------------------------------------------------------- multiply + + def test_multiply_same_grid(self): + import warnings + + n = 6 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td1 = TdCondTdGridDistribution(grid, gv) + td2 = TdCondTdGridDistribution(grid, gv) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = td1.multiply(td2) + npt.assert_allclose( + asarray(result.grid_values), + asarray(td1.grid_values) * asarray(td2.grid_values), + rtol=1e-10, + ) + + def test_multiply_incompatible_grid_raises(self): + n1, n2 = 4, 6 + grid1 = linspace(0.0, 2.0 * pi - 2.0 * pi / n1, n1).reshape(-1, 1) + grid2 = linspace(0.0, 2.0 * pi - 2.0 * pi / n2, n2).reshape(-1, 1) + gv1 = _make_normalized_grid_values(n1) + gv2 = _make_normalized_grid_values(n2) + td1 = TdCondTdGridDistribution(grid1, gv1) + td2 = TdCondTdGridDistribution(grid2, gv2) + with self.assertRaises(ValueError) as ctx: + td1.multiply(td2) + self.assertIn("IncompatibleGrid", str(ctx.exception)) + + # -------------------------------------------------------------- from_function + + def test_from_function_t1(self): + """from_function should recover a wrapped-normal-like conditional.""" + n = 20 + dim = 2 # T1 x T1 + + def cond_fun(a, b): + # Simple Gaussian-like conditional (unnormalized, normalized per column) + diff = asarray(a)[:, 0] - asarray(b)[:, 0] + return exp(-0.5 * minimum(diff**2, (2 * pi - abs(diff)) ** 2)) + + td = TdCondTdGridDistribution.from_function( + cond_fun, n, fun_does_cartesian_product=False, grid_type="CartesianProd", dim=dim + ) + self.assertIsInstance(td, TdCondTdGridDistribution) + self.assertEqual(td.dim, dim) + self.assertEqual(asarray(td.grid_values).shape, (n, n)) + self.assertEqual(asarray(td.grid).shape, (n, 1)) + + def test_from_function_cartesian_product_flag(self): + """from_function with fun_does_cartesian_product=True.""" + n = 8 + dim = 2 + + def cond_fun_cp(a, b): + # a: (n, 1), b: (n, 1) → return (n, n) + a_arr = asarray(a)[:, 0] + b_arr = asarray(b)[:, 0] + diff = a_arr[:, None] - b_arr[None, :] + return exp(-0.5 * minimum(diff**2, (2 * pi - abs(diff)) ** 2)) + + td = TdCondTdGridDistribution.from_function( + cond_fun_cp, + n, + fun_does_cartesian_product=True, + grid_type="CartesianProd", + dim=dim, + ) + self.assertIsInstance(td, TdCondTdGridDistribution) + self.assertEqual(asarray(td.grid_values).shape, (n, n)) + + def test_from_function_unknown_grid_raises(self): + n = 4 + with self.assertRaises(ValueError): + TdCondTdGridDistribution.from_function( + lambda a, b: ones(len(a)), + n, + fun_does_cartesian_product=False, + grid_type="unknownGrid", + dim=2, + ) + + def test_from_function_odd_dim_raises(self): + with self.assertRaises(ValueError): + TdCondTdGridDistribution.from_function( + lambda a, b: ones(len(a)), 4, False, "CartesianProd", dim=3 + ) + + # --------------------------------------------------------- marginalize_out + + def test_marginalize_out_returns_hgd(self): + n = 6 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid, gv) + + for first_or_second in (1, 2): + with self.subTest(first_or_second=first_or_second): + marginal = td.marginalize_out(first_or_second) + self.assertIsInstance(marginal, HypertoroidalGridDistribution) + self.assertEqual(marginal.dim, 1) + + def test_marginalize_out_sums(self): + n = 5 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid, gv) + + # marginalize_out(1) sums rows; HGD normalizes, so check proportionality + m1 = td.marginalize_out(1) + expected_unnorm = gv.sum(axis=0) + actual_unnorm = asarray(m1.grid_values) * float(m1.integrate()) + # Check proportionality (ratio should be constant) + ratio = actual_unnorm / expected_unnorm + npt.assert_allclose(ratio, ratio[0] * ones(n), atol=1e-12) + + # marginalize_out(2) sums cols + m2 = td.marginalize_out(2) + expected_unnorm2 = gv.sum(axis=1) + actual_unnorm2 = asarray(m2.grid_values) * float(m2.integrate()) + ratio2 = actual_unnorm2 / expected_unnorm2 + npt.assert_allclose(ratio2, ratio2[0] * ones(n), atol=1e-12) + + def test_marginalize_out_invalid_raises(self): + n = 5 + grid = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid, gv) + with self.assertRaises(ValueError): + td.marginalize_out(0) + with self.assertRaises(ValueError): + td.marginalize_out(3) + + # -------------------------------------------------------------- fix_dim + + def test_fix_dim_returns_hgd(self): + n = 5 + grid_np = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid_np, gv) + + for first_or_second in (1, 2): + with self.subTest(first_or_second=first_or_second): + point = grid_np[2] # third grid point + result = td.fix_dim(first_or_second, point) + self.assertIsInstance(result, HypertoroidalGridDistribution) + self.assertEqual(result.dim, 1) + + def test_fix_dim_off_grid_raises(self): + n = 5 + grid_np = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid_np, gv) + with self.assertRaises(ValueError): + td.fix_dim(1, array([1.23456789])) + + def test_fix_dim_values_correct(self): + """fix_dim(2, grid[j]) should give a distribution proportional to col j.""" + n = 6 + grid_np = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid_np, gv) + + j = 3 + slice_dist = td.fix_dim(2, grid_np[j]) + expected = gv[:, j] + expected = expected / expected.mean() / (2.0 * pi) # normalize + npt.assert_allclose( + asarray(slice_dist.grid_values), + expected, + atol=1e-12, + ) + + def test_fix_dim_invalid_raises(self): + n = 5 + grid_np = linspace(0.0, 2.0 * pi - 2.0 * pi / n, n).reshape(-1, 1) + gv = _make_normalized_grid_values(n) + td = TdCondTdGridDistribution(grid_np, gv) + point = grid_np[0] + with self.assertRaises(ValueError): + td.fix_dim(0, point) + with self.assertRaises(ValueError): + td.fix_dim(3, point) + + # --------------------------------------------------------- from_function + fix_dim round-trip + + def test_from_function_fix_dim_roundtrip(self): + """fix_dim on a from_function object should return a HypertoroidalGridDistribution.""" + n = 10 + dim = 2 + + def cond_fun(a, b): + diff = asarray(a)[:, 0] - asarray(b)[:, 0] + return exp(-0.5 * diff**2) + + td = TdCondTdGridDistribution.from_function( + cond_fun, n, fun_does_cartesian_product=False, dim=dim + ) + grid_np = asarray(td.grid) + slice_dist = td.fix_dim(2, grid_np[0]) + self.assertIsInstance(slice_dist, HypertoroidalGridDistribution) + + +if __name__ == "__main__": + unittest.main()