Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ec4a697
Implement TdCondTdGridDistribution with tests
Copilot Mar 31, 2026
42352bc
Fix MegaLinter errors: remove unused imports, reduce locals, fix prot…
Copilot Mar 31, 2026
7998df6
Increase tolerance to fix
FlorianPfaff Mar 31, 2026
e071052
[MegaLinter] Apply linters automatic fixes
FlorianPfaff Mar 31, 2026
9673238
Fix 3 bugs in EOT shape database: sample_within return type, visibili…
Copilot Mar 31, 2026
32930f7
Implement SdCondSdGridDistribution with tests
Copilot Mar 31, 2026
f33c532
Fix MegaLinter errors: remove unused import, add pylint disable comme…
Copilot Mar 31, 2026
cba5d75
Fix tile/repeat swap in pdf() and strengthen test_pdf with distinct rows
Copilot Mar 31, 2026
e598439
Fixed linter warning
FlorianPfaff Mar 31, 2026
3295199
Create only temporary npy file during tests
FlorianPfaff Mar 31, 2026
e541651
Prevent temporary plots from remaining in the file system
FlorianPfaff Mar 31, 2026
dd00bff
[MegaLinter] Apply linters automatic fixes
FlorianPfaff Mar 31, 2026
67612c0
Rebase onto main, move to conditional/, harmonize with sd_cond_sd_gri…
Copilot Mar 31, 2026
c46a799
Merge branch 'main' into copilot/add-conditional-distribution-class
FlorianPfaff Mar 31, 2026
5d64c77
Removed axis= keyword for backend compatibility
FlorianPfaff Apr 2, 2026
5a2b2ea
Fix R0801 duplicate-code: move shared logic to AbstractConditionalDis…
Copilot Apr 3, 2026
0c8f19d
Remove unused abs import from abstract_conditional_distribution
Copilot Apr 3, 2026
108cfa2
Replace numpy with pyrecest.backend in test_td_cond_td_grid_distribution
Copilot Apr 3, 2026
7e8dfbf
Suppress W0622 redefined-builtin for abs import in test file
Copilot Apr 3, 2026
136db1c
Fix random.uniform call to use cross-backend compatible size= kwarg
Copilot Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyrecest/distributions/conditional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .sd_cond_sd_grid_distribution import SdCondSdGridDistribution
from .td_cond_td_grid_distribution import TdCondTdGridDistribution

__all__ = ["SdCondSdGridDistribution"]
__all__ = ["SdCondSdGridDistribution", "TdCondTdGridDistribution"]
Original file line number Diff line number Diff line change
@@ -1,5 +1,164 @@
import copy
import warnings
from abc import ABC

# pylint: disable=redefined-builtin,no-name-in-module,no-member
from pyrecest.backend import (
any,
arange,
argmin,
array_equal,
linalg,
meshgrid,
)


class AbstractConditionalDistribution(ABC):
pass
"""Abstract base class for conditional grid distributions on manifolds.

Subclasses represent distributions of the form f(a | b) where both a and b
live on the same manifold. The joint state is stored as a square matrix
``grid_values`` where ``grid_values[i, j] = f(grid[i] | grid[j])``.
"""

def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
"""Common initialisation for conditional grid distributions.

Parameters
----------
grid : array of shape (n_points, d)
Grid points on the individual manifold.
grid_values : array of shape (n_points, n_points)
Conditional pdf values; ``grid_values[i, j] = f(grid[i] | grid[j])``.
enforce_pdf_nonnegative : bool
Whether to require non-negative ``grid_values``.
"""
if grid.ndim != 2:
raise ValueError("grid must be a 2D array of shape (n_points, d).")

n_points, d = grid.shape

if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points):
raise ValueError(
f"grid_values must be a square 2D array of shape ({n_points}, {n_points})."
)

if enforce_pdf_nonnegative and any(grid_values < 0):
raise ValueError("grid_values must be non-negative.")

