diff --git a/ml.cabal b/ml.cabal index cc5cefb..0943ff9 100644 --- a/ml.cabal +++ b/ml.cabal @@ -26,6 +26,7 @@ source-repository head library exposed-modules: MLambda.Index + MLambda.Linear MLambda.Matrix MLambda.NDArr MLambda.TypeLits @@ -39,6 +40,7 @@ library RankNTypes PolyKinds UndecidableInstances + NoStarIsType ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints build-depends: base >=4.7 && <5 @@ -48,6 +50,8 @@ library , massiv , mtl , netlib-ffi + , singletons + , singletons-base , template-haskell , vector default-language: GHC2024 @@ -56,6 +60,9 @@ test-suite ml-test type: exitcode-stdio-1.0 main-is: Spec.hs other-modules: + Test.MLambda.Matrix + Test.MLambda.NDArr + Test.MLambda.Utils Paths_ml autogen-modules: Paths_ml @@ -67,16 +74,20 @@ test-suite ml-test RankNTypes PolyKinds UndecidableInstances + NoStarIsType ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N build-depends: base >=4.7 && <5 , blas-ffi , deepseq + , falsify , haskell-src-meta , massiv , ml , mtl , netlib-ffi + , singletons + , singletons-base , tasty , tasty-hunit , template-haskell @@ -98,6 +109,7 @@ benchmark ml-bench RankNTypes PolyKinds UndecidableInstances + NoStarIsType ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N build-depends: base >=4.7 && <5 @@ -110,6 +122,8 @@ benchmark ml-bench , netlib-ffi , normaldistribution , random + , singletons + , singletons-base , tasty-bench , template-haskell , vector diff --git a/package.yaml b/package.yaml index 50126bc..6b67cea 100644 --- a/package.yaml +++ b/package.yaml @@ -29,6 +29,9 @@ dependencies: - template-haskell - vector - haskell-src-meta +- singletons +- singletons-base +# - ghc-typelits-natnormalise ghc-options: - -Wall @@ -72,6 +75,9 @@ tests: - ml - tasty - tasty-hunit + - falsify + - singletons + - singletons-base language: GHC2024 default-extensions: @@ -80,3 +86,4 @@ default-extensions: - RankNTypes - PolyKinds - UndecidableInstances +- NoStarIsType diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 2aae007..8534137 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -1,5 +1,6 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE ViewPatterns #-} - -- | -- Module : MLambda.Index -- Description : Type of multidimensional array indices. @@ -12,18 +13,28 @@ -- This module contains definition of 'Index' type of multidimensional array -- indices along with its instances and public interface. module MLambda.Index - ( Index (E, (:.)) - , consIndex - , concatIndex + ( -- * Index type + Index (E) + , pattern (:.) + -- * Index instances + , Ix(..) , IndexI (..) - , Ix - , inst + , pattern IxI + , concatIndexI + -- * Lifting of runtime dimesions into indices + , singToIndexI + , withIx + -- * Index operations + , concatIndex ) where -import Data.Proxy (Proxy (..)) - import MLambda.TypeLits +import Data.Bool.Singletons +import Data.List.Singletons +import Data.Singletons +import GHC.TypeLits.Singletons hiding (natVal) + -- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@. -- Instances are provided for convenient use: -- @@ -37,8 +48,20 @@ import MLambda.TypeLits -- order their respective elements are laid out in memory. data Index (dim :: [Natural]) where E :: Index '[] - I :: Int -> Index '[n] - (:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds) + ICons :: Int -> Index ds -> Index (d : ds) + +{-# COMPLETE I #-} +pattern I :: Int -> Index '[d] +pattern I m = ICons m E + +viewI :: Index (d:ds) -> (Index '[d], Index ds) +viewI (ICons x xs) = (I x, xs) + +{-# COMPLETE (:.) #-} +-- | Prepend a single-dimensional index to multi-dimensional one +pattern (:.) :: Index '[d] -> Index ds -> Index (d:ds) +pattern x :. xs <- (viewI -> (x, xs)) + where (ICons x E) :. xs = ICons x xs deriving instance Eq (Index dim) deriving instance Ord (Index dim) @@ -59,8 +82,8 @@ instance Bounded (Index '[]) where maxBound = E instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where - minBound = 0 `consIndex` minBound - maxBound = (-1) `consIndex` maxBound + minBound = 0 :. minBound + maxBound = (-1) :. maxBound instance Enum (Index '[]) where toEnum = const E @@ -68,34 +91,17 @@ instance Enum (Index '[]) where instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => Enum (Index (n:d)) where - toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q `consIndex` toEnum r - fromEnum = \case - I i -> i - I q :. r -> q * enumSize (Index d) + fromEnum r - succ = \case - h@(I i) | h == maxBound -> error "Undefined succ" - | otherwise -> I (succ i) - h :. t | t == maxBound -> succ h :. minBound - | otherwise -> h :. succ t - pred = \case - h@(I i) | h == minBound -> error "Undefined pred" - | otherwise -> I (pred i) - h :. t | t == minBound -> pred h :. maxBound - | otherwise -> h :. pred t - --- | Prepend a single-dimensional index to multi-dimensional one -consIndex :: Index '[x] -> Index xs -> Index (x : xs) -consIndex (I x) = \case - E -> I x - I y -> I x :. I y - I y :. xs -> I x :. I y :. xs + toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r + fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r + succ (h :. t) | t == maxBound = succ h :. minBound + | otherwise = h :. succ t + pred (h :. t) | t == minBound = pred h :. maxBound + | otherwise = h :. pred t -- | Concatenate two indices together -concatIndex :: Index xs -> Index ys -> Index (xs ++ ys) -concatIndex = \case - E -> id - I x -> consIndex (I x) - I x :. xs -> consIndex (I x) . concatIndex xs +concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys) +concatIndex E = id +concatIndex (ICons x xs) = ICons x . concatIndex xs -- | A helper type that holds instances for everything you can get by pattern matching on -- an @`Index`@ value. If you need to get instances for a head/tail of an @`Index`@, @@ -110,8 +116,31 @@ data IndexI (dim :: [Natural]) where (:.=) :: (KnownNat n, 1 <= n, Ix ds) => Proxy n -> IndexI ds -> IndexI (n : ds) +data IxInstance (dim :: [Natural]) where + IxInstance :: Ix dim => IxInstance dim + +viewII :: IndexI dim -> IxInstance dim +viewII (_ :.= _) = IxInstance +viewII EI = IxInstance + +-- | A simpler pattern for @`IndexI`@ in case you don't need instances +-- for suffixes. +{-# COMPLETE IxI #-} +pattern IxI :: () => Ix dim => IndexI dim +pattern IxI <- (viewII -> IxInstance) where + IxI = inst + +deriving instance Show (IndexI dim) + infixr 5 :.= +-- | Concatenate two @`IndexI`@s to get the instances for concatenated +-- dimensions. +concatIndexI :: IndexI d1 -> IndexI d2 -> IndexI (d1 ++ d2) +concatIndexI EI d2 = d2 +concatIndexI (i1 :.= d1) d2 = case concatIndexI d1 d2 of + IxI -> i1 :.= IxI + -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain -- a value of @`IndexI`@. class (Bounded (Index dim), Enum (Index dim)) => Ix dim where @@ -123,3 +152,19 @@ instance Ix '[] where instance (KnownNat n, 1 <= n, Ix ds) => Ix (n:ds) where inst = Proxy :.= inst + +-- | Convert @`Sing`@ of a dimension list to a witness for @`Ix`@ instances +-- of that index. +singToIndexI :: forall dim . Sing dim -> Maybe (IndexI dim) +singToIndexI SNil = Just EI +singToIndexI (SCons sn@SNat sr) = + case (sing @1 %<=? sn, singToIndexI sr) of + (STrue, Just r@IxI) -> Just $ Proxy :.= r + _ -> Nothing + +-- | Simple interface for lifting runtime dimensions into type level. +-- Provides the given continuation with the type of said dimensions and their +-- @`Ix`@ instance. +withIx :: Demote [Natural] -> (forall dim . Ix dim => Proxy dim -> r) -> Maybe r +withIx d f = withSomeSing d \(singToIndexI -> r) -> + flip fmap r \case (IxI :: IndexI dim) -> f $ Proxy @dim diff --git a/src/MLambda/Linear.hs b/src/MLambda/Linear.hs new file mode 100644 index 0000000..c993962 --- /dev/null +++ b/src/MLambda/Linear.hs @@ -0,0 +1,45 @@ +-- | +-- Module : MLambda.Linear +-- Description : Linear maps. +-- Copyright : (c) neclitoris, TurtlePU, 2025 +-- License : BSD-3-Clause +-- Maintainer : nas140301@gmail.com +-- Stability : experimental +-- Portability : portable +-- +-- This module describes linear maps. +module MLambda.Linear (Additive(..), Module(..), LinearMap, LinearMap') where + +import Data.Kind + +-- | @Additive@ describes things you can add. +class Additive m where + add :: m -> m -> m + zero :: m + +-- | @Module@ is a class describing modules over a ring @r@. +-- Since @`Num`@ describes commutative rings, this is both a right +-- and a left module. +class (Num r, Additive m) => Module r m where + modMult :: r -> m -> m + +instance {-# OVERLAPPABLE #-} Num r => Additive r where + add = (+) + zero = 0 + +instance Num r => Module r r where + modMult = (*) + +-- | This version of LinearMap allows one to pass an arbitrary additional +-- constraint on the type mapped over. Note that this can break things: +-- @LinearMap' ((~) r) s m r@ will allow you to do naughty things. +type LinearMap' (e :: Type -> Constraint) t s r = + forall m . (Module r m, e m) => t m -> s m + +class Empty t + +instance Empty t + +-- | @LinearMap@ forces linearity (in mathematical sense) in its argument +-- since you can't multiply an element of @m@ by itself, only by some @a@. +type LinearMap t s r = LinearMap' Empty t s r diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index c9b8af3..87d6a3d 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -20,10 +20,18 @@ module MLambda.Matrix , crossMassiv -- * Matrix creation , mat + , eye + , rep + , act + -- * Shape manipulation + , transpose ) where import MLambda.Foreign.Utils (asFPtr, asPtr, char) -import MLambda.NDArr +import MLambda.Index +import MLambda.Linear +import MLambda.NDArr hiding (concat, foldr, map, zipWith) +import MLambda.NDArr qualified as NDArr import MLambda.TypeLits import Control.Applicative @@ -176,3 +184,27 @@ matE s = do vec <- Storable.unsafeFreeze mvec pure $ unsafeMkNDArr @'[$tm, $tn] @Double vec |] + +-- | Identity matrix. +eye :: (KnownNat n, 1 <= n, Storable e, Num e) => NDArr '[n, n] e +eye = fromIndex go where + go (i :. j) | i == j = 1 + | otherwise = 0 + +-- | A matrix that corresponds to a linear map of vector spaces. +rep :: forall m n e . + ( KnownNat m, 1 <= m + , KnownNat n, 1 <= n + , Storable e, Num e) + => LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e -> NDArr '[n, m] e +rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`)) + +-- | A linear map of vector spaces that corresponds to a matrix. +act :: forall m n e . (KnownNat m, 1 <= m , Storable e) + => NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e +act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a) + +-- | Matrix transposition. +transpose :: (KnownNat m, 1 <= m , KnownNat n, 1 <= n , Storable e) + => NDArr '[m, n] e -> NDArr '[n, m] e +transpose m = fromIndex \(i :. j) -> m `at` (j :. i) diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index f60efc1..c41b1e7 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -26,28 +26,42 @@ module MLambda.NDArr , at , row , rows + -- * Array operations + , toList + , concat + , map + , zipWith + , foldr -- * Array composition , Stack , Stacks + , StackWitness(..) , stack + , stackWithWitness -- * Unsafe API , unsafeMkNDArr + -- * Shape manipulation + , reshape + , prependDim + , stripDim ) where import MLambda.Index +import MLambda.Linear import MLambda.TypeLits import Control.DeepSeq (NFData) import Control.Monad.ST (runST) import Data.Foldable (forM_) -import Data.List (intersperse) -import Data.Proxy +import Data.List qualified as List +import Data.List.Singletons +import Data.Singletons import Data.Vector.Storable qualified as Storable import Data.Vector.Storable.Mutable qualified as Mutable import Foreign.Ptr (Ptr, castPtr) import Foreign.Storable (Storable (..)) import GHC.TypeError (ErrorMessage (..), TypeError) -import Prelude hiding (concat, zipWith) +import Prelude hiding (concat, foldr, map, zipWith) -- | @NDArr [n1,...,nd] e@ is a type of arrays with dimensions @n1 x ... x nd@ -- consisting of elements of type @e@. @@ -63,18 +77,18 @@ newtype NDArr (dim :: [Natural]) e = MkNDArr { unsafeMkNDArr :: forall dim e. Storable.Vector e -> NDArr dim e unsafeMkNDArr = MkNDArr -instance (Show e, Storable e) => Show (NDArr '[n] e) where - show = show . runNDArr - -instance (Ix (n:a:r), Show (NDArr (a:r) e), Storable e) => - Show (NDArr (n:a:r) e) where +instance (Ix dim, Show e, Storable e) => Show (NDArr dim e) where showsPrec _ = - case inst @(n:a:r) of - _ :.= _ -> (showString "[" .) . (. showString "]") - . foldl' (.) id - . intersperse (showString ",\n") - . map shows - . toList . rows @'[n] + case inst @dim of + _ :.= _ -> go + EI -> shows . (Storable.! 0) . runNDArr + where + go :: forall n r e' . (Ix r, Show (NDArr r e'), Storable e') => NDArr (n:r) e' -> ShowS + go = (showString "[" .) . (. showString "]") + . foldl' (.) id + . List.intersperse (showString ",\n") + . List.map shows + . toList . rows @'[n] instance (Ix d, Storable e) => Storable (NDArr d e) where sizeOf _ = sizeOf (undefined :: e) * enumSize (Index d) @@ -82,6 +96,13 @@ instance (Ix d, Storable e) => Storable (NDArr d e) where peek (castPtr -> (ptr :: Ptr e)) = fromIndexM (peekElemOff ptr . fromEnum) poke (castPtr -> (ptr :: Ptr e)) = (`Storable.iforM_` pokeElemOff ptr) . runNDArr +instance (Ix d, Storable r, Num r) => Additive (NDArr d r) where + add = zipWith (+) + zero = MkNDArr $ Storable.replicate (enumSize (Index d)) 0 + +instance (Ix d, Storable r, Num r) => Module r (NDArr d r) where + modMult r = MkNDArr . Storable.map (r *) . runNDArr + -- | Construct an array from a function @f@ that maps indices to elements. -- This function is strict and therefore @f@ can't refer to the result of -- @fromIndex f@. @@ -101,44 +122,56 @@ fromIndexM f = do vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec --- | Access array element by its index. This is a total function. +-- | O(1). Access array element by its index. This is a total function. at :: (Storable e, Ix dim) => NDArr dim e -> Index dim -> e (MkNDArr v) `at` i = v Storable.! fromEnum i infixl 9 `at` --- | Extract a "row" from the array. If you're used to C or numpy arrays, +-- | O(1). Extract a "row" from the array. If you're used to C or numpy arrays, -- this is similar to @a[i]@. row :: forall d1 d2 e. (Ix d1, Ix d2, Storable e) => Index d1 -> NDArr (d1 ++ d2) e -> NDArr d2 e row i a = rows a `at` i --- | Extract all "rows" from the array as an array. +-- | O(1). Extract all "rows" from the array as an array. rows :: forall d1 d2 e. (Ix d2, Storable e) => NDArr (d1 ++ d2) e -> NDArr d1 (NDArr d2 e) rows = MkNDArr . Storable.unsafeCast . runNDArr +-- | O(array_size). Extract @`NDArr`@ elements as a list in the order they are +-- laid out in memory. toList :: Storable e => NDArr d e -> [e] toList = Storable.toList . runNDArr +-- | O(1). Flatten an array of arrays into a single array. concat :: forall d1 d2 e. (Ix d2, Storable e) => NDArr d1 (NDArr d2 e) -> NDArr (d1 ++ d2) e concat = MkNDArr . Storable.unsafeCast . runNDArr +-- | O(array_size). Transform elements of an array. +map :: (Storable a, Storable b) => (a -> b) -> NDArr d a -> NDArr d b +map f = MkNDArr . Storable.map f . runNDArr + +-- | O(array_size). Combine two arrays into one element-by-element. zipWith :: (Storable a, Storable b, Storable c) => (a -> b -> c) -> NDArr d a -> NDArr d b -> NDArr d c zipWith f (MkNDArr xs) (MkNDArr ys) = MkNDArr (Storable.zipWith f xs ys) +-- | O(array_size). Reduce array using an accumulator function and initial value. +foldr :: Storable a => (a -> r -> r) -> r -> NDArr d a -> r +foldr f r = Storable.foldr f r . runNDArr + vstack :: Storable e => NDArr (k : d) e -> NDArr (l : d) e -> NDArr ((k + l) : d) e vstack (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys) -- | A type family which computes the resulting size of a stacked array. -type Stack n d e = StackImpl (StackError n d e) (Peano n) d e +type Stack n d e = StackImpl (StackError n d e) n d e type StackError n d e = Text "Not enough dimensions to stack along axis " :<>: ShowType n @@ -154,12 +187,13 @@ type family StackImpl msg i d e where class Stacks i dim1 dim2 dimr where stacks :: StackWitness i dim1 dim2 dimr +-- | A witness that carries the constraints needed to stack arrays together. data StackWitness i d1 d2 dr where SW :: ( KnownNat k, KnownNat l, KnownNat (k + l), Ix t , 1 <= k, 1 <= l, 1 <= k + l ) => Proxy '(s, k, l, t) -> - StackWitness (Length s) (s ++ (k : t)) (s ++ (l : t)) (s ++ ((k + l) : t)) + StackWitness (PLength s) (s ++ (k : t)) (s ++ (l : t)) (s ++ ((k + l) : t)) instance ( KnownNat n, KnownNat m, KnownNat k, Ix d @@ -171,11 +205,31 @@ instance Stacks i d e r => Stacks (PS i) (n : d) (n : e) (n : r) where stacks = case stacks @i of SW (Proxy @'(s, k, l, t)) -> SW (Proxy @'(n : s, k, l, t)) +-- | Sometimes you need to create a witness yourself for this to work. +stackWithWitness :: Storable e => StackWitness n d1 d2 dr + -> NDArr d1 e -> NDArr d2 e -> NDArr dr e +stackWithWitness (SW (Proxy @'(s, k, l, t))) xs ys = + concat $ zipWith vstack (rows @s @(k : t) xs) (rows @s @(l : t) ys) + -- | @stack i@ stacks arrays along the axis @i@. All other axes are required -- to be the same lengths. stack :: - forall n -> (Stacks (Peano n) d1 d2 (Stack n d1 d2), Storable e) => - NDArr d1 e -> NDArr d2 e -> NDArr (Stack n d1 d2) e -stack n xs ys = case stacks @(Peano n) of - SW (Proxy @'(s, k, l, t)) -> - concat $ zipWith vstack (rows @s @(k : t) xs) (rows @s @(l : t) ys) + forall n -> (Stacks (Peano n) d1 d2 (Stack (Peano n) d1 d2), Storable e) => + NDArr d1 e -> NDArr d2 e -> NDArr (Stack (Peano n) d1 d2) e +stack n = stackWithWitness (stacks @(Peano n)) + +type family Size d where + Size '[] = 1 + Size (x:xs) = x * Size xs + +-- | Change the shape of an array. +reshape :: forall d2 -> (Size d1 ~ Size d2) => NDArr d1 e -> NDArr d2 e +reshape _ = MkNDArr . runNDArr + +-- | Prepend a single dimension of size 1. +prependDim :: NDArr d e -> NDArr (1:d) e +prependDim = MkNDArr . runNDArr + +-- | Remove a dimension of size 1 from the front. +stripDim :: NDArr (1:d) e -> NDArr d e +stripDim = MkNDArr . runNDArr diff --git a/src/MLambda/TypeLits.hs b/src/MLambda/TypeLits.hs index 5f0b54b..7f0631d 100644 --- a/src/MLambda/TypeLits.hs +++ b/src/MLambda/TypeLits.hs @@ -17,9 +17,8 @@ module MLambda.TypeLits , natVal , enumSize , Unify - , type (++) , PNat (..) - , Length + , PLength , Peano , RNat (..) , RPNat (..) @@ -45,18 +44,13 @@ type family Unify n a b where Unify _ a a = a Unify n a b = TypeError (Text n :<>: Text " are not equal:" :$$: ShowType a :$$: ShowType b) --- | Type-level list concatenation. -type family xs ++ ys where - '[] ++ ys = ys - (x:xs) ++ ys = x : xs ++ ys - -- | Peano naturals. data PNat = PZ | PS PNat -- | Compute length of a type-level list as a Peano natural. -type family Length xs where - Length '[] = PZ - Length (_:xs) = PS (Length xs) +type family PLength xs where + PLength '[] = PZ + PLength (_:xs) = PS (PLength xs) -- | Compute Peano representation from type-level natural. type family Peano n where diff --git a/stack.yaml b/stack.yaml index 2dc9a56..0dd9cb2 100644 --- a/stack.yaml +++ b/stack.yaml @@ -34,10 +34,8 @@ packages: # These entries can reference officially published versions as well as # forks / in-progress versions pinned to a git hash. For example: # -# extra-deps: -# - acme-missiles-0.3 -# - git: https://github.com/commercialhaskell/stack.git -# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a +extra-deps: +- falsify-0.2.0@sha256:af5c4142095d05775236c8e18e827d540ed22f57b3b74b2268b32040b514ac88,5451 # # extra-deps: [] @@ -64,3 +62,6 @@ packages: # # Allow a newer minor version of GHC than the snapshot specifies # compiler-check: newer-minor +allow-newer-deps: +- falsify +allow-newer: true diff --git a/test/Spec.hs b/test/Spec.hs index c18740c..fcbb023 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,5 +1,7 @@ {-# LANGUAGE QuasiQuotes #-} +import Test.MLambda.Matrix +import Test.MLambda.NDArr import Test.Tasty import Test.Tasty.HUnit @@ -26,4 +28,6 @@ tests :: TestTree tests = testGroup "Tests" [ testCase "Mul" $ a `cross` b @?= c + , testNDArr + , testMatrix ] diff --git a/test/Test/MLambda/Matrix.hs b/test/Test/MLambda/Matrix.hs new file mode 100644 index 0000000..7b5d525 --- /dev/null +++ b/test/Test/MLambda/Matrix.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE TypeAbstractions #-} +module Test.MLambda.Matrix (testMatrix) where + +import MLambda.Matrix +import MLambda.NDArr as NDArr +import MLambda.TypeLits + +import Test.MLambda.Utils + +import Control.Monad +import Data.Bool.Singletons +import Data.Singletons +import GHC.TypeLits.Singletons +import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Predicate qualified as Pred +import Test.Tasty +import Test.Tasty.Falsify + +propId :: Property () +propId = do + (m, n) <- gen $ (,) <$> genSz <*> genSz + (SomeSing (SNat @m)) <- pure $ toSing m + (SomeSing (SNat @n)) <- pure $ toSing n + STrue <- pure (sing @1 %<=? sing @m) + STrue <- pure (sing @1 %<=? sing @n) + a <- gen $ genNDArr @'[n, m] genDouble + let a' = rep $ act a + assert $ Pred.eq .$ ("rep . act", a') .$ ("id", a) + +propCompose :: Property () +propCompose = do + (m, n) <- gen $ (,) <$> genSz <*> genSz + (SomeSing (SNat @m)) <- pure $ toSing m + (SomeSing (SNat @n)) <- pure $ toSing n + (SomeSing (SNat @k)) <- pure $ toSing n + STrue <- pure (sing @1 %<=? sing @m) + STrue <- pure (sing @1 %<=? sing @n) + STrue <- pure (sing @1 %<=? sing @k) + a <- gen $ genNDArr @'[n, m] genDouble + b <- gen $ genNDArr @'[m, k] genDouble + let c = a `cross` b + c' = rep (act a . act b) + eps = 10**(-7) + diff = NDArr.zipWith (\x y -> abs (x - y)) c c' + forM_ [minBound..maxBound] $ \i -> + when (diff `at` i > eps) $ testFailed "matrices unequal" + + +testMatmul :: TestTree +testMatmul = testGroup "matmul" + [ testProperty "rep . act = id" propId + , testProperty "rep (act a . act b) = a @ b" propCompose + ] + +testMatrix :: TestTree +testMatrix = testGroup "Matrix" [testMatmul] diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs new file mode 100644 index 0000000..272bf8a --- /dev/null +++ b/test/Test/MLambda/NDArr.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE ViewPatterns #-} +module Test.MLambda.NDArr (testNDArr) where + +import MLambda.Index +import MLambda.NDArr +import MLambda.TypeLits + +import Test.MLambda.Utils + +import Control.Monad +import Data.Bool.Singletons +import Data.List.Singletons (type (++)) +import Data.Proxy +import Data.Singletons +import GHC.TypeLits.Singletons +import Prelude.Singletons ((%+)) +import Test.Falsify.Generator qualified as Gen +import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Predicate qualified as Pred +import Test.Tasty +import Test.Tasty.Falsify + +propAtDotFromIndex :: Property () +propAtDotFromIndex = do + dim <- gen $ genDim 0 5 + SomeSing sdim <- pure $ toSing dim + Just (IxI @dim) <- pure $ singToIndexI sdim + Fn f <- gen $ Gen.fun genInt + i <- gen (genIndex @dim) + assert $ Pred.eq + .$ ("at . fromIndex", (at . fromIndex) f i) + .$ ("id", f i) + +propFromIndexDotAt :: Property () +propFromIndexDotAt = do + dim <- gen $ genDim 0 5 + SomeSing sdim <- pure $ toSing dim + Just (IxI @dim) <- pure $ singToIndexI sdim + arr <- gen $ genNDArr @dim genInt + assert $ Pred.eq + .$ ("fromIndex . at", (fromIndex . at) arr) + .$ ("id", arr) + +testFromIndex :: TestTree +testFromIndex = testGroup "fromIndex" + [ testProperty "at . fromIndex = id" propAtDotFromIndex + , testProperty "fromIndex . at = id" propFromIndexDotAt] + +propStack :: Property () +propStack = do + dimsuff <- gen $ genDim 0 3 + dimpref <- gen $ genDim 0 2 + dimmid1 <- gen genSz + dimmid2 <- gen genSz + SomeSing (singToIndexI -> Just (IxI @p)) <- pure $ toSing dimsuff + SomeSing (singToIndexI -> Just (IxI @s)) <- pure $ toSing dimpref + SomeSing sm1@(SNat @m1) <- pure $ toSing dimmid1 + SomeSing sm2@(SNat @m2) <- pure $ toSing dimmid2 + STrue <- pure $ sing @1 %<=? sm1 + STrue <- pure $ sing @1 %<=? sm2 + SNat <- pure $ sm1 %+ sm2 + STrue <- pure $ sing @1 %<=? sm1 %+ sm2 + IxI <- pure $ concatIndexI (IxI @p) (IxI @(m1 : s)) + IxI <- pure $ concatIndexI (IxI @p) (IxI @(m2 : s)) + IxI <- pure $ concatIndexI (IxI @p) (IxI @((m1 + m2) : s)) + arr1 <- gen $ genNDArr @(p ++ (m1 : s)) genInt + arr2 <- gen $ genNDArr @(p ++ (m2 : s)) genInt + let arr3 = stackWithWitness (SW (Proxy @'(p, m1, m2, s))) arr1 arr2 + s <- gen $ genIndex @s + p <- gen $ genIndex @p + m <- gen $ genIndex @'[m1 + m2] + let i1 = p `concatIndex` ((toEnum (fromEnum m) :: Index '[m1]) :. s) + i2 = p `concatIndex` + ((toEnum (fromEnum m - enumSize (Index [m1])) :: Index '[m2]) :. s) + i3 = p `concatIndex` (m :. s) + if fromEnum m <= fromEnum (maxBound @(Index '[m1])) + then when (arr1 `at` i1 /= arr3 `at` i3) $ testFailed "mid index in lhs" + else when (arr2 `at` i2 /= arr3 `at` i3) $ testFailed "mid index in rhs" + +testStack :: TestTree +testStack = testProperty "stack" propStack + +testNDArr :: TestTree +testNDArr = testGroup "NDArr" [testFromIndex, testStack] diff --git a/test/Test/MLambda/Utils.hs b/test/Test/MLambda/Utils.hs new file mode 100644 index 0000000..b3fec3f --- /dev/null +++ b/test/Test/MLambda/Utils.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE TypeAbstractions #-} +module Test.MLambda.Utils where + +import MLambda.Index +import MLambda.NDArr +import MLambda.TypeLits + +import Data.Proxy +import Foreign.Storable +import Test.Falsify.Generator (Gen) +import Test.Falsify.Generator qualified as Gen +import Test.Falsify.Range qualified as Range + +genSz :: Gen Natural +genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) + +genDim :: Word -> Word -> Gen [Natural] +genDim a b = Gen.list (Range.between (a, b)) genSz + +genInt :: Gen Int +genInt = Gen.inRange $ Range.between (-1000, 1000) + +genDouble :: Gen Double +genDouble = Gen.inRange $ Range.fromProperFraction 64 + \(Range.ProperFraction d) -> 1 + 4 * d + +genIndex :: forall dim . Ix dim => Gen (Index dim) +genIndex = case inst @dim of + EI -> pure E + Proxy @h :.= (_ :: IndexI t) -> do + h <- genInt + t <- genIndex @t + pure ((toEnum h :: Index '[h]) :. t) + +genNDArr :: forall dim e . (Storable e, Ix dim) => Gen e -> Gen (NDArr dim e) +genNDArr g = do + Gen.Fn f <- Gen.fun g + pure $ fromIndex f + +instance Enum (Index dim) => Gen.Function (Index dim) where + function gb = Gen.functionMap fromEnum toEnum <$> Gen.function gb +