diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 72efac1ad..c55c778e7 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -86,6 +86,7 @@ ) from .circle.von_mises_distribution import VonMisesDistribution from .circle.wrapped_cauchy_distribution import WrappedCauchyDistribution +from .circle.wrapped_exponential_distribution import WrappedExponentialDistribution from .circle.wrapped_laplace_distribution import WrappedLaplaceDistribution from .circle.wrapped_normal_distribution import WrappedNormalDistribution from .conditional.abstract_conditional_distribution import ( @@ -210,6 +211,7 @@ VMDistribution = VonMisesDistribution WDDistribution = CircularDiracDistribution VMFDistribution = VonMisesFisherDistribution +WEDistribution = WrappedExponentialDistribution aliases = [ "HypertoroidalWNDistribution", @@ -219,6 +221,7 @@ "VMDistribution", "WDDistribution", "VMFDistribution", + "WEDistribution", ] __all__ = aliases + [ @@ -273,6 +276,7 @@ "SineSkewedWrappedNormalDistribution", "VonMisesDistribution", "WrappedCauchyDistribution", + "WrappedExponentialDistribution", "WrappedLaplaceDistribution", "WrappedNormalDistribution", "AbstractConditionalDistribution", diff --git a/pyrecest/distributions/circle/wrapped_exponential_distribution.py b/pyrecest/distributions/circle/wrapped_exponential_distribution.py new file mode 100644 index 000000000..26c80f226 --- /dev/null +++ b/pyrecest/distributions/circle/wrapped_exponential_distribution.py @@ -0,0 +1,43 @@ +# pylint: disable=no-name-in-module,no-member +from typing import Union + +# pylint: disable=redefined-builtin +from pyrecest.backend import exp, int32, int64, log, mod, ndim, pi, random + +from .abstract_circular_distribution import AbstractCircularDistribution + + +class WrappedExponentialDistribution(AbstractCircularDistribution): + """Wrapped exponential distribution on the circle. + + See Sreenivasa Rao Jammalamadaka and Tomasz J. Kozubowski, "New + Families of Wrapped Distributions for Modeling Skew Circular Data", + Communications in Statistics - Theory and Methods, Vol. 33, No. 9, + pp. 2059-2074, 2004. + """ + + def __init__(self, lambda_): + AbstractCircularDistribution.__init__(self) + assert lambda_.shape in ((1,), ()) + assert lambda_ > 0.0 + self.lambda_ = lambda_ + self._normalization_const = 1.0 / (1.0 - exp(-2.0 * pi * lambda_)) + + def pdf(self, xs): + assert ndim(xs) <= 1 + xs = mod(xs, 2.0 * pi) + return self.lambda_ * exp(-self.lambda_ * xs) * self._normalization_const + + def trigonometric_moment(self, n): + return 1.0 / (1.0 - 1j * n / self.lambda_) + + def sample(self, n: Union[int, int32, int64]): + # Use inverse CDF method: X = -ln(U)/lambda ~ Exp(lambda), then wrap + u = random.uniform(size=(n,)) + return mod(-log(u) / self.lambda_, 2.0 * pi) + + def entropy(self): + # log(exp(2*pi*lambda)) = 2*pi*lambda, avoiding redundant exp/log + log_beta = 2.0 * pi * self.lambda_ + beta = exp(log_beta) + return 1.0 + log((beta - 1.0) / self.lambda_) - beta / (beta - 1.0) * log_beta diff --git a/pyrecest/tests/distributions/test_wrapped_exponential_distribution.py b/pyrecest/tests/distributions/test_wrapped_exponential_distribution.py new file mode 100644 index 000000000..c3e1a07b0 --- /dev/null +++ b/pyrecest/tests/distributions/test_wrapped_exponential_distribution.py @@ -0,0 +1,89 @@ +import unittest + +import numpy.testing as npt + +# pylint: disable=no-name-in-module,no-member +import pyrecest.backend + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import arange, arctan, array, exp, linspace, pi +from pyrecest.distributions.circle.wrapped_exponential_distribution import ( + WrappedExponentialDistribution, +) + + +class WrappedExponentialDistributionTest(unittest.TestCase): + def setUp(self): + self.lambda_ = array(2.0) + self.we = WrappedExponentialDistribution(self.lambda_) + + def test_pdf(self): + def pdftemp(x): + return sum( + self.lambda_ * exp(-self.lambda_ * (x + 2.0 * pi * k)) + for k in arange(-20, 21) + if x + 2.0 * pi * k >= 0 + ) + + for x in [0.0, 1.0, 2.0, 3.0, 4.0]: + npt.assert_allclose( + self.we.pdf(array(x)), pdftemp(array(x)), rtol=5e-7 + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", + ) + def test_integral(self): + npt.assert_allclose(self.we.integrate(), 1.0, rtol=5e-7) + npt.assert_allclose(self.we.integrate_numerically(), 1.0, rtol=5e-7) + npt.assert_allclose( + self.we.integrate(array([0.0, pi])) + + self.we.integrate(array([pi, 2.0 * pi])), + 1.0, + rtol=5e-7, + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on this backend", + ) + def test_angular_moments(self): + for i in range(1, 4): + npt.assert_allclose( + self.we.trigonometric_moment(i), + self.we.trigonometric_moment_numerical(i), + rtol=5e-7, + ) + + def test_circular_mean(self): + npt.assert_allclose( + self.we.mean_direction(), float(arctan(1.0 / self.lambda_)), rtol=5e-7 + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on this backend", + ) + def test_entropy(self): + npt.assert_allclose( + self.we.entropy(), self.we.entropy_numerical(), rtol=5e-7 + ) + + def test_periodicity(self): + npt.assert_allclose( + self.we.pdf(linspace(-2.0 * pi, 0.0, 100)), + self.we.pdf(linspace(0.0, 2.0 * pi, 100)), + rtol=5e-6, + ) + + def test_sample(self): + n = 100 + s = self.we.sample(n) + self.assertEqual(s.shape, (n,)) + self.assertTrue((s >= 0).all()) + self.assertTrue((s < 2.0 * pi).all()) + + +if __name__ == "__main__": + unittest.main()