self.grid = grid
self.grid_values = grid_values
self.enforce_pdf_nonnegative = enforce_pdf_nonnegative
# Embedding dimension of the Cartesian product space (convention from
# libDirectional: dim = 2 * dim_of_individual_manifold).
self.dim = 2 * d

# ------------------------------------------------------------------
# Normalization
# ------------------------------------------------------------------

def normalize(self):
"""No-op – returns ``self`` for compatibility."""
return self

# ------------------------------------------------------------------
# Arithmetic
# ------------------------------------------------------------------

def multiply(self, other):
"""Element-wise multiply two conditional grid distributions.

The resulting distribution is *not* normalized.

Parameters
----------
other : AbstractConditionalDistribution
Must be defined on the same grid.

Returns
-------
AbstractConditionalDistribution
Same concrete type as ``self``.
"""
if not array_equal(self.grid, other.grid):
raise ValueError(
"Multiply:IncompatibleGrid: Can only multiply distributions "
"defined on identical grids."
)
warnings.warn(
"Multiply:UnnormalizedResult: Multiplication does not yield a "
"normalized result.",
UserWarning,
)
result = copy.deepcopy(self)
result.grid_values = result.grid_values * other.grid_values
return result

# ------------------------------------------------------------------
# Protected helpers
# ------------------------------------------------------------------

def _get_grid_slice(self, first_or_second, point):
"""Return the ``grid_values`` slice for a fixed grid point.

Parameters
----------
first_or_second : int (1 or 2)
Which variable to fix.
point : array of shape (d,)
Must be an existing grid point.

Returns
-------
array of shape (n_points,)
"""
d = self.grid.shape[1]
if point.shape[0] != d:
raise ValueError(
f"point must have length {d} (grid dimension)."
)
diffs = linalg.norm(self.grid - point[None, :], axis=1)
locb = argmin(diffs)
if diffs[locb] > 1e-10:
raise ValueError(
"Cannot fix value at this point because it is not on the grid."
)
if first_or_second == 1:
return self.grid_values[locb, :]
if first_or_second == 2:
return self.grid_values[:, locb]
raise ValueError("first_or_second must be 1 or 2.")

@staticmethod
def _evaluate_on_grid(fun, grid, n, fun_does_cartesian_product):
"""Evaluate ``fun`` on all grid point pairs and return an (n, n) array.

Parameters
----------
fun : callable
``f(a, b)`` with the semantics described in ``from_function``.
grid : array of shape (n, d)
Grid points on the individual manifold.
n : int
Number of grid points (``grid.shape[0]``).
fun_does_cartesian_product : bool
Whether *fun* handles all grid combinations internally.

Returns
-------
array of shape (n, n)
"""
if fun_does_cartesian_product:
fvals = fun(grid, grid)
return fvals.reshape(n, n)
idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij")
grid_a = grid[idx_a.ravel()]
grid_b = grid[idx_b.ravel()]
fvals = fun(grid_a, grid_b)
if fvals.shape == (n**2, n**2):
raise ValueError(
"Function apparently performs the Cartesian product itself. "
"Set fun_does_cartesian_product=True."
)
return fvals.reshape(n, n)

110 changes: 6 additions & 104 deletions pyrecest/distributions/conditional/sd_cond_sd_grid_distribution.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import copy
import warnings

