diff --git a/src/WaveBlocksND/DirectHomogeneousQuadrature.py b/src/WaveBlocksND/DirectHomogeneousQuadrature.py index 9a3f3a7c..d46a5d13 100644 --- a/src/WaveBlocksND/DirectHomogeneousQuadrature.py +++ b/src/WaveBlocksND/DirectHomogeneousQuadrature.py @@ -13,20 +13,18 @@ from scipy.linalg import sqrtm, inv #, svd, diagsvd from DirectQuadrature import DirectQuadrature +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["DirectHomogeneousQuadrature"] -class DirectHomogeneousQuadrature(DirectQuadrature): +class DirectHomogeneousQuadrature(DirectQuadrature, InnerProductCompatibility): r""" """ def __init__(self, QR=None): # Pure convenience to allow setting of quadrature rule in constructor - if QR is not None: - self.set_qr(QR) - else: - self._QR = None + self.set_qr(QR) def __str__(self): @@ -45,6 +43,10 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous",) + + def initialize_packet(self, packet): r"""Provide the wavepacket part of the inner product to evaluate. Since the quadrature is homogeneous the same wavepacket is used diff --git a/src/WaveBlocksND/DirectInhomogeneousQuadrature.py b/src/WaveBlocksND/DirectInhomogeneousQuadrature.py index 8b70aa28..812a81ab 100644 --- a/src/WaveBlocksND/DirectInhomogeneousQuadrature.py +++ b/src/WaveBlocksND/DirectInhomogeneousQuadrature.py @@ -14,20 +14,18 @@ from scipy.linalg import sqrtm, inv, det from DirectQuadrature import DirectQuadrature +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["DirectInhomogeneousQuadrature"] -class DirectInhomogeneousQuadrature(DirectQuadrature): +class DirectInhomogeneousQuadrature(DirectQuadrature, InnerProductCompatibility): r""" """ def __init__(self, QR=None): # Pure convenience to allow setting of quadrature rule in constructor - if QR is not None: - self.set_qr(QR) - else: - self._QR = None + self.set_qr(QR) def __str__(self): @@ -46,6 +44,10 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous", "inhomogeneous",) + + def initialize_packet(self, pacbra, packet=None): r"""Provide the wavepacket parts of the inner product to evaluate. Since the quadrature is inhomogeneous different wavepackets can be diff --git a/src/WaveBlocksND/HomogeneousInnerProduct.py b/src/WaveBlocksND/HomogeneousInnerProduct.py index dc3ef9c6..6a5ccca8 100644 --- a/src/WaveBlocksND/HomogeneousInnerProduct.py +++ b/src/WaveBlocksND/HomogeneousInnerProduct.py @@ -12,11 +12,12 @@ from numpy import zeros, complexfloating, sum, cumsum from InnerProduct import InnerProduct +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["HomogeneousInnerProduct"] -class HomogeneousInnerProduct(InnerProduct): +class HomogeneousInnerProduct(InnerProduct, InnerProductCompatibility): r""" """ @@ -48,6 +49,14 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous",) + + + def require_kind(self): + return ("homogeneous",) + + def quadrature(self, packet, operator=None, summed=False, component=None, diag_component=None, diagonal=False, eval_at_once=False): r"""Delegates the evaluation of :math:`\langle\Psi|f|\Psi\rangle` for a general function :math:`f(x)` with :math:`x \in \mathbb{R}^D`. diff --git a/src/WaveBlocksND/HomogeneousInnerProductLCWP.py b/src/WaveBlocksND/HomogeneousInnerProductLCWP.py index 3ca00ba7..7dc375f3 100644 --- a/src/WaveBlocksND/HomogeneousInnerProductLCWP.py +++ b/src/WaveBlocksND/HomogeneousInnerProductLCWP.py @@ -13,11 +13,12 @@ from numpy import zeros, complexfloating, conjugate, transpose, dot, cumsum, sum, reshape, array, repeat from InnerProduct import InnerProduct +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["HomogeneousInnerProductLCWP"] -class HomogeneousInnerProductLCWP(InnerProduct): +class HomogeneousInnerProductLCWP(InnerProduct, InnerProductCompatibility): def __init__(self, delegate=None, oracle=None): r""" @@ -55,6 +56,14 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous",) + + + def require_kind(self): + return ("homogeneous",) + + def get_oracle(self): r"""Return the sparsity oracle in use or ``None``. """ diff --git a/src/WaveBlocksND/InhomogeneousInnerProduct.py b/src/WaveBlocksND/InhomogeneousInnerProduct.py index 679d634b..fe8fda53 100644 --- a/src/WaveBlocksND/InhomogeneousInnerProduct.py +++ b/src/WaveBlocksND/InhomogeneousInnerProduct.py @@ -12,11 +12,12 @@ from numpy import zeros, complexfloating, sum, cumsum from InnerProduct import InnerProduct +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["InhomogeneousInnerProduct"] -class InhomogeneousInnerProduct(InnerProduct): +class InhomogeneousInnerProduct(InnerProduct, InnerProductCompatibility): r""" """ @@ -48,6 +49,14 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous", "inhomogeneous",) + + + def require_kind(self): + return ("inhomogeneous",) + + def quadrature(self, pacbra, packet=None, operator=None, summed=False, component=None, diag_component=None, diagonal=False, eval_at_once=False): r"""Delegates the evaluation of :math:`\langle\Psi|f|\Psi^\prime\rangle` for a general function :math:`f(x)` with :math:`x \in \mathbb{R}^D`. diff --git a/src/WaveBlocksND/InhomogeneousInnerProductLCWP.py b/src/WaveBlocksND/InhomogeneousInnerProductLCWP.py index f6a98eac..b06b544d 100644 --- a/src/WaveBlocksND/InhomogeneousInnerProductLCWP.py +++ b/src/WaveBlocksND/InhomogeneousInnerProductLCWP.py @@ -13,11 +13,12 @@ from numpy import zeros, complexfloating, conjugate, transpose, dot, sum, cumsum, array, repeat, reshape from InnerProduct import InnerProduct +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["InhomogeneousInnerProductLCWP"] -class InhomogeneousInnerProductLCWP(InnerProduct): +class InhomogeneousInnerProductLCWP(InnerProduct, InnerProductCompatibility): def __init__(self, delegate=None, oracle=None): r""" @@ -55,6 +56,14 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous", "inhomogeneous",) + + + def require_kind(self): + return ("inhomogeneous",) + + def get_oracle(self): r"""Return the sparsity oracle in use or ``None``. """ diff --git a/src/WaveBlocksND/InnerProduct.py b/src/WaveBlocksND/InnerProduct.py index bac21b23..7e621fc2 100644 --- a/src/WaveBlocksND/InnerProduct.py +++ b/src/WaveBlocksND/InnerProduct.py @@ -11,10 +11,12 @@ @license: Modified BSD License """ +from InnerProductCompatibility import InnerProductCompatibility + __all__ = ["InnerProduct", "InnerProductException"] -class InnerProduct(object): +class InnerProduct(InnerProductCompatibility): r"""This class is an abstract interface to inner products in general. """ @@ -45,7 +47,11 @@ def set_delegate(self, delegate): :param delegate: The new :py:class:`Quadrature` instance. """ # TODO: Allow a list of quads, one quad for each component of Psi - self._delegate = delegate + if delegate is not None: + if self.compatible(self, delegate): + self._delegate = delegate + else: + self._delegate = delegate def get_delegate(self): diff --git a/src/WaveBlocksND/InnerProductCompatibility.py b/src/WaveBlocksND/InnerProductCompatibility.py new file mode 100644 index 00000000..e76bcce8 --- /dev/null +++ b/src/WaveBlocksND/InnerProductCompatibility.py @@ -0,0 +1,32 @@ +"""The WaveBlocks Project + +This class abstracts compatibility conditions on nested inner products. + +@author: R. Bourquin +@copyright: Copyright (C) 2013 R. Bourquin +@license: Modified BSD License +""" + +__all__ = ["InnerProductCompatibility"] + + +class InnerProductCompatibility(object): + r"""This class abstracts compatibility conditions on nested inner products. + """ + + def get_kind(self): + return None + + def require_kind(self): + return None + + def compatible(self, ipouter, ipinner): + r""" + """ + inner = set(ipinner.get_kind()) + outer = set(ipouter.require_kind()) + + if not len(outer.intersection(inner)) == 0: + return True + else: + raise ValueError("Can not nest inner product with kind "+str(inner)+" into inner product which requires "+str(outer)) diff --git a/src/WaveBlocksND/NSDInhomogeneous.py b/src/WaveBlocksND/NSDInhomogeneous.py index e9f4c836..c77f3d29 100644 --- a/src/WaveBlocksND/NSDInhomogeneous.py +++ b/src/WaveBlocksND/NSDInhomogeneous.py @@ -15,11 +15,12 @@ from scipy.linalg import inv, schur, det, sqrtm from Quadrature import Quadrature +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["NSDInhomogeneous"] -class NSDInhomogeneous(Quadrature): +class NSDInhomogeneous(Quadrature, InnerProductCompatibility): r""" """ @@ -30,10 +31,7 @@ def __init__(self, QR=None): :param QR: Typically one uses an instance of :py:class:`GaussHermiteOriginalQR`. """ # Pure convenience to allow setting of quadrature rule in constructor - if QR is not None: - self.set_qr(QR) - else: - self._QR = None + self.set_qr(QR) def __str__(self): @@ -52,6 +50,10 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous", "inhomogeneous") + + def initialize_packet(self, pacbra, packet=None): r"""Provide the wavepacket parts of the inner product to evaluate. Since the quadrature is inhomogeneous, different wavepackets can be diff --git a/src/WaveBlocksND/SymbolicIntegral.py b/src/WaveBlocksND/SymbolicIntegral.py index bc5c1f79..6b6a3dfd 100644 --- a/src/WaveBlocksND/SymbolicIntegral.py +++ b/src/WaveBlocksND/SymbolicIntegral.py @@ -17,11 +17,12 @@ from InnerProduct import InnerProductException from Quadrature import Quadrature +from InnerProductCompatibility import InnerProductCompatibility __all__ = ["SymbolicIntegral"] -class SymbolicIntegral(Quadrature): +class SymbolicIntegral(Quadrature, InnerProductCompatibility): r""" """ @@ -53,6 +54,10 @@ def get_description(self): return d + def get_kind(self): + return ("homogeneous", "inhomogeneous") + + def initialize_packet(self, pacbra, packet=None): r"""Provide the wavepacket parts of the inner product to evaluate. Since the formula is for the inhomogeneous case explicitly, different diff --git a/src/WaveBlocksND/__init__.py b/src/WaveBlocksND/__init__.py index 06290822..ddfd56f7 100644 --- a/src/WaveBlocksND/__init__.py +++ b/src/WaveBlocksND/__init__.py @@ -81,6 +81,8 @@ from HomogeneousInnerProductLCWP import HomogeneousInnerProductLCWP from InhomogeneousInnerProductLCWP import InhomogeneousInnerProductLCWP +#from InnerProductCompatibility import InnerProductCompatibility + from Quadrature import Quadrature from DirectQuadrature import DirectQuadrature from DirectHomogeneousQuadrature import DirectHomogeneousQuadrature