Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 154 additions & 4 deletions pyrecest/distributions/hypersphere_subset/bingham_distribution.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
# pylint: disable=redefined-builtin,no-name-in-module,no-member
# pylint: disable=no-name-in-module,no-member
from scipy.integrate import quad
from scipy.optimize import fsolve
from scipy.special import iv

from pyrecest.backend import (
abs,
all,
argsort,
array,
concatenate,
diag,
exp,
eye,
linalg,
max,
maximum,
pi,
ones,
sum,
zeros,
sort,
)
from scipy.integrate import quad
from scipy.special import iv

from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution

Expand Down Expand Up @@ -54,7 +59,12 @@ def F(self, value):

@staticmethod
def calculate_F(Z):
"""Uses method by wood. Only supports 4-D distributions."""
"""Uses analytical method. Supports 2-D and 4-D distributions."""
if Z.shape[0] == 2:
# F = exp((Z[0]+Z[1])/2) * 2*pi * I_0(|Z[0]-Z[1]|/2)
return float(
exp((Z[0] + Z[1]) / 2) * 2 * pi * iv(0, abs(float(Z[0] - Z[1])) / 2)
)
assert Z.shape[0] == 4

def J(Z, u):
Expand Down Expand Up @@ -157,3 +167,143 @@ def moment(self):
S = self.M @ D @ self.M.T
S = (S + S.T) / 2 # Enforce symmetry
return S

def mode(self):
"""Returns the mode of the Bingham distribution.

The mode is the eigenvector corresponding to Z=0 (the maximum), i.e.,
the last column of M.

Returns:
mode (numpy.ndarray): mode as a unit vector in R^{dim+1}
"""
return self.M[:, -1]

def sample_deterministic(self, _spread=0.5):
"""Returns deterministic sigma-point samples and weights.

Generates 2*(dim+1) sigma points as ±columns of M with weights
derived from the normalized moments, so that the weighted scatter
matrix equals the distribution's moment matrix.

Parameters:
_spread (float): spread parameter reserved for future use (e.g., tuning
the sigma-point placement); currently the samples are always ±M columns

Returns:
samples (numpy.ndarray): shape (dim+1, 2*(dim+1)), columns are samples
weights (numpy.ndarray): shape (2*(dim+1),), non-negative weights summing to 1
"""
d = self.dF / self.F
d = d / sum(d) # normalize
# ±columns of M with equal weight d_i/2 for both signs
samples = concatenate([self.M, -self.M], axis=1)
weights = concatenate([d / 2, d / 2])
return samples, weights

@staticmethod
def _right_mult_matrix(q):
"""Right multiplication matrix for complex (2D) or quaternion (4D).

For 2D complex q = [a, b]: z * q corresponds to [[a, -b], [b, a]] * z
For 4D quaternion q = [w, x, y, z]: p * q = R(q) * p where R is returned.
"""
if q.shape[0] == 2:
return array([[q[0], -q[1]], [q[1], q[0]]])
if q.shape[0] == 4:
w, x, y, z = q[0], q[1], q[2], q[3]
return array(
[
[w, -x, -y, -z],
[x, w, z, -y],
[y, -z, w, x],
[z, y, -x, w],
]
)
raise ValueError("Only 2D and 4D are supported")

def compose(self, B2):
"""Compose two Bingham distributions via complex or quaternion multiplication.

Computes the Bingham distribution approximating the scatter matrix of
the product x*y, where x ~ self and y ~ B2 are independent.

Parameters:
B2 (BinghamDistribution): second distribution

Returns:
BinghamDistribution: composed distribution
"""
assert isinstance(B2, BinghamDistribution)
assert self.dim == B2.dim, "Dimensions must match"
assert self.dim in (1, 3), "Compose only supported for 2D and 4D distributions"

d2 = B2.dF / B2.F
d2 = d2 / sum(d2)
S1 = self.moment()

n = self.input_dim
S = zeros((n, n))
for j in range(n):
R_j = BinghamDistribution._right_mult_matrix(B2.M[:, j])
S = S + d2[j] * R_j @ S1 @ R_j.T

S = (S + S.T) / 2
return BinghamDistribution.fit_to_moment(S)

@staticmethod
def fit_to_moment(S):
"""Fit a Bingham distribution to a given scatter/moment matrix.

Finds Z and M such that the moment of B(Z, M) matches S.

Parameters:
S (numpy.ndarray): symmetric positive semi-definite matrix with trace 1
(or will be normalized)

Returns:
BinghamDistribution: fitted distribution
"""
n = S.shape[0]
S_np = array(S, dtype=float)
S_np = (S_np + S_np.T) / 2

# Eigendecompose S: eigenvectors sorted by ascending eigenvalue
eigenvalues, M_np = linalg.eigh(S_np)
eigenvalues = eigenvalues.real
M_np = M_np.real

# Normalize eigenvalues to get target moments (they should sum to 1)
eigenvalues = maximum(eigenvalues, 0)
ev_sum = eigenvalues.sum()
if ev_sum == 0:
target_d = ones(n) / n
else:
target_d = eigenvalues / ev_sum

def moment_residual(z_free):
Z_cand = concatenate((z_free, array([0.0])))
Z_sorted = sort(Z_cand)
M_sorted = M_np[:, argsort(Z_cand)]
try:
B_temp = BinghamDistribution(array(Z_sorted), array(M_sorted))
d = array(B_temp.dF / B_temp.F, dtype=float)
d = d / d.sum()
return d[:-1] - target_d[:-1]
except (
AssertionError,
ValueError,
RuntimeError,
): # pylint: disable=broad-except
return ones(n - 1) * 1e6

