Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyrecest/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
)
from .circle.von_mises_distribution import VonMisesDistribution
from .circle.wrapped_cauchy_distribution import WrappedCauchyDistribution
from .circle.wrapped_exponential_distribution import WrappedExponentialDistribution
from .circle.wrapped_laplace_distribution import WrappedLaplaceDistribution
from .circle.wrapped_normal_distribution import WrappedNormalDistribution
from .conditional.abstract_conditional_distribution import (
Expand Down Expand Up @@ -210,6 +211,7 @@
VMDistribution = VonMisesDistribution
WDDistribution = CircularDiracDistribution
VMFDistribution = VonMisesFisherDistribution
WEDistribution = WrappedExponentialDistribution

aliases = [
"HypertoroidalWNDistribution",
Expand All @@ -219,6 +221,7 @@
"VMDistribution",
"WDDistribution",
"VMFDistribution",
"WEDistribution",
]

__all__ = aliases + [
Expand Down Expand Up @@ -273,6 +276,7 @@
"SineSkewedWrappedNormalDistribution",
"VonMisesDistribution",
"WrappedCauchyDistribution",
"WrappedExponentialDistribution",
"WrappedLaplaceDistribution",
"WrappedNormalDistribution",
"AbstractConditionalDistribution",
Expand Down
43 changes: 43 additions & 0 deletions pyrecest/distributions/circle/wrapped_exponential_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# pylint: disable=no-name-in-module,no-member
from typing import Union

# pylint: disable=redefined-builtin
from pyrecest.backend import exp, int32, int64, log, mod, ndim, pi, random

from .abstract_circular_distribution import AbstractCircularDistribution


class WrappedExponentialDistribution(AbstractCircularDistribution):
"""Wrapped exponential distribution on the circle.

See Sreenivasa Rao Jammalamadaka and Tomasz J. Kozubowski, "New
Families of Wrapped Distributions for Modeling Skew Circular Data",
Communications in Statistics - Theory and Methods, Vol. 33, No. 9,
pp. 2059-2074, 2004.
"""

def __init__(self, lambda_):
AbstractCircularDistribution.__init__(self)
assert lambda_.shape in ((1,), ())
assert lambda_ > 0.0
self.lambda_ = lambda_
self._normalization_const = 1.0 / (1.0 - exp(-2.0 * pi * lambda_))

def pdf(self, xs):
assert ndim(xs) <= 1
xs = mod(xs, 2.0 * pi)
return self.lambda_ * exp(-self.lambda_ * xs) * self._normalization_const

def trigonometric_moment(self, n):
return 1.0 / (1.0 - 1j * n / self.lambda_)

def sample(self, n: Union[int, int32, int64]):
# Use inverse CDF method: X = -ln(U)/lambda ~ Exp(lambda), then wrap
u = random.uniform(size=(n,))
return mod(-log(u) / self.lambda_, 2.0 * pi)

def entropy(self):
# log(exp(2*pi*lambda)) = 2*pi*lambda, avoiding redundant exp/log
log_beta = 2.0 * pi * self.lambda_
beta = exp(log_beta)
return 1.0 + log((beta - 1.0) / self.lambda_) - beta / (beta - 1.0) * log_beta
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import unittest

import numpy.testing as npt

# pylint: disable=no-name-in-module,no-member
import pyrecest.backend

# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import arange, arctan, array, exp, linspace, pi
from pyrecest.distributions.circle.wrapped_exponential_distribution import (
WrappedExponentialDistribution,
)


class WrappedExponentialDistributionTest(unittest.TestCase):
def setUp(self):
self.lambda_ = array(2.0)
self.we = WrappedExponentialDistribution(self.lambda_)

def test_pdf(self):
def pdftemp(x):
return sum(
self.lambda_ * exp(-self.lambda_ * (x + 2.0 * pi * k))
for k in arange(-20, 21)
if x + 2.0 * pi * k >= 0
)

for x in [0.0, 1.0, 2.0, 3.0, 4.0]:
npt.assert_allclose(
self.we.pdf(array(x)), pdftemp(array(x)), rtol=5e-7
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
reason="Not supported on this backend",
)
def test_integral(self):
npt.assert_allclose(self.we.integrate(), 1.0, rtol=5e-7)
npt.assert_allclose(self.we.integrate_numerically(), 1.0, rtol=5e-7)
npt.assert_allclose(
self.we.integrate(array([0.0, pi]))
+ self.we.integrate(array([pi, 2.0 * pi])),
1.0,
rtol=5e-7,
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on this backend",
)
def test_angular_moments(self):
for i in range(1, 4):
npt.assert_allclose(
self.we.trigonometric_moment(i),
self.we.trigonometric_moment_numerical(i),
rtol=5e-7,
)

def test_circular_mean(self):
npt.assert_allclose(
self.we.mean_direction(), float(arctan(1.0 / self.lambda_)), rtol=5e-7
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on this backend",
)
def test_entropy(self):
npt.assert_allclose(
self.we.entropy(), self.we.entropy_numerical(), rtol=5e-7
)

def test_periodicity(self):
npt.assert_allclose(
self.we.pdf(linspace(-2.0 * pi, 0.0, 100)),
self.we.pdf(linspace(0.0, 2.0 * pi, 100)),
rtol=5e-6,
)

def test_sample(self):
n = 100
s = self.we.sample(n)
self.assertEqual(s.shape, (n,))
self.assertTrue((s >= 0).all())
self.assertTrue((s < 2.0 * pi).all())


if __name__ == "__main__":
unittest.main()
Loading