diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 72efac1ad..094565516 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -71,6 +71,7 @@ from .cart_prod.partially_wrapped_normal_distribution import ( PartiallyWrappedNormalDistribution, ) +from .cart_prod.se2_bingham_distribution import SE2BinghamDistribution from .circle.abstract_circular_distribution import AbstractCircularDistribution from .circle.circular_dirac_distribution import CircularDiracDistribution from .circle.circular_fourier_distribution import CircularFourierDistribution @@ -328,4 +329,5 @@ "LinearMixture", "SE3CartProdStackedDistribution", "SE3DiracDistribution", + "SE2BinghamDistribution", ] diff --git a/pyrecest/distributions/cart_prod/se2_bingham_distribution.py b/pyrecest/distributions/cart_prod/se2_bingham_distribution.py new file mode 100644 index 000000000..7cc73a5e5 --- /dev/null +++ b/pyrecest/distributions/cart_prod/se2_bingham_distribution.py @@ -0,0 +1,341 @@ +# pylint: disable=redefined-builtin,no-name-in-module,no-member +import numpy as np +from scipy.integrate import quad +from scipy.special import iv + +from pyrecest.backend import ( + argmax, + argsort, + array, + column_stack, + concatenate, + exp, + linalg, + pi, + sort, + sqrt, + sum, +) + +from ..abstract_se2_distribution import AbstractSE2Distribution +from ..hypersphere_subset.bingham_distribution import BinghamDistribution +from ..nonperiodic.custom_linear_distribution import CustomLinearDistribution + + +class SE2BinghamDistribution(AbstractSE2Distribution): + """ + Distribution on SE(2) = S^1 x R^2. + + The density is f(x) = (1/NC) * exp(x^T * C * x) where x is the dual + quaternion representation (first two components on S^1, last two + components in R^2). + + C is a 4x4 symmetric matrix partitioned as:: + + C = [ C1 C2^T ] + [ C2 C3 ] + + where: + - C1 (2x2): symmetric, controls the Bingham (rotational) part + - C2 (2x2): coupling between rotation and translation + - C3 (2x2): symmetric, negative-definite, controls the Gaussian (translational) part + + Reference: + Igor Gilitschenski, Gerhard Kurz, Simon J. Julier, Uwe D. Hanebeck, + "A New Probability Distribution for Simultaneous Representation of + Uncertain Position and Orientation", + Proceedings of the 17th International Conference on Information Fusion + (Fusion 2014), Salamanca, Spain, July 2014. + """ + + def __init__(self, C, C2=None, C3=None): + """ + Create an SE2BinghamDistribution. + + Parameters + ---------- + C : array_like, shape (4, 4) or (2, 2) + If C2 and C3 are not provided, this is the full 4x4 parameter + matrix. Otherwise it is the 2x2 Bingham (rotational) part C1. + C2 : array_like, shape (2, 2), optional + Coupling matrix between rotation and translation. + C3 : array_like, shape (2, 2), optional + Symmetric negative-definite matrix for the translational part. + """ + AbstractSE2Distribution.__init__(self) + + assert (C2 is None) == (C3 is None), ( + "Either both C2 and C3 must be provided, or neither." + ) + + if C2 is None: + assert C.shape == (4, 4), "C must be 4x4 when C2 and C3 are not provided." + assert np.allclose(np.array(C), np.array(C).T), "Full C matrix must be symmetric." + self.C = C + self.C1 = C[:2, :2] + self.C2 = C[2:, :2] + self.C3 = C[2:, 2:] + else: + assert C.shape == (2, 2), "C1 must be 2x2." + assert C2.shape == (2, 2), "C2 must be 2x2." + assert C3.shape == (2, 2), "C3 must be 2x2." + assert np.allclose(np.array(C), np.array(C).T), "C1 must be symmetric." + assert np.allclose(np.array(C3), np.array(C3).T), "C3 must be symmetric." + self.C1 = C + self.C2 = C2 + self.C3 = C3 + self.C = column_stack( + [ + column_stack([self.C1, self.C2.T]).T, + column_stack([self.C2, self.C3]).T, + ] + ).T + + assert np.all(np.linalg.eigvalsh(np.array(self.C3)) <= 0), ( + "C3 must be negative semi-definite." + ) + + self._nc = None # lazily computed + + @property + def nc(self): + """Normalization constant (lazily computed).""" + if self._nc is None: + self._nc = self._compute_nc() + return self._nc + + def _compute_nc(self): + """ + Compute the normalization constant. + + NC = 2*pi * sqrt(det(-0.5 * C3^{-1})) * F_bingham(Z_bm) + + where Z_bm are the eigenvalues of the Schur complement + BM = C1 - C2^T * C3^{-1} * C2, + and F_bingham is the 2D Bingham normalization constant + F = 2*pi * exp((z1+z2)/2) * I_0((z2-z1)/2). + """ + C1 = array(self.C1, dtype=float) + C2 = array(self.C2, dtype=float) + C3 = array(self.C3, dtype=float) + C3_inv = linalg.inv(C3) + bm = C1 - C2.T @ C3_inv @ C2 + z = sort(linalg.eigvalsh(bm)) # ascending + # 2D Bingham normalization on S^1 + b_nc = 2.0 * pi * exp((z[0] + z[1]) / 2.0) * iv(0, (z[1] - z[0]) / 2.0) + nc = 2.0 * pi * sqrt(linalg.det(-0.5 * C3_inv)) * b_nc + return float(nc) + + def pdf(self, xs): + """ + Evaluate the probability density at the given points. + + Parameters + ---------- + xs : array_like, shape (N, 4) or (N, 3) + Evaluation points in dual quaternion (N x 4) or angle-pos + (N x 3) representation. + + Returns + ------- + p : array, shape (N,) + Density values. + """ + xs = array(xs) + if xs.ndim == 1: + xs = xs.reshape(1, -1) + if xs.shape[1] == 3: + xs = AbstractSE2Distribution.angle_pos_to_dual_quaternion(xs) + assert xs.shape[1] == 4, "Input must have 4 columns (dual quaternion)." + return (1.0 / self.nc) * exp(sum(xs * (xs @ self.C.T), axis=1)) + + def mode(self): + """ + Compute one mode of the distribution. + + Because of antipodal symmetry, -mode is equally valid. + + Returns + ------- + m : array, shape (4,) + Mode in dual quaternion representation. + """ + C1 = array(self.C1, dtype=float) + C2 = array(self.C2, dtype=float) + C3 = array(self.C3, dtype=float) + C3_inv = linalg.inv(C3) + bingham_c = C1 - C2.T @ C3_inv @ C2 + eigenvalues, eigenvectors = linalg.eigh(bingham_c) + idx = int(argmax(eigenvalues)) + m_rot = eigenvectors[:, idx] + m_lin = -C3_inv @ C2 @ m_rot + return array(concatenate([m_rot, m_lin])) + + def sample(self, n): + """ + Draw n samples from the distribution. + + Sampling uses a two-step procedure: + 1. Sample the rotational part from the Bingham marginal. + 2. Sample the translational part from the Gaussian conditional. + + Parameters + ---------- + n : int + Number of samples. + + Returns + ------- + s : array, shape (n, 4) + Samples in dual quaternion representation. + """ + assert n > 0, "n must be positive." + C3_inv = linalg.inv(array(self.C3, dtype=float)) + + # Step 1: sample Bingham marginal via Schur complement eigendecomp + bingham_c = ( + array(self.C1, dtype=float) + - array(self.C2, dtype=float).T @ C3_inv @ array(self.C2, dtype=float) + ) + eigenvalues, eigenvectors = linalg.eigh(bingham_c) + order = argsort(eigenvalues) # ascending + eigenvalues = eigenvalues[order] + b = BinghamDistribution( + array(eigenvalues - eigenvalues[-1]), array(eigenvectors[:, order]) + ) + bingham_samples = b.sample(n) # (n, 2) + + # Step 2: sample Gaussian conditional + # mean_i = -C3^{-1} * C2 * x_rot_i + cov = np.array(-0.5 * C3_inv) + means = np.array((-C3_inv @ array(self.C2, dtype=float) @ array(bingham_samples).T).T) + lin_samples = array( + means + np.random.multivariate_normal(np.zeros(2), cov, size=n) + ) + + return column_stack([bingham_samples, lin_samples]) + + def marginalize_linear(self): + """ + Return the marginal distribution over the periodic (rotational) part. + + The marginal is the Bingham distribution corresponding to the Schur + complement BM = C1 - C2^T * C3^{-1} * C2. + + Returns + ------- + b : BinghamDistribution + Marginal Bingham distribution on S^1. + """ + C1 = np.array(self.C1, dtype=float) + C2 = np.array(self.C2, dtype=float) + C3 = np.array(self.C3, dtype=float) + C3_inv = np.linalg.inv(C3) + bm = C1 - C2.T @ C3_inv @ C2 + eigenvalues, eigenvectors = np.linalg.eigh(bm) + order = np.argsort(eigenvalues) + eigenvalues = eigenvalues[order] + eigenvectors = eigenvectors[:, order] + z = array(eigenvalues - eigenvalues[-1]) + m = array(eigenvectors) + return BinghamDistribution(z, m) + + def marginalize_periodic(self): + """ + Return the marginal distribution over the linear (translational) part. + + The marginal is computed by numerically integrating out the rotational + component. + + Returns + ------- + dist : CustomLinearDistribution + Marginal distribution over R^2. + """ + C_np = array(self.C, dtype=float) + nc = self.nc + + def _marginal_pdf(xs): + xs = np.atleast_2d(xs) + out = np.empty(xs.shape[0]) + for i, x_lin in enumerate(xs): + # Integrate exp(x^T C x) over S^1 using the angle parametrisation + def integrand(theta, xl=x_lin): + x_rot = np.array([np.cos(theta), np.sin(theta)]) + x = np.concatenate([x_rot, xl]) + return np.exp(float(x @ C_np @ x)) + + val, _ = quad(integrand, 0.0, 2.0 * np.pi) + out[i] = val / nc + return array(out) + + return CustomLinearDistribution(_marginal_pdf, self.lin_dim) + + @staticmethod + def fit(samples, weights=None): + """ + Estimate SE2BinghamDistribution parameters from samples. + + Parameters + ---------- + samples : array_like, shape (N, 4) or (N, 3) + Samples in dual quaternion (N x 4) or angle-pos (N x 3) form. + weights : array_like, shape (N,), optional + Non-negative weights (need not sum to 1). Defaults to uniform. + + Returns + ------- + dist : SE2BinghamDistribution + Fitted distribution. + """ + samples = np.array(samples, dtype=float) + if samples.shape[1] == 3: + samples = np.array( + AbstractSE2Distribution.angle_pos_to_dual_quaternion(array(samples)) + ) + assert samples.shape[1] == 4 + + n = samples.shape[0] + if weights is None: + weights = np.ones(n) / n + else: + weights = np.asarray(weights, dtype=float) + weights = weights / weights.sum() + + w = weights[:, np.newaxis] + s_rot, s_lin = samples[:, :2], samples[:, 2:] + + # Bingham block: estimate Schur complement from weighted scatter + schur_c = SE2BinghamDistribution._schur_from_scatter(s_rot, w) + + # Gaussian block: estimate C2 and C3 via weighted regression + c2_est, c3_est = SE2BinghamDistribution._fit_gaussian_block(s_rot, s_lin, w) + + # Recover C1 from Schur complement definition + c1_est = schur_c + c2_est.T @ np.linalg.inv(c3_est) @ c2_est + c1_est = 0.5 * (c1_est + c1_est.T) + + return SE2BinghamDistribution(array(c1_est), array(c2_est), array(c3_est)) + + @staticmethod + def _schur_from_scatter(s_rot, w): + """Return the estimated Schur complement C1 - C2' C3^{-1} C2 from samples.""" + scatter = (s_rot * w).T @ s_rot + eigenvalues, eigenvectors = np.linalg.eigh(scatter) + order = np.argsort(eigenvalues) + eigenvalues, eigenvectors = eigenvalues[order], eigenvectors[:, order] + z = eigenvalues - eigenvalues[-1] + return eigenvectors @ np.diag(z) @ eigenvectors.T + + @staticmethod + def _fit_gaussian_block(s_rot, s_lin, w): + """Return estimated (C2, C3) via weighted linear regression.""" + reg_a = (s_rot * w).T @ s_rot + # Use pinv for numerical stability when reg_a is nearly singular + reg_beta = (s_lin * w).T @ s_rot @ np.linalg.pinv(reg_a) + residuals = s_lin - s_rot @ reg_beta.T + reg_cov = (residuals * w).T @ residuals + # Use pinv: reg_cov may be ill-conditioned when samples cluster on a subspace + c3_est = np.linalg.pinv(-2.0 * reg_cov) + c3_est = 0.5 * (c3_est + c3_est.T) + return -c3_est @ reg_beta, c3_est diff --git a/pyrecest/tests/distributions/test_se2_bingham_distribution.py b/pyrecest/tests/distributions/test_se2_bingham_distribution.py new file mode 100644 index 000000000..2d7aad7a6 --- /dev/null +++ b/pyrecest/tests/distributions/test_se2_bingham_distribution.py @@ -0,0 +1,150 @@ +import unittest + +import numpy as np +import numpy.testing as npt +import pyrecest.backend + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import array +from pyrecest.distributions import SE2BinghamDistribution + + +class TestSE2BinghamDistribution(unittest.TestCase): + def setUp(self): + """Set up a test SE2BinghamDistribution instance.""" + # Build a valid parameter set: C3 negative definite, C1 symmetric + self.C1 = array([[-3.0, 0.5], [0.5, -1.0]]) + self.C2 = array([[0.1, 0.2], [-0.1, 0.3]]) + self.C3 = array([[-2.0, 0.1], [0.1, -1.5]]) + self.dist = SE2BinghamDistribution(self.C1, self.C2, self.C3) + + def test_constructor_from_parts(self): + """Distribution can be constructed from C1, C2, C3.""" + dist = SE2BinghamDistribution(self.C1, self.C2, self.C3) + self.assertIsInstance(dist, SE2BinghamDistribution) + + def test_constructor_from_full_matrix(self): + """Distribution can be constructed from the full 4x4 matrix.""" + C_full = self.dist.C + dist2 = SE2BinghamDistribution(C_full) + self.assertIsInstance(dist2, SE2BinghamDistribution) + npt.assert_array_almost_equal(np.array(dist2.C1), np.array(self.dist.C1)) + npt.assert_array_almost_equal(np.array(dist2.C2), np.array(self.dist.C2)) + npt.assert_array_almost_equal(np.array(dist2.C3), np.array(self.dist.C3)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_nc_positive(self): + """Normalization constant must be positive.""" + self.assertGreater(self.dist.nc, 0.0) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_pdf_positive(self): + """PDF values must be positive.""" + # A few dual-quaternion-like points (norm of first two not necessarily 1 here, + # but pdf is evaluated at arbitrary 4D points) + points = array( + [ + [1.0, 0.0, 0.5, -0.3], + [0.0, 1.0, 0.1, 0.2], + [0.7071, 0.7071, -0.2, 0.4], + ] + ) + vals = self.dist.pdf(points) + self.assertTrue(np.all(np.array(vals) > 0)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_pdf_from_angle_pos(self): + """PDF should accept angle-pos (N x 3) input and give consistent results.""" + # Create angle-pos samples + angle_pos = array([[0.5, 1.0, -1.0], [1.0, 0.0, 0.5]]) + p_ap = self.dist.pdf(angle_pos) + + # Convert to dual quaternion manually and evaluate + from pyrecest.distributions import AbstractSE2Distribution + + dq = AbstractSE2Distribution.angle_pos_to_dual_quaternion(angle_pos) + p_dq = self.dist.pdf(dq) + npt.assert_array_almost_equal(np.array(p_ap), np.array(p_dq)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_mode_shape(self): + """Mode should be a 4-element array.""" + m = self.dist.mode() + self.assertEqual(np.array(m).shape, (4,)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_mode_is_local_maximum(self): + """PDF at mode should be >= PDF at nearby on-manifold perturbed points.""" + from pyrecest.distributions import AbstractSE2Distribution + + # Mode is in dual-quaternion (DQ) representation + m_dq = array(np.array(self.dist.mode()).reshape(1, -1)) + + # Convert mode DQ → angle-pos + angle_arr, pos_arr = AbstractSE2Distribution.dual_quaternion_to_angle_pos(m_dq) + angle0 = float(np.array(angle_arr).ravel()[0]) + pos0 = np.array(pos_arr).ravel() + + p_mode = float(np.array(self.dist.pdf(m_dq)).ravel()[0]) + + rng = np.random.default_rng(42) + for _ in range(20): + # Perturb angle and position (stays on S^1 x R^2 manifold) + angle_p = angle0 + rng.normal(0, 0.15) + pos_p = pos0 + rng.normal(0, 0.15, size=2) + ap = array(np.array([[angle_p, pos_p[0], pos_p[1]]])) + dq_p = AbstractSE2Distribution.angle_pos_to_dual_quaternion(ap) + p_perturbed = float(np.array(self.dist.pdf(dq_p)).ravel()[0]) + self.assertGreaterEqual(p_mode, p_perturbed - 1e-6) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_sample_shape(self): + """sample() must return an (n, 4) array.""" + n = 50 + s = self.dist.sample(n) + self.assertEqual(np.array(s).shape, (n, 4)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_fit_returns_instance(self): + """fit() should return a valid SE2BinghamDistribution.""" + samples = self.dist.sample(500) + fitted = SE2BinghamDistribution.fit(samples) + self.assertIsInstance(fitted, SE2BinghamDistribution) + self.assertGreater(fitted.nc, 0.0) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) + def test_fit_weighted(self): + """fit() should accept explicit weights.""" + n = 200 + samples = self.dist.sample(n) + weights = np.ones(n) / n + fitted = SE2BinghamDistribution.fit(samples, weights) + self.assertIsInstance(fitted, SE2BinghamDistribution) + + +if __name__ == "__main__": + unittest.main()