# Initial guess: scale based on target moments relative to last
z0 = -(target_d[-1] - target_d[:-1]) * 10.0
z_sol = fsolve(moment_residual, z0, full_output=False)

Z_out = concatenate((z_sol, array([0.0])))
idx = argsort(Z_out)
Z_final = Z_out[idx]
M_final = M_np[:, idx]

return BinghamDistribution(array(Z_final), array(M_final))
2 changes: 2 additions & 0 deletions pyrecest/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .abstract_filter import AbstractFilter
from .abstract_particle_filter import AbstractParticleFilter
from .bingham_filter import BinghamFilter
from .euclidean_particle_filter import EuclideanParticleFilter
from .hypertoroidal_particle_filter import HypertoroidalParticleFilter
from .kalman_filter import KalmanFilter
from .manifold_mixins import EuclideanFilterMixin, HypertoroidalFilterMixin

__all__ = [
"AbstractFilter",
"BinghamFilter",
"EuclideanFilterMixin",
"HypertoroidalFilterMixin",
"AbstractParticleFilter",
Expand Down
152 changes: 152 additions & 0 deletions pyrecest/filters/bingham_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# pylint: disable=no-name-in-module,no-member
import copy

from pyrecest.backend import array, diag
from pyrecest.distributions.hypersphere_subset.bingham_distribution import (
BinghamDistribution,
)

from .abstract_filter import AbstractFilter


class BinghamFilter(AbstractFilter):
"""Recursive filter based on the Bingham distribution.

Supports antipodally symmetric complex numbers (2D) and quaternions (4D).

References:
- Gerhard Kurz, Igor Gilitschenski, Simon Julier, Uwe D. Hanebeck,
Recursive Bingham Filter for Directional Estimation Involving 180
Degree Symmetry, Journal of Advances in Information Fusion,
9(2):90-105, December 2014.
- Igor Gilitschenski, Gerhard Kurz, Simon J. Julier, Uwe D. Hanebeck,
Unscented Orientation Estimation Based on the Bingham Distribution,
IEEE Transactions on Automatic Control, January 2016.
"""

def __init__(self):
# Default 4-D identity initial state (uniform on S^3, suitable for quaternion orientation)
initial_state = BinghamDistribution(
array([-1.0, -1.0, -1.0, 0.0]),
array(
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=float
),
)
AbstractFilter.__init__(self, initial_state)

@property
def filter_state(self):
return self._filter_state

@filter_state.setter
def filter_state(self, new_state):
assert isinstance(
new_state, BinghamDistribution
), "filter_state must be a BinghamDistribution"
assert new_state.dim in (
1,
3,
), "Only 2D and 4D Bingham distributions are supported"
self._filter_state = copy.deepcopy(new_state)

def predict_identity(self, bw):
"""Predict assuming identity system model with Bingham noise.

Computes x(k+1) = x(k) (*) w(k) where (*) is complex or quaternion
multiplication and w(k) ~ bw.

Parameters:
bw (BinghamDistribution): noise distribution
"""
assert isinstance(bw, BinghamDistribution)
self.filter_state = self.filter_state.compose(bw)

def predict_nonlinear(self, a, bw):
"""Predict assuming nonlinear system model with Bingham noise.

Computes x(k+1) = a(x(k)) (*) w(k) using a sigma-point approximation.

Parameters:
a (callable): nonlinear system function mapping R^n -> R^n
bw (BinghamDistribution): noise distribution
"""
assert isinstance(bw, BinghamDistribution)

samples, weights = self.filter_state.sample_deterministic(0.5)

# Propagate each sample through the system function
for i in range(len(weights)):
samples[:, i] = a(samples[:, i])

# Compute scatter matrix of propagated samples
S = samples @ diag(weights) @ samples.T
S = (S + S.T) / 2

predicted = BinghamDistribution.fit_to_moment(S)
self.filter_state = predicted.compose(bw)

def update_identity(self, bv, z):
"""Update assuming identity measurement model with Bingham noise.

Applies the measurement z using likelihood based on Bingham noise bv.

Parameters:
bv (BinghamDistribution): measurement noise distribution
z (numpy.ndarray): measurement as a unit vector of shape (dim+1,)
"""
assert isinstance(bv, BinghamDistribution)
assert bv.dim == self.filter_state.dim
assert z.shape == (self.filter_state.input_dim,)

bv = copy.deepcopy(bv)
n = bv.input_dim
for i in range(n):
m_conj = self._conjugate(bv.M[:, i])
bv.M[:, i] = self._compose(z, m_conj)

self.filter_state = self.filter_state.multiply(bv)

def get_point_estimate(self):
"""Return the mode of the current distribution as a point estimate."""
return self.filter_state.mode()

@staticmethod
def _conjugate(q):
"""Return the conjugate of a unit complex number or quaternion.

For q = [w, x, y, z], conjugate = [w, -x, -y, -z].
For q = [a, b], conjugate = [a, -b].
"""
result = q.copy()
result[1:] = -result[1:]
return result

@staticmethod
def _compose(q1, q2):
"""Compose two unit complex numbers or quaternions via multiplication.

Parameters:
q1, q2: unit vectors of length 2 or 4

Returns:
product q1 * q2
"""
if q1.shape[0] == 2:
# Complex multiplication
return array(
[
q1[0] * q2[0] - q1[1] * q2[1],
q1[0] * q2[1] + q1[1] * q2[0],
]
)
# Hamilton quaternion product
w1, x1, y1, z1 = q1[0], q1[1], q1[2], q1[3]
w2, x2, y2, z2 = q2[0], q2[1], q2[2], q2[3]
return array(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
]
)
Loading
Loading