# pylint: disable=redefined-builtin,no-name-in-module,no-member
from pyrecest.backend import (
abs,
all,
any,
arange,
argmin,
array_equal,
linalg,
mean,
meshgrid,
sum,
)
from pyrecest.distributions.hypersphere_subset.abstract_hypersphere_subset_distribution import (
Expand Down Expand Up @@ -50,31 +44,11 @@ def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
enforce_pdf_nonnegative : bool
Whether non-negativity of ``grid_values`` is required.
"""
if grid.ndim != 2:
raise ValueError("grid must be a 2D array of shape (n_points, d).")

n_points, d = grid.shape

if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points):
raise ValueError(
f"grid_values must be a square 2D array of shape ({n_points}, {n_points})."
)

if any(abs(grid) > 1 + 1e-12):
super().__init__(grid, grid_values, enforce_pdf_nonnegative)
if any(abs(self.grid) > 1 + 1e-12):
raise ValueError(
"Grid points must have coordinates in [-1, 1] (unit sphere)."
)

if enforce_pdf_nonnegative and any(grid_values < 0):
raise ValueError("grid_values must be non-negative.")

self.grid = grid
self.grid_values = grid_values
self.enforce_pdf_nonnegative = enforce_pdf_nonnegative
# Embedding dimension of the Cartesian product space (convention from
# libDirectional: dim = 2 * embedding_dim_of_individual_sphere).
self.dim = 2 * d

self._check_normalization()

# ------------------------------------------------------------------
Expand Down Expand Up @@ -107,43 +81,6 @@ def _check_normalization(self, tol=0.01):
UserWarning,
)

def normalize(self):
"""No-op – returns ``self`` for compatibility."""
return self

# ------------------------------------------------------------------
# Arithmetic
# ------------------------------------------------------------------

def multiply(self, other):
"""
Element-wise multiply two conditional grid distributions.

The resulting distribution is *not* normalized.

Parameters
----------
other : SdCondSdGridDistribution
Must be defined on the same grid.

Returns
-------
SdCondSdGridDistribution
"""
if not array_equal(self.grid, other.grid):
raise ValueError(
"Multiply:IncompatibleGrid: Can only multiply distributions "
"defined on identical grids."
)
warnings.warn(
"Multiply:UnnormalizedResult: Multiplication does not yield a "
"normalized result.",
UserWarning,
)
result = copy.deepcopy(self)
result.grid_values = result.grid_values * other.grid_values
return result

# ------------------------------------------------------------------
# Marginalisation and conditioning
# ------------------------------------------------------------------
Expand Down Expand Up @@ -201,26 +138,7 @@ def fix_dim(self, first_or_second, point):
HypersphericalGridDistribution,
)

d = self.grid.shape[1]
if point.shape[0] != d:
raise ValueError(
f"point must have length {d} (embedding dimension of the sphere)."
)

diffs = linalg.norm(self.grid - point[None, :], axis=1)
locb = argmin(diffs)
if diffs[locb] > 1e-10:
raise ValueError(
"Cannot fix value at this point because it is not on the grid."
)

if first_or_second == 1:
grid_values_slice = self.grid_values[locb, :]
elif first_or_second == 2:
grid_values_slice = self.grid_values[:, locb]
else:
raise ValueError("first_or_second must be 1 or 2.")

grid_values_slice = self._get_grid_slice(first_or_second, point)
return HypersphericalGridDistribution(self.grid, grid_values_slice)

# ------------------------------------------------------------------
Expand Down Expand Up @@ -276,24 +194,8 @@ def from_function(
# manifold dim: embedding_dim = dim // 2, manifold_dim = embedding_dim - 1.
manifold_dim = dim // 2 - 1
grid, _ = get_grid_hypersphere(grid_type, n, manifold_dim)
# grid is (n, dim//2)

if fun_does_cartesian_product:
fvals = fun(grid, grid)
grid_values = fvals.reshape(n, n)
else:
# Build index pairs: idx_a[i, j] = i, idx_b[i, j] = j
idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij")
grid_a = grid[idx_a.ravel()] # (n*n, d)
grid_b = grid[idx_b.ravel()] # (n*n, d)
fvals = fun(grid_a, grid_b) # (n*n,)

if fvals.shape == (n**2, n**2):
raise ValueError(
"Function apparently performs the Cartesian product itself. "
"Set fun_does_cartesian_product=True."
)

grid_values = fvals.reshape(n, n)

grid_values = SdCondSdGridDistribution._evaluate_on_grid(
fun, grid, n, fun_does_cartesian_product
)
return SdCondSdGridDistribution(grid, grid_values)
Loading
Loading