From b38005997b00955cc2de83777745ff423178ccfa Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:39:19 +0300 Subject: [PATCH 01/20] Add index witness construction from runtime sizes and use it for testing with falsify. --- ml.cabal | 8 +++++++ package.yaml | 5 ++++ src/MLambda/Index.hs | 44 ++++++++++++++++++++++++++++++++++ src/MLambda/NDArr.hs | 8 +++---- src/MLambda/TypeLits.hs | 6 +---- stack.yaml | 9 +++---- test/Spec.hs | 2 ++ test/Test/MLambda/NDArr.hs | 48 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 117 insertions(+), 13 deletions(-) create mode 100644 test/Test/MLambda/NDArr.hs diff --git a/ml.cabal b/ml.cabal index cc5cefb..9b6d1a8 100644 --- a/ml.cabal +++ b/ml.cabal @@ -48,6 +48,8 @@ library , massiv , mtl , netlib-ffi + , singletons + , singletons-base , template-haskell , vector default-language: GHC2024 @@ -56,6 +58,7 @@ test-suite ml-test type: exitcode-stdio-1.0 main-is: Spec.hs other-modules: + Test.MLambda.NDArr Paths_ml autogen-modules: Paths_ml @@ -72,11 +75,14 @@ test-suite ml-test base >=4.7 && <5 , blas-ffi , deepseq + , falsify , haskell-src-meta , massiv , ml , mtl , netlib-ffi + , singletons + , singletons-base , tasty , tasty-hunit , template-haskell @@ -110,6 +116,8 @@ benchmark ml-bench , netlib-ffi , normaldistribution , random + , singletons + , singletons-base , tasty-bench , template-haskell , vector diff --git a/package.yaml b/package.yaml index 50126bc..d9c9f7d 100644 --- a/package.yaml +++ b/package.yaml @@ -29,6 +29,8 @@ dependencies: - template-haskell - vector - haskell-src-meta +- singletons +- singletons-base ghc-options: - -Wall @@ -72,6 +74,9 @@ tests: - ml - tasty - tasty-hunit + - falsify + - singletons + - singletons-base language: GHC2024 default-extensions: diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 99f4175..94eeb82 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -1,3 +1,6 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : MLambda.Index -- Description : Type of multidimensional array indices. @@ -15,11 +18,20 @@ module MLambda.Index , concatIndex , IndexI (..) , Ix + , pattern IxI , inst + -- * Lifting of runtime dimesions into indices + , singToIndexI + , withIx ) where import MLambda.TypeLits +import Data.Bool.Singletons +import Data.List.Singletons +import Data.Singletons +import GHC.TypeLits.Singletons hiding (natVal) + -- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@. -- Instances are provided for convenient use: -- @@ -100,6 +112,20 @@ data IndexI (dim :: [Natural]) where (:.=) :: Ix (d : ds) => IndexI '[n] -> IndexI (d : ds) -> IndexI (n : d : ds) II :: (KnownNat n, 1 <= n) => IndexI '[n] +data IxInstance (dim :: [Natural]) where + IxInstance :: Ix dim => IxInstance dim + +viewII :: IndexI dim -> IxInstance dim +viewII (II :.= _) = IxInstance +viewII II = IxInstance + +{-# COMPLETE IxI #-} +pattern IxI :: () => Ix dim => IndexI dim +pattern IxI <- (viewII -> IxInstance) where + IxI = inst + +deriving instance Show (IndexI dim) + infixr 5 :.= -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain @@ -113,3 +139,21 @@ instance (KnownNat n, 1 <= n) => Ix '[n] where instance (KnownNat n, 1 <= n, Ix (d:ds)) => Ix (n:d:ds) where inst = II :.= inst + +-- | Convert @`Sing`@ of a +singToIndexI :: forall dim . Sing dim -> Maybe (IndexI dim) +singToIndexI (SCons sn@SNat SNil) = + case sing @1 %<=? sn of + STrue -> Just II + SFalse -> Nothing +singToIndexI (SCons sn@SNat sr@SCons{}) = + case (sing @1 %<=? sn, singToIndexI sr) of + (STrue, Just r@IxI) -> Just $ II :.= r + _ -> Nothing +singToIndexI _ = Nothing + +withIx :: Demote [Natural] -> (forall dim . Ix dim => Proxy dim -> r) -> Maybe r +withIx d f = withSomeSing d \(singToIndexI -> r) -> + flip fmap r \case (IxI :: IndexI dim) -> f $ Proxy @dim + + diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 30cfb38..bdadf0e 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -131,10 +131,10 @@ zipWith :: zipWith f (MkNDArr xs) (MkNDArr ys) = MkNDArr (Storable.zipWith f xs ys) data StackWitness i d1 d2 dr where - SZ :: (a + b ~ c, Ix (a : d), Ix (b : d), Ix (c : d)) - => Proxy '(a, b, c) -> Proxy d - -> StackWitness PZ (a : d) (b : d) (c : d) - SS :: ( Ix (a : s), Ix (b : s), Ix (c : s), a + b ~ c) + SZ :: (a + b ~ c, Ix (a : s), Ix (b : s), Ix (c : s)) + => Proxy '(a, b, c) -> Proxy s + -> StackWitness PZ (a : s) (b : s) (c : s) + SS :: (a + b ~ c, Ix (a : s), Ix (b : s), Ix (c : s)) => Proxy p -> Proxy '(a, b, c) -> Proxy s -> StackWitness (PS i) (p ++ (a : s)) (p ++ (b : s)) (p ++ (c : s)) diff --git a/src/MLambda/TypeLits.hs b/src/MLambda/TypeLits.hs index 426e51e..4574dd8 100644 --- a/src/MLambda/TypeLits.hs +++ b/src/MLambda/TypeLits.hs @@ -27,6 +27,7 @@ module MLambda.TypeLits , rpnat ) where +import Data.List.Singletons import Data.Proxy (Proxy (Proxy)) import GHC.TypeError (ErrorMessage (..), TypeError) import GHC.TypeNats hiding (natVal) @@ -44,11 +45,6 @@ type family Unify n a b where Unify _ a a = a Unify n a b = TypeError (Text n :<>: Text " are not equal:" :$$: ShowType a :$$: ShowType b) --- | Type-level list concatenation. -type family xs ++ ys where - '[] ++ ys = ys - (x:xs) ++ ys = x : xs ++ ys - -- | Peano naturals. data PNat = PZ | PS PNat diff --git a/stack.yaml b/stack.yaml index 2dc9a56..0dd9cb2 100644 --- a/stack.yaml +++ b/stack.yaml @@ -34,10 +34,8 @@ packages: # These entries can reference officially published versions as well as # forks / in-progress versions pinned to a git hash. For example: # -# extra-deps: -# - acme-missiles-0.3 -# - git: https://github.com/commercialhaskell/stack.git -# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a +extra-deps: +- falsify-0.2.0@sha256:af5c4142095d05775236c8e18e827d540ed22f57b3b74b2268b32040b514ac88,5451 # # extra-deps: [] @@ -64,3 +62,6 @@ packages: # # Allow a newer minor version of GHC than the snapshot specifies # compiler-check: newer-minor +allow-newer-deps: +- falsify +allow-newer: true diff --git a/test/Spec.hs b/test/Spec.hs index c18740c..e367a5a 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,5 +1,6 @@ {-# LANGUAGE QuasiQuotes #-} +import Test.MLambda.NDArr import Test.Tasty import Test.Tasty.HUnit @@ -26,4 +27,5 @@ tests :: TestTree tests = testGroup "Tests" [ testCase "Mul" $ a `cross` b @?= c + , testNDArr ] diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs new file mode 100644 index 0000000..c23dcf1 --- /dev/null +++ b/test/Test/MLambda/NDArr.hs @@ -0,0 +1,48 @@ +module Test.MLambda.NDArr (testNDArr) where + +import MLambda.Index +import MLambda.NDArr + +import Control.Monad +import Data.Proxy +import Data.Singletons +import Numeric.Natural +import Test.Falsify.Generator qualified as Gen +import Test.Falsify.Predicate qualified as Pred +import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Range qualified as Range +import Test.Tasty.Falsify +import Test.Tasty + +genSz :: Gen Natural +genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) + +genDim :: Gen [Natural] +genDim = Gen.list (Range.between (1, 5)) $ genSz + +genInt :: Gen Int +genInt = Gen.inRange $ Range.between (-1000, 1000) + +instance Enum (Index dim) => Gen.Function (Index dim) where + function gb = Gen.functionMap fromEnum toEnum <$> Gen.function gb + +propFromIndex :: Property () +propFromIndex = do + dim <- gen genDim + let Just p = withIx dim \(Proxy @dim) -> do + Fn f <- gen $ Gen.fun genInt + let is = [minBound..maxBound] :: [Index dim] + ndarr = fromIndex f + values = map f is + ndarrValues = map (ndarr `at`) is + pred = Pred.relatedBy + ("==", (foldl1 (||) .) . zipWith (==)) + assert $ pred .$ ("fromIndex f `at` i", ndarrValues) + .$ ("f i", values) + p + +testFromIndex :: TestTree +testFromIndex = testProperty "fromIndex" propFromIndex + +testNDArr :: TestTree +testNDArr = testGroup "NDArr" [testFromIndex] From ed5b3a538c05840f69229a503b9200fdb150b67a Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:42:49 +0300 Subject: [PATCH 02/20] Lint and format --- src/MLambda/Index.hs | 2 +- test/Test/MLambda/NDArr.hs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 94eeb82..1b38f29 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -117,7 +117,7 @@ data IxInstance (dim :: [Natural]) where viewII :: IndexI dim -> IxInstance dim viewII (II :.= _) = IxInstance -viewII II = IxInstance +viewII II = IxInstance {-# COMPLETE IxI #-} pattern IxI :: () => Ix dim => IndexI dim diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index c23dcf1..3ff3362 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -8,17 +8,17 @@ import Data.Proxy import Data.Singletons import Numeric.Natural import Test.Falsify.Generator qualified as Gen -import Test.Falsify.Predicate qualified as Pred import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Predicate qualified as Pred import Test.Falsify.Range qualified as Range -import Test.Tasty.Falsify import Test.Tasty +import Test.Tasty.Falsify genSz :: Gen Natural genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) genDim :: Gen [Natural] -genDim = Gen.list (Range.between (1, 5)) $ genSz +genDim = Gen.list (Range.between (1, 5)) genSz genInt :: Gen Int genInt = Gen.inRange $ Range.between (-1000, 1000) @@ -36,7 +36,7 @@ propFromIndex = do values = map f is ndarrValues = map (ndarr `at`) is pred = Pred.relatedBy - ("==", (foldl1 (||) .) . zipWith (==)) + ("==", (or .) . zipWith (==)) assert $ pred .$ ("fromIndex f `at` i", ndarrValues) .$ ("f i", values) p From a8939b3ab986c8f03d736fc3f7721c5cf4691f30 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:49:05 +0300 Subject: [PATCH 03/20] Reduce warnings --- test/Test/MLambda/NDArr.hs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index 3ff3362..be6d7f7 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -1,11 +1,11 @@ +{-# LANGUAGE TypeAbstractions #-} module Test.MLambda.NDArr (testNDArr) where import MLambda.Index import MLambda.NDArr -import Control.Monad +import Data.Maybe import Data.Proxy -import Data.Singletons import Numeric.Natural import Test.Falsify.Generator qualified as Gen import Test.Falsify.Predicate ((.$)) @@ -29,17 +29,16 @@ instance Enum (Index dim) => Gen.Function (Index dim) where propFromIndex :: Property () propFromIndex = do dim <- gen genDim - let Just p = withIx dim \(Proxy @dim) -> do - Fn f <- gen $ Gen.fun genInt - let is = [minBound..maxBound] :: [Index dim] - ndarr = fromIndex f - values = map f is - ndarrValues = map (ndarr `at`) is - pred = Pred.relatedBy - ("==", (or .) . zipWith (==)) - assert $ pred .$ ("fromIndex f `at` i", ndarrValues) - .$ ("f i", values) - p + let p = withIx dim \(Proxy @dim) -> do + Fn f <- gen $ Gen.fun genInt + let is = [minBound..maxBound] :: [Index dim] + ndarr = fromIndex f + values = map f is + ndarrValues = map (ndarr `at`) is + assert $ Pred.relatedBy ("==", (or .) . zipWith (==)) + .$ ("fromIndex f `at` i", ndarrValues) + .$ ("f i", values) + fromMaybe (error "propFromIndex: impossible") p testFromIndex :: TestTree testFromIndex = testProperty "fromIndex" propFromIndex From 1d2a2dd93f997af86f417af86d87e2a3543bebb8 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:53:53 +0300 Subject: [PATCH 04/20] Write docs for new stuff --- src/MLambda/Index.hs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 1b38f29..2fc7110 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -119,6 +119,8 @@ viewII :: IndexI dim -> IxInstance dim viewII (II :.= _) = IxInstance viewII II = IxInstance +-- | A simpler pattern for @`IndexI`@ in case you don't need to recursive instances +-- for suffixes. {-# COMPLETE IxI #-} pattern IxI :: () => Ix dim => IndexI dim pattern IxI <- (viewII -> IxInstance) where @@ -140,7 +142,8 @@ instance (KnownNat n, 1 <= n) => Ix '[n] where instance (KnownNat n, 1 <= n, Ix (d:ds)) => Ix (n:d:ds) where inst = II :.= inst --- | Convert @`Sing`@ of a +-- | Convert @`Sing`@ of a dimension list to a witness for @`Ix`@ instances +-- of that index. singToIndexI :: forall dim . Sing dim -> Maybe (IndexI dim) singToIndexI (SCons sn@SNat SNil) = case sing @1 %<=? sn of @@ -152,6 +155,9 @@ singToIndexI (SCons sn@SNat sr@SCons{}) = _ -> Nothing singToIndexI _ = Nothing +-- | Simple interface for lifting runtime dimensions into type level. +-- Provides the given continuation with the type of said dimensions and their +-- @`Ix`@ instance. withIx :: Demote [Natural] -> (forall dim . Ix dim => Proxy dim -> r) -> Maybe r withIx d f = withSomeSing d \(singToIndexI -> r) -> flip fmap r \case (IxI :: IndexI dim) -> f $ Proxy @dim From 35b6f0cd53657c496287aebddef8f264fe456c95 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:39:19 +0300 Subject: [PATCH 05/20] Add index witness construction from runtime sizes and use it for testing with falsify. --- ml.cabal | 8 +++++++ package.yaml | 5 ++++ src/MLambda/Index.hs | 42 ++++++++++++++++++++++++++++++++- src/MLambda/TypeLits.hs | 6 +---- stack.yaml | 9 +++---- test/Spec.hs | 2 ++ test/Test/MLambda/NDArr.hs | 48 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 test/Test/MLambda/NDArr.hs diff --git a/ml.cabal b/ml.cabal index cc5cefb..9b6d1a8 100644 --- a/ml.cabal +++ b/ml.cabal @@ -48,6 +48,8 @@ library , massiv , mtl , netlib-ffi + , singletons + , singletons-base , template-haskell , vector default-language: GHC2024 @@ -56,6 +58,7 @@ test-suite ml-test type: exitcode-stdio-1.0 main-is: Spec.hs other-modules: + Test.MLambda.NDArr Paths_ml autogen-modules: Paths_ml @@ -72,11 +75,14 @@ test-suite ml-test base >=4.7 && <5 , blas-ffi , deepseq + , falsify , haskell-src-meta , massiv , ml , mtl , netlib-ffi + , singletons + , singletons-base , tasty , tasty-hunit , template-haskell @@ -110,6 +116,8 @@ benchmark ml-bench , netlib-ffi , normaldistribution , random + , singletons + , singletons-base , tasty-bench , template-haskell , vector diff --git a/package.yaml b/package.yaml index 50126bc..d9c9f7d 100644 --- a/package.yaml +++ b/package.yaml @@ -29,6 +29,8 @@ dependencies: - template-haskell - vector - haskell-src-meta +- singletons +- singletons-base ghc-options: - -Wall @@ -72,6 +74,9 @@ tests: - ml - tasty - tasty-hunit + - falsify + - singletons + - singletons-base language: GHC2024 default-extensions: diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 2aae007..513fa02 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -1,5 +1,6 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE ViewPatterns #-} - -- | -- Module : MLambda.Index -- Description : Type of multidimensional array indices. @@ -17,13 +18,22 @@ module MLambda.Index , concatIndex , IndexI (..) , Ix + , pattern IxI , inst + -- * Lifting of runtime dimesions into indices + , singToIndexI + , withIx ) where import Data.Proxy (Proxy (..)) import MLambda.TypeLits +import Data.Bool.Singletons +import Data.List.Singletons +import Data.Singletons +import GHC.TypeLits.Singletons hiding (natVal) + -- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@. -- Instances are provided for convenient use: -- @@ -110,6 +120,20 @@ data IndexI (dim :: [Natural]) where (:.=) :: (KnownNat n, 1 <= n, Ix ds) => Proxy n -> IndexI ds -> IndexI (n : ds) +data IxInstance (dim :: [Natural]) where + IxInstance :: Ix dim => IxInstance dim + +viewII :: IndexI dim -> IxInstance dim +viewII (II :.= _) = IxInstance +viewII EI = IxInstance + +{-# COMPLETE IxI #-} +pattern IxI :: () => Ix dim => IndexI dim +pattern IxI <- (viewII -> IxInstance) where + IxI = inst + +deriving instance Show (IndexI dim) + infixr 5 :.= -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain @@ -123,3 +147,19 @@ instance Ix '[] where instance (KnownNat n, 1 <= n, Ix ds) => Ix (n:ds) where inst = Proxy :.= inst + +-- | Convert @`Sing`@ of a +singToIndexI :: forall dim . Sing dim -> Maybe (IndexI dim) +singToIndexI (SCons sn@SNat SNil) = + case sing @1 %<=? sn of + STrue -> Just II + SFalse -> Nothing +singToIndexI (SCons sn@SNat sr@SCons{}) = + case (sing @1 %<=? sn, singToIndexI sr) of + (STrue, Just r@IxI) -> Just $ II :.= r + _ -> Nothing +singToIndexI _ = Nothing + +withIx :: Demote [Natural] -> (forall dim . Ix dim => Proxy dim -> r) -> Maybe r +withIx d f = withSomeSing d \(singToIndexI -> r) -> + flip fmap r \case (IxI :: IndexI dim) -> f $ Proxy @dim diff --git a/src/MLambda/TypeLits.hs b/src/MLambda/TypeLits.hs index 5f0b54b..a6618da 100644 --- a/src/MLambda/TypeLits.hs +++ b/src/MLambda/TypeLits.hs @@ -28,6 +28,7 @@ module MLambda.TypeLits , rpnat ) where +import Data.List.Singletons import Data.Proxy (Proxy (Proxy)) import GHC.TypeError (ErrorMessage (..), TypeError) import GHC.TypeNats hiding (natVal) @@ -45,11 +46,6 @@ type family Unify n a b where Unify _ a a = a Unify n a b = TypeError (Text n :<>: Text " are not equal:" :$$: ShowType a :$$: ShowType b) --- | Type-level list concatenation. -type family xs ++ ys where - '[] ++ ys = ys - (x:xs) ++ ys = x : xs ++ ys - -- | Peano naturals. data PNat = PZ | PS PNat diff --git a/stack.yaml b/stack.yaml index 2dc9a56..0dd9cb2 100644 --- a/stack.yaml +++ b/stack.yaml @@ -34,10 +34,8 @@ packages: # These entries can reference officially published versions as well as # forks / in-progress versions pinned to a git hash. For example: # -# extra-deps: -# - acme-missiles-0.3 -# - git: https://github.com/commercialhaskell/stack.git -# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a +extra-deps: +- falsify-0.2.0@sha256:af5c4142095d05775236c8e18e827d540ed22f57b3b74b2268b32040b514ac88,5451 # # extra-deps: [] @@ -64,3 +62,6 @@ packages: # # Allow a newer minor version of GHC than the snapshot specifies # compiler-check: newer-minor +allow-newer-deps: +- falsify +allow-newer: true diff --git a/test/Spec.hs b/test/Spec.hs index c18740c..e367a5a 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,5 +1,6 @@ {-# LANGUAGE QuasiQuotes #-} +import Test.MLambda.NDArr import Test.Tasty import Test.Tasty.HUnit @@ -26,4 +27,5 @@ tests :: TestTree tests = testGroup "Tests" [ testCase "Mul" $ a `cross` b @?= c + , testNDArr ] diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs new file mode 100644 index 0000000..c23dcf1 --- /dev/null +++ b/test/Test/MLambda/NDArr.hs @@ -0,0 +1,48 @@ +module Test.MLambda.NDArr (testNDArr) where + +import MLambda.Index +import MLambda.NDArr + +import Control.Monad +import Data.Proxy +import Data.Singletons +import Numeric.Natural +import Test.Falsify.Generator qualified as Gen +import Test.Falsify.Predicate qualified as Pred +import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Range qualified as Range +import Test.Tasty.Falsify +import Test.Tasty + +genSz :: Gen Natural +genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) + +genDim :: Gen [Natural] +genDim = Gen.list (Range.between (1, 5)) $ genSz + +genInt :: Gen Int +genInt = Gen.inRange $ Range.between (-1000, 1000) + +instance Enum (Index dim) => Gen.Function (Index dim) where + function gb = Gen.functionMap fromEnum toEnum <$> Gen.function gb + +propFromIndex :: Property () +propFromIndex = do + dim <- gen genDim + let Just p = withIx dim \(Proxy @dim) -> do + Fn f <- gen $ Gen.fun genInt + let is = [minBound..maxBound] :: [Index dim] + ndarr = fromIndex f + values = map f is + ndarrValues = map (ndarr `at`) is + pred = Pred.relatedBy + ("==", (foldl1 (||) .) . zipWith (==)) + assert $ pred .$ ("fromIndex f `at` i", ndarrValues) + .$ ("f i", values) + p + +testFromIndex :: TestTree +testFromIndex = testProperty "fromIndex" propFromIndex + +testNDArr :: TestTree +testNDArr = testGroup "NDArr" [testFromIndex] From f33e80749cdde2af53028dedb6d9d17f9c226219 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:42:49 +0300 Subject: [PATCH 06/20] Lint and format --- src/MLambda/Index.hs | 2 +- test/Test/MLambda/NDArr.hs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 513fa02..9755d2c 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -125,7 +125,7 @@ data IxInstance (dim :: [Natural]) where viewII :: IndexI dim -> IxInstance dim viewII (II :.= _) = IxInstance -viewII EI = IxInstance +viewII EI = IxInstance {-# COMPLETE IxI #-} pattern IxI :: () => Ix dim => IndexI dim diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index c23dcf1..3ff3362 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -8,17 +8,17 @@ import Data.Proxy import Data.Singletons import Numeric.Natural import Test.Falsify.Generator qualified as Gen -import Test.Falsify.Predicate qualified as Pred import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Predicate qualified as Pred import Test.Falsify.Range qualified as Range -import Test.Tasty.Falsify import Test.Tasty +import Test.Tasty.Falsify genSz :: Gen Natural genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) genDim :: Gen [Natural] -genDim = Gen.list (Range.between (1, 5)) $ genSz +genDim = Gen.list (Range.between (1, 5)) genSz genInt :: Gen Int genInt = Gen.inRange $ Range.between (-1000, 1000) @@ -36,7 +36,7 @@ propFromIndex = do values = map f is ndarrValues = map (ndarr `at`) is pred = Pred.relatedBy - ("==", (foldl1 (||) .) . zipWith (==)) + ("==", (or .) . zipWith (==)) assert $ pred .$ ("fromIndex f `at` i", ndarrValues) .$ ("f i", values) p From b63677fc188119ca941d1d1641b7ed0c211d9705 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:49:05 +0300 Subject: [PATCH 07/20] Reduce warnings --- test/Test/MLambda/NDArr.hs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index 3ff3362..be6d7f7 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -1,11 +1,11 @@ +{-# LANGUAGE TypeAbstractions #-} module Test.MLambda.NDArr (testNDArr) where import MLambda.Index import MLambda.NDArr -import Control.Monad +import Data.Maybe import Data.Proxy -import Data.Singletons import Numeric.Natural import Test.Falsify.Generator qualified as Gen import Test.Falsify.Predicate ((.$)) @@ -29,17 +29,16 @@ instance Enum (Index dim) => Gen.Function (Index dim) where propFromIndex :: Property () propFromIndex = do dim <- gen genDim - let Just p = withIx dim \(Proxy @dim) -> do - Fn f <- gen $ Gen.fun genInt - let is = [minBound..maxBound] :: [Index dim] - ndarr = fromIndex f - values = map f is - ndarrValues = map (ndarr `at`) is - pred = Pred.relatedBy - ("==", (or .) . zipWith (==)) - assert $ pred .$ ("fromIndex f `at` i", ndarrValues) - .$ ("f i", values) - p + let p = withIx dim \(Proxy @dim) -> do + Fn f <- gen $ Gen.fun genInt + let is = [minBound..maxBound] :: [Index dim] + ndarr = fromIndex f + values = map f is + ndarrValues = map (ndarr `at`) is + assert $ Pred.relatedBy ("==", (or .) . zipWith (==)) + .$ ("fromIndex f `at` i", ndarrValues) + .$ ("f i", values) + fromMaybe (error "propFromIndex: impossible") p testFromIndex :: TestTree testFromIndex = testProperty "fromIndex" propFromIndex From 2afb36675b5d43e16eff11f3146d05c12fbdf41f Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 14 Aug 2025 05:53:53 +0300 Subject: [PATCH 08/20] Write docs for new stuff --- src/MLambda/Index.hs | 22 +++++++++++----------- src/MLambda/NDArr.hs | 2 +- src/MLambda/TypeLits.hs | 8 ++++---- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 9755d2c..74dcaeb 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -25,8 +25,6 @@ module MLambda.Index , withIx ) where -import Data.Proxy (Proxy (..)) - import MLambda.TypeLits import Data.Bool.Singletons @@ -124,9 +122,11 @@ data IxInstance (dim :: [Natural]) where IxInstance :: Ix dim => IxInstance dim viewII :: IndexI dim -> IxInstance dim -viewII (II :.= _) = IxInstance +viewII (_ :.= _) = IxInstance viewII EI = IxInstance +-- | A simpler pattern for @`IndexI`@ in case you don't need to recursive instances +-- for suffixes. {-# COMPLETE IxI #-} pattern IxI :: () => Ix dim => IndexI dim pattern IxI <- (viewII -> IxInstance) where @@ -148,18 +148,18 @@ instance Ix '[] where instance (KnownNat n, 1 <= n, Ix ds) => Ix (n:ds) where inst = Proxy :.= inst --- | Convert @`Sing`@ of a +-- | Convert @`Sing`@ of a dimension list to a witness for @`Ix`@ instances +-- of that index. singToIndexI :: forall dim . Sing dim -> Maybe (IndexI dim) -singToIndexI (SCons sn@SNat SNil) = - case sing @1 %<=? sn of - STrue -> Just II - SFalse -> Nothing -singToIndexI (SCons sn@SNat sr@SCons{}) = +singToIndexI SNil = Just EI +singToIndexI (SCons sn@SNat sr) = case (sing @1 %<=? sn, singToIndexI sr) of - (STrue, Just r@IxI) -> Just $ II :.= r + (STrue, Just r@IxI) -> Just $ Proxy :.= r _ -> Nothing -singToIndexI _ = Nothing +-- | Simple interface for lifting runtime dimensions into type level. +-- Provides the given continuation with the type of said dimensions and their +-- @`Ix`@ instance. withIx :: Demote [Natural] -> (forall dim . Ix dim => Proxy dim -> r) -> Maybe r withIx d f = withSomeSing d \(singToIndexI -> r) -> flip fmap r \case (IxI :: IndexI dim) -> f $ Proxy @dim diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index f60efc1..464dcdd 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -159,7 +159,7 @@ data StackWitness i d1 d2 dr where ( KnownNat k, KnownNat l, KnownNat (k + l), Ix t , 1 <= k, 1 <= l, 1 <= k + l ) => Proxy '(s, k, l, t) -> - StackWitness (Length s) (s ++ (k : t)) (s ++ (l : t)) (s ++ ((k + l) : t)) + StackWitness (PLength s) (s ++ (k : t)) (s ++ (l : t)) (s ++ ((k + l) : t)) instance ( KnownNat n, KnownNat m, KnownNat k, Ix d diff --git a/src/MLambda/TypeLits.hs b/src/MLambda/TypeLits.hs index a6618da..3ae8a67 100644 --- a/src/MLambda/TypeLits.hs +++ b/src/MLambda/TypeLits.hs @@ -19,7 +19,7 @@ module MLambda.TypeLits , Unify , type (++) , PNat (..) - , Length + , PLength , Peano , RNat (..) , RPNat (..) @@ -50,9 +50,9 @@ type family Unify n a b where data PNat = PZ | PS PNat -- | Compute length of a type-level list as a Peano natural. -type family Length xs where - Length '[] = PZ - Length (_:xs) = PS (Length xs) +type family PLength xs where + PLength '[] = PZ + PLength (_:xs) = PS (PLength xs) -- | Compute Peano representation from type-level natural. type family Peano n where From 7a5e82616567bc5c77c8fcea83ec82b48339b5ad Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sat, 16 Aug 2025 16:53:01 +0300 Subject: [PATCH 09/20] Cleanup `Index` stuff --- src/MLambda/Index.hs | 69 +++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 74dcaeb..1e52eaf 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -13,16 +13,18 @@ -- This module contains definition of 'Index' type of multidimensional array -- indices along with its instances and public interface. module MLambda.Index - ( Index (E, (:.)) - , consIndex - , concatIndex + ( -- * Index type + Index (E) + , pattern (:.) + -- * Index instances + , Ix(..) , IndexI (..) - , Ix , pattern IxI - , inst -- * Lifting of runtime dimesions into indices , singToIndexI + -- * Index operations , withIx + , concatIndex ) where import MLambda.TypeLits @@ -45,8 +47,20 @@ import GHC.TypeLits.Singletons hiding (natVal) -- order their respective elements are laid out in memory. data Index (dim :: [Natural]) where E :: Index '[] - I :: Int -> Index '[n] - (:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds) + ICons :: Int -> Index ds -> Index (d : ds) + +{-# COMPLETE I #-} +pattern I :: Int -> Index '[d] +pattern I m = ICons m E + +viewI :: Index (d:ds) -> (Index '[d], Index ds) +viewI (ICons x xs) = (I x, xs) + +{-# COMPLETE (:.) #-} +-- | Prepend a single-dimensional index to multi-dimensional one +pattern (:.) :: Index '[d] -> Index ds -> Index (d:ds) +pattern x :. xs <- (viewI -> (x, xs)) + where (ICons x E) :. xs = ICons x xs deriving instance Eq (Index dim) deriving instance Ord (Index dim) @@ -67,8 +81,8 @@ instance Bounded (Index '[]) where maxBound = E instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where - minBound = 0 `consIndex` minBound - maxBound = (-1) `consIndex` maxBound + minBound = 0 :. minBound + maxBound = (-1) :. maxBound instance Enum (Index '[]) where toEnum = const E @@ -76,34 +90,17 @@ instance Enum (Index '[]) where instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => Enum (Index (n:d)) where - toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q `consIndex` toEnum r - fromEnum = \case - I i -> i - I q :. r -> q * enumSize (Index d) + fromEnum r - succ = \case - h@(I i) | h == maxBound -> error "Undefined succ" - | otherwise -> I (succ i) - h :. t | t == maxBound -> succ h :. minBound - | otherwise -> h :. succ t - pred = \case - h@(I i) | h == minBound -> error "Undefined pred" - | otherwise -> I (pred i) - h :. t | t == minBound -> pred h :. maxBound - | otherwise -> h :. pred t - --- | Prepend a single-dimensional index to multi-dimensional one -consIndex :: Index '[x] -> Index xs -> Index (x : xs) -consIndex (I x) = \case - E -> I x - I y -> I x :. I y - I y :. xs -> I x :. I y :. xs + toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q :. toEnum r + fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r + succ (h :. t) | t == maxBound = succ h :. minBound + | otherwise = h :. succ t + pred (h :. t) | t == minBound = pred h :. maxBound + | otherwise = h :. pred t -- | Concatenate two indices together -concatIndex :: Index xs -> Index ys -> Index (xs ++ ys) -concatIndex = \case - E -> id - I x -> consIndex (I x) - I x :. xs -> consIndex (I x) . concatIndex xs +concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys) +concatIndex E = id +concatIndex (ICons x xs) = ICons x . concatIndex xs -- | A helper type that holds instances for everything you can get by pattern matching on -- an @`Index`@ value. If you need to get instances for a head/tail of an @`Index`@, @@ -125,7 +122,7 @@ viewII :: IndexI dim -> IxInstance dim viewII (_ :.= _) = IxInstance viewII EI = IxInstance --- | A simpler pattern for @`IndexI`@ in case you don't need to recursive instances +-- | A simpler pattern for @`IndexI`@ in case you don't need instances -- for suffixes. {-# COMPLETE IxI #-} pattern IxI :: () => Ix dim => IndexI dim From ac56fae08a6130259215fc989ac42ed2c56f10dd Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 01:48:30 +0300 Subject: [PATCH 10/20] Cleanup (with a fork) --- src/MLambda/Index.hs | 2 +- src/MLambda/NDArr.hs | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 1e52eaf..3a8b02d 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -90,7 +90,7 @@ instance Enum (Index '[]) where instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => Enum (Index (n:d)) where - toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q :. toEnum r + toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r succ (h :. t) | t == maxBound = succ h :. minBound | otherwise = h :. succ t diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 464dcdd..8b30d74 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -63,18 +63,18 @@ newtype NDArr (dim :: [Natural]) e = MkNDArr { unsafeMkNDArr :: forall dim e. Storable.Vector e -> NDArr dim e unsafeMkNDArr = MkNDArr -instance (Show e, Storable e) => Show (NDArr '[n] e) where - show = show . runNDArr - -instance (Ix (n:a:r), Show (NDArr (a:r) e), Storable e) => - Show (NDArr (n:a:r) e) where +instance (Ix dim, Show e, Storable e) => Show (NDArr dim e) where showsPrec _ = - case inst @(n:a:r) of - _ :.= _ -> (showString "[" .) . (. showString "]") - . foldl' (.) id - . intersperse (showString ",\n") - . map shows - . toList . rows @'[n] + case inst @dim of + _ :.= _ -> go + EI -> shows . (Storable.! 0) . runNDArr + where + go :: forall n r e' . (Ix r, Show (NDArr r e'), Storable e') => NDArr (n:r) e' -> ShowS + go = (showString "[" .) . (. showString "]") + . foldl' (.) id + . intersperse (showString ",\n") + . map shows + . toList . rows @'[n] instance (Ix d, Storable e) => Storable (NDArr d e) where sizeOf _ = sizeOf (undefined :: e) * enumSize (Index d) From 1f0043c816d64b426f96a3884053c560e0272808 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 01:49:00 +0300 Subject: [PATCH 11/20] Test another property for `fromIndex` --- test/Test/MLambda/NDArr.hs | 48 ++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index be6d7f7..2dd4b9f 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -23,25 +23,49 @@ genDim = Gen.list (Range.between (1, 5)) genSz genInt :: Gen Int genInt = Gen.inRange $ Range.between (-1000, 1000) +genIndex :: forall dim . Ix dim => Gen (Index dim) +genIndex = case inst @dim of + EI -> pure E + Proxy @h :.= (_ :: IndexI t) -> do + h <- genInt + t <- genIndex @t + pure ((toEnum h :: Index '[h]) :. t) + +genNDArr :: forall dim . Ix dim => Gen (NDArr dim Int) +genNDArr = do + Fn f <- Gen.fun genInt + pure $ fromIndex f + instance Enum (Index dim) => Gen.Function (Index dim) where function gb = Gen.functionMap fromEnum toEnum <$> Gen.function gb -propFromIndex :: Property () -propFromIndex = do +propAtDotFromIndex :: Property () +propAtDotFromIndex = do dim <- gen genDim let p = withIx dim \(Proxy @dim) -> do - Fn f <- gen $ Gen.fun genInt - let is = [minBound..maxBound] :: [Index dim] - ndarr = fromIndex f - values = map f is - ndarrValues = map (ndarr `at`) is - assert $ Pred.relatedBy ("==", (or .) . zipWith (==)) - .$ ("fromIndex f `at` i", ndarrValues) - .$ ("f i", values) - fromMaybe (error "propFromIndex: impossible") p + Fn f <- gen $ Gen.fun genInt + i <- gen (genIndex @dim) + assert $ Pred.eq + .$ ("at . fromIndex", (at . fromIndex) f i) + .$ ("id", f i) + fromMaybe (error "propAtDotFromIndex: impossible") p + +propFromIndexDotAt :: Property () +propFromIndexDotAt = do + dim <- gen genDim + let p = withIx dim \(Proxy @dim) -> + case inst @dim of + _ -> do + arr <- gen (genNDArr @dim) + assert $ Pred.eq + .$ ("fromIndex . at", (fromIndex . at) arr) + .$ ("id", arr) + fromMaybe (error "propFromIndexDotAt: impossible") p testFromIndex :: TestTree -testFromIndex = testProperty "fromIndex" propFromIndex +testFromIndex = testGroup "fromIndex" + [ testProperty "at . fromIndex = id" propAtDotFromIndex + , testProperty "fromIndex . at = id" propFromIndexDotAt] testNDArr :: TestTree testNDArr = testGroup "NDArr" [testFromIndex] From 5d50a37f43e45f2a6a72de1c3cecf5ea7febcb59 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 01:49:33 +0300 Subject: [PATCH 12/20] Format --- src/MLambda/Index.hs | 4 ++-- test/Test/MLambda/NDArr.hs | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 3a8b02d..20b9904 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -119,8 +119,8 @@ data IxInstance (dim :: [Natural]) where IxInstance :: Ix dim => IxInstance dim viewII :: IndexI dim -> IxInstance dim -viewII (_ :.= _) = IxInstance -viewII EI = IxInstance +viewII (_ :.= _) = IxInstance +viewII EI = IxInstance -- | A simpler pattern for @`IndexI`@ in case you don't need instances -- for suffixes. diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index 2dd4b9f..2f44b1e 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -53,13 +53,11 @@ propAtDotFromIndex = do propFromIndexDotAt :: Property () propFromIndexDotAt = do dim <- gen genDim - let p = withIx dim \(Proxy @dim) -> - case inst @dim of - _ -> do - arr <- gen (genNDArr @dim) - assert $ Pred.eq - .$ ("fromIndex . at", (fromIndex . at) arr) - .$ ("id", arr) + let p = withIx dim \(Proxy @dim) -> do + arr <- gen (genNDArr @dim) + assert $ Pred.eq + .$ ("fromIndex . at", (fromIndex . at) arr) + .$ ("id", arr) fromMaybe (error "propFromIndexDotAt: impossible") p testFromIndex :: TestTree From 121b07fcf0644117370f32a5e5c78388fc7332f7 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 02:09:05 +0300 Subject: [PATCH 13/20] Fix docs --- src/MLambda/Index.hs | 2 +- src/MLambda/NDArr.hs | 1 + src/MLambda/TypeLits.hs | 2 -- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 20b9904..8fc7c9e 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -22,8 +22,8 @@ module MLambda.Index , pattern IxI -- * Lifting of runtime dimesions into indices , singToIndexI - -- * Index operations , withIx + -- * Index operations , concatIndex ) where diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 8b30d74..9842223 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -41,6 +41,7 @@ import Control.DeepSeq (NFData) import Control.Monad.ST (runST) import Data.Foldable (forM_) import Data.List (intersperse) +import Data.List.Singletons import Data.Proxy import Data.Vector.Storable qualified as Storable import Data.Vector.Storable.Mutable qualified as Mutable diff --git a/src/MLambda/TypeLits.hs b/src/MLambda/TypeLits.hs index 3ae8a67..7f0631d 100644 --- a/src/MLambda/TypeLits.hs +++ b/src/MLambda/TypeLits.hs @@ -17,7 +17,6 @@ module MLambda.TypeLits , natVal , enumSize , Unify - , type (++) , PNat (..) , PLength , Peano @@ -28,7 +27,6 @@ module MLambda.TypeLits , rpnat ) where -import Data.List.Singletons import Data.Proxy (Proxy (Proxy)) import GHC.TypeError (ErrorMessage (..), TypeError) import GHC.TypeNats hiding (natVal) From b392068e1e4695f98d61a586a42ddbd74e41f6e8 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 03:42:23 +0300 Subject: [PATCH 14/20] Test `stack` I need therapy now --- src/MLambda/Index.hs | 8 +++++ src/MLambda/NDArr.hs | 21 ++++++++----- test/Test/MLambda/NDArr.hs | 60 ++++++++++++++++++++++++++++++++++---- 3 files changed, 76 insertions(+), 13 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 8fc7c9e..8534137 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -20,6 +20,7 @@ module MLambda.Index , Ix(..) , IndexI (..) , pattern IxI + , concatIndexI -- * Lifting of runtime dimesions into indices , singToIndexI , withIx @@ -133,6 +134,13 @@ deriving instance Show (IndexI dim) infixr 5 :.= +-- | Concatenate two @`IndexI`@s to get the instances for concatenated +-- dimensions. +concatIndexI :: IndexI d1 -> IndexI d2 -> IndexI (d1 ++ d2) +concatIndexI EI d2 = d2 +concatIndexI (i1 :.= d1) d2 = case concatIndexI d1 d2 of + IxI -> i1 :.= IxI + -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain -- a value of @`IndexI`@. class (Bounded (Index dim), Enum (Index dim)) => Ix dim where diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 9842223..d281683 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -29,7 +29,9 @@ module MLambda.NDArr -- * Array composition , Stack , Stacks + , StackWitness(..) , stack + , stackWithWitness -- * Unsafe API , unsafeMkNDArr ) where @@ -42,7 +44,7 @@ import Control.Monad.ST (runST) import Data.Foldable (forM_) import Data.List (intersperse) import Data.List.Singletons -import Data.Proxy +import Data.Singletons import Data.Vector.Storable qualified as Storable import Data.Vector.Storable.Mutable qualified as Mutable import Foreign.Ptr (Ptr, castPtr) @@ -139,7 +141,7 @@ vstack :: vstack (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys) -- | A type family which computes the resulting size of a stacked array. -type Stack n d e = StackImpl (StackError n d e) (Peano n) d e +type Stack n d e = StackImpl (StackError n d e) n d e type StackError n d e = Text "Not enough dimensions to stack along axis " :<>: ShowType n @@ -155,6 +157,7 @@ type family StackImpl msg i d e where class Stacks i dim1 dim2 dimr where stacks :: StackWitness i dim1 dim2 dimr +-- | A witness that carries the constraints needed to stack arrays together. data StackWitness i d1 d2 dr where SW :: ( KnownNat k, KnownNat l, KnownNat (k + l), Ix t @@ -172,11 +175,15 @@ instance Stacks i d e r => Stacks (PS i) (n : d) (n : e) (n : r) where stacks = case stacks @i of SW (Proxy @'(s, k, l, t)) -> SW (Proxy @'(n : s, k, l, t)) +-- | Sometimes you need to create a witness yourself for this to work. +stackWithWitness :: Storable e => StackWitness n d1 d2 dr + -> NDArr d1 e -> NDArr d2 e -> NDArr dr e +stackWithWitness (SW (Proxy @'(s, k, l, t))) xs ys = + concat $ zipWith vstack (rows @s @(k : t) xs) (rows @s @(l : t) ys) + -- | @stack i@ stacks arrays along the axis @i@. All other axes are required -- to be the same lengths. stack :: - forall n -> (Stacks (Peano n) d1 d2 (Stack n d1 d2), Storable e) => - NDArr d1 e -> NDArr d2 e -> NDArr (Stack n d1 d2) e -stack n xs ys = case stacks @(Peano n) of - SW (Proxy @'(s, k, l, t)) -> - concat $ zipWith vstack (rows @s @(k : t) xs) (rows @s @(l : t) ys) + forall n -> (Stacks (Peano n) d1 d2 (Stack (Peano n) d1 d2), Storable e) => + NDArr d1 e -> NDArr d2 e -> NDArr (Stack (Peano n) d1 d2) e +stack n xs ys = stackWithWitness (stacks @(Peano n)) xs ys diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index 2f44b1e..a312422 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -1,12 +1,19 @@ {-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE ViewPatterns #-} module Test.MLambda.NDArr (testNDArr) where import MLambda.Index import MLambda.NDArr +import MLambda.TypeLits +import Control.Monad +import Data.Bool.Singletons +import Data.List.Singletons (type (++)) import Data.Maybe import Data.Proxy -import Numeric.Natural +import Data.Singletons +import GHC.TypeLits.Singletons +import Prelude.Singletons ((%+)) import Test.Falsify.Generator qualified as Gen import Test.Falsify.Predicate ((.$)) import Test.Falsify.Predicate qualified as Pred @@ -17,8 +24,8 @@ import Test.Tasty.Falsify genSz :: Gen Natural genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) -genDim :: Gen [Natural] -genDim = Gen.list (Range.between (1, 5)) genSz +genDim :: Word -> Word -> Gen [Natural] +genDim a b = Gen.list (Range.between (a, b)) genSz genInt :: Gen Int genInt = Gen.inRange $ Range.between (-1000, 1000) @@ -41,7 +48,7 @@ instance Enum (Index dim) => Gen.Function (Index dim) where propAtDotFromIndex :: Property () propAtDotFromIndex = do - dim <- gen genDim + dim <- gen $ genDim 0 5 let p = withIx dim \(Proxy @dim) -> do Fn f <- gen $ Gen.fun genInt i <- gen (genIndex @dim) @@ -52,7 +59,7 @@ propAtDotFromIndex = do propFromIndexDotAt :: Property () propFromIndexDotAt = do - dim <- gen genDim + dim <- gen $ genDim 0 5 let p = withIx dim \(Proxy @dim) -> do arr <- gen (genNDArr @dim) assert $ Pred.eq @@ -65,5 +72,46 @@ testFromIndex = testGroup "fromIndex" [ testProperty "at . fromIndex = id" propAtDotFromIndex , testProperty "fromIndex . at = id" propFromIndexDotAt] +propStack :: Property () +propStack = do + dimsuff <- gen $ genDim 0 3 + dimpref <- gen $ genDim 0 2 + dimmid1 <- gen genSz + dimmid2 <- gen genSz + case ( toSing dimsuff + , toSing dimpref + , toSing dimmid1 + , toSing dimmid2) of + (SomeSing (singToIndexI -> Just (IxI @p)) + , SomeSing (singToIndexI -> Just (IxI @s)) + , SomeSing sm1@(SNat @m1), SomeSing sm2@(SNat @m2)) -> + case ( sing @1 %<=? sm1 + , sing @1 %<=? sm2 + , sm1 %+ sm2 + , sing @1 %<=? sm1 %+ sm2) of + (STrue, STrue, SNat, STrue) -> + case ( concatIndexI (IxI @p) (IxI @(m1 : s)) + , concatIndexI (IxI @p) (IxI @(m2 : s)) + , concatIndexI (IxI @p) (IxI @((m1 + m2) : s))) of + (IxI, IxI, IxI) -> do + arr1 <- gen $ genNDArr @(p ++ (m1 : s)) + arr2 <- gen $ genNDArr @(p ++ (m2 : s)) + let arr3 = stackWithWitness (SW (Proxy @'(p, m1, m2, s))) arr1 arr2 + s <- gen $ genIndex @s + p <- gen $ genIndex @p + m <- gen $ genIndex @'[m1 + m2] + let i1 = p `concatIndex` ((toEnum (fromEnum m) :: Index '[m1]) :. s) + i2 = p `concatIndex` + ((toEnum (fromEnum m - enumSize (Index [m1])) :: Index '[m2]) :. s) + i3 = p `concatIndex` (m :. s) + if fromEnum m <= fromEnum (maxBound @(Index '[m1])) + then when (arr1 `at` i1 /= arr3 `at` i3) $ testFailed "mid index in lhs" + else when (arr2 `at` i2 /= arr3 `at` i3) $ testFailed "mid index in rhs" + _ -> error "propStack: impossible" + _ -> error "propStack: impossible" + +testStack :: TestTree +testStack = testProperty "stack" propStack + testNDArr :: TestTree -testNDArr = testGroup "NDArr" [testFromIndex] +testNDArr = testGroup "NDArr" [testFromIndex, testStack] From 0ef5e8e1079fa4c65e19c8633e310a1299061037 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 03:43:51 +0300 Subject: [PATCH 15/20] Eta reduce for hlint --- src/MLambda/NDArr.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index d281683..22838c5 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -186,4 +186,4 @@ stackWithWitness (SW (Proxy @'(s, k, l, t))) xs ys = stack :: forall n -> (Stacks (Peano n) d1 d2 (Stack (Peano n) d1 d2), Storable e) => NDArr d1 e -> NDArr d2 e -> NDArr (Stack (Peano n) d1 d2) e -stack n xs ys = stackWithWitness (stacks @(Peano n)) xs ys +stack n = stackWithWitness (stacks @(Peano n)) From 483eb440928bc97302a76a591bee40f0ca2f76f7 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 20:29:40 +0300 Subject: [PATCH 16/20] Split out test utils --- ml.cabal | 1 + test/Test/MLambda/NDArr.hs | 28 ++-------------------------- test/Test/MLambda/Utils.hs | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 26 deletions(-) create mode 100644 test/Test/MLambda/Utils.hs diff --git a/ml.cabal b/ml.cabal index 9b6d1a8..2652c89 100644 --- a/ml.cabal +++ b/ml.cabal @@ -59,6 +59,7 @@ test-suite ml-test main-is: Spec.hs other-modules: Test.MLambda.NDArr + Test.MLambda.Utils Paths_ml autogen-modules: Paths_ml diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index a312422..a52e876 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -6,6 +6,8 @@ import MLambda.Index import MLambda.NDArr import MLambda.TypeLits +import Test.MLambda.Utils + import Control.Monad import Data.Bool.Singletons import Data.List.Singletons (type (++)) @@ -17,35 +19,9 @@ import Prelude.Singletons ((%+)) import Test.Falsify.Generator qualified as Gen import Test.Falsify.Predicate ((.$)) import Test.Falsify.Predicate qualified as Pred -import Test.Falsify.Range qualified as Range import Test.Tasty import Test.Tasty.Falsify -genSz :: Gen Natural -genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) - -genDim :: Word -> Word -> Gen [Natural] -genDim a b = Gen.list (Range.between (a, b)) genSz - -genInt :: Gen Int -genInt = Gen.inRange $ Range.between (-1000, 1000) - -genIndex :: forall dim . Ix dim => Gen (Index dim) -genIndex = case inst @dim of - EI -> pure E - Proxy @h :.= (_ :: IndexI t) -> do - h <- genInt - t <- genIndex @t - pure ((toEnum h :: Index '[h]) :. t) - -genNDArr :: forall dim . Ix dim => Gen (NDArr dim Int) -genNDArr = do - Fn f <- Gen.fun genInt - pure $ fromIndex f - -instance Enum (Index dim) => Gen.Function (Index dim) where - function gb = Gen.functionMap fromEnum toEnum <$> Gen.function gb - propAtDotFromIndex :: Property () propAtDotFromIndex = do dim <- gen $ genDim 0 5 diff --git a/test/Test/MLambda/Utils.hs b/test/Test/MLambda/Utils.hs new file mode 100644 index 0000000..ef1b2af --- /dev/null +++ b/test/Test/MLambda/Utils.hs @@ -0,0 +1,37 @@ +module Test.MLambda.Utils where + +import MLambda.Index +import MLambda.NDArr +import MLambda.TypeLits + +import Data.Proxy +import Test.Falsify.Generator (Gen) +import Test.Falsify.Generator qualified as Gen +import Test.Falsify.Range (Range) +import Test.Falsify.Range qualified as Range + +genSz :: Gen Natural +genSz = fromIntegral <$> Gen.int (Range.between (1, 5)) + +genDim :: Word -> Word -> Gen [Natural] +genDim a b = Gen.list (Range.between (a, b)) genSz + +genInt :: Gen Int +genInt = Gen.inRange $ Range.between (-1000, 1000) + +genIndex :: forall dim . Ix dim => Gen (Index dim) +genIndex = case inst @dim of + EI -> pure E + Proxy @h :.= (_ :: IndexI t) -> do + h <- genInt + t <- genIndex @t + pure ((toEnum h :: Index '[h]) :. t) + +genNDArr :: forall dim . Ix dim => Gen (NDArr dim Int) +genNDArr = do + Gen.Fn f <- Gen.fun genInt + pure $ fromIndex f + +instance Enum (Index dim) => Gen.Function (Index dim) where + function gb = Gen.functionMap fromEnum toEnum <$> Gen.function gb + From 4ceb83f1f44bc09c90fb960a91efdd0965c01950 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 23:44:20 +0300 Subject: [PATCH 17/20] Implement linear maps and relate them to matrices --- ml.cabal | 8 ++++++ package.yaml | 2 ++ src/MLambda/Linear.hs | 45 ++++++++++++++++++++++++++++++++ src/MLambda/Matrix.hs | 34 +++++++++++++++++++++++- src/MLambda/NDArr.hs | 60 ++++++++++++++++++++++++++++++++++++++----- 5 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 src/MLambda/Linear.hs diff --git a/ml.cabal b/ml.cabal index 2652c89..1f8d770 100644 --- a/ml.cabal +++ b/ml.cabal @@ -26,6 +26,7 @@ source-repository head library exposed-modules: MLambda.Index + MLambda.Linear MLambda.Matrix MLambda.NDArr MLambda.TypeLits @@ -39,11 +40,13 @@ library RankNTypes PolyKinds UndecidableInstances + NoStarIsType ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints build-depends: base >=4.7 && <5 , blas-ffi , deepseq + , ghc-typelits-natnormalise , haskell-src-meta , massiv , mtl @@ -58,6 +61,7 @@ test-suite ml-test type: exitcode-stdio-1.0 main-is: Spec.hs other-modules: + Test.MLambda.Matrix Test.MLambda.NDArr Test.MLambda.Utils Paths_ml @@ -71,12 +75,14 @@ test-suite ml-test RankNTypes PolyKinds UndecidableInstances + NoStarIsType ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N build-depends: base >=4.7 && <5 , blas-ffi , deepseq , falsify + , ghc-typelits-natnormalise , haskell-src-meta , massiv , ml @@ -105,11 +111,13 @@ benchmark ml-bench RankNTypes PolyKinds UndecidableInstances + NoStarIsType ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N build-depends: base >=4.7 && <5 , blas-ffi , deepseq + , ghc-typelits-natnormalise , haskell-src-meta , massiv , ml diff --git a/package.yaml b/package.yaml index d9c9f7d..60f8660 100644 --- a/package.yaml +++ b/package.yaml @@ -31,6 +31,7 @@ dependencies: - haskell-src-meta - singletons - singletons-base +- ghc-typelits-natnormalise ghc-options: - -Wall @@ -85,3 +86,4 @@ default-extensions: - RankNTypes - PolyKinds - UndecidableInstances +- NoStarIsType diff --git a/src/MLambda/Linear.hs b/src/MLambda/Linear.hs new file mode 100644 index 0000000..c993962 --- /dev/null +++ b/src/MLambda/Linear.hs @@ -0,0 +1,45 @@ +-- | +-- Module : MLambda.Linear +-- Description : Linear maps. +-- Copyright : (c) neclitoris, TurtlePU, 2025 +-- License : BSD-3-Clause +-- Maintainer : nas140301@gmail.com +-- Stability : experimental +-- Portability : portable +-- +-- This module describes linear maps. +module MLambda.Linear (Additive(..), Module(..), LinearMap, LinearMap') where + +import Data.Kind + +-- | @Additive@ describes things you can add. +class Additive m where + add :: m -> m -> m + zero :: m + +-- | @Module@ is a class describing modules over a ring @r@. +-- Since @`Num`@ describes commutative rings, this is both a right +-- and a left module. +class (Num r, Additive m) => Module r m where + modMult :: r -> m -> m + +instance {-# OVERLAPPABLE #-} Num r => Additive r where + add = (+) + zero = 0 + +instance Num r => Module r r where + modMult = (*) + +-- | This version of LinearMap allows one to pass an arbitrary additional +-- constraint on the type mapped over. Note that this can break things: +-- @LinearMap' ((~) r) s m r@ will allow you to do naughty things. +type LinearMap' (e :: Type -> Constraint) t s r = + forall m . (Module r m, e m) => t m -> s m + +class Empty t + +instance Empty t + +-- | @LinearMap@ forces linearity (in mathematical sense) in its argument +-- since you can't multiply an element of @m@ by itself, only by some @a@. +type LinearMap t s r = LinearMap' Empty t s r diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index c9b8af3..c3fe01d 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -3,6 +3,7 @@ {-# LANGUAGE RequiredTypeArguments #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -- | -- Module : MLambda.Matrix @@ -20,10 +21,18 @@ module MLambda.Matrix , crossMassiv -- * Matrix creation , mat + , eye + , rep + , act + -- * Shape manipulation + , transpose ) where import MLambda.Foreign.Utils (asFPtr, asPtr, char) -import MLambda.NDArr +import MLambda.Index +import MLambda.Linear +import MLambda.NDArr hiding (concat, zipWith, map, foldr) +import MLambda.NDArr qualified as NDArr import MLambda.TypeLits import Control.Applicative @@ -176,3 +185,26 @@ matE s = do vec <- Storable.unsafeFreeze mvec pure $ unsafeMkNDArr @'[$tm, $tn] @Double vec |] + +-- | Identity matrix. +eye :: (KnownNat n, 1 <= n, Storable e, Num e) => NDArr '[n, n] e +eye = fromIndex go where + go (i :. j) | i == j = 1 + | otherwise = 0 + +-- | A matrix that corresponds to a linear map of vector spaces. +rep :: forall m n e . + ( KnownNat m, 1 <= m + , KnownNat n, 1 <= n + , Storable e, Num e) + => LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e -> NDArr '[n, m] e +rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`)) + +act :: forall m n e . (KnownNat m, 1 <= m , Storable e) + => NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e +act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a) + +-- | Matrix transposition. +transpose :: (KnownNat m, 1 <= m , KnownNat n, 1 <= n , Storable e) + => NDArr '[m, n] e -> NDArr '[n, m] e +transpose m = fromIndex \(i :. j) -> m `at` (j :. i) diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 22838c5..1270555 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -26,6 +26,12 @@ module MLambda.NDArr , at , row , rows + -- * Array operations + , toList + , concat + , map + , zipWith + , foldr -- * Array composition , Stack , Stacks @@ -34,15 +40,20 @@ module MLambda.NDArr , stackWithWitness -- * Unsafe API , unsafeMkNDArr + -- * Shape manipulation + , reshape + , prependDim + , stripDim ) where import MLambda.Index import MLambda.TypeLits +import MLambda.Linear import Control.DeepSeq (NFData) import Control.Monad.ST (runST) import Data.Foldable (forM_) -import Data.List (intersperse) +import Data.List qualified as List import Data.List.Singletons import Data.Singletons import Data.Vector.Storable qualified as Storable @@ -50,7 +61,7 @@ import Data.Vector.Storable.Mutable qualified as Mutable import Foreign.Ptr (Ptr, castPtr) import Foreign.Storable (Storable (..)) import GHC.TypeError (ErrorMessage (..), TypeError) -import Prelude hiding (concat, zipWith) +import Prelude hiding (map, concat, zipWith, foldr) -- | @NDArr [n1,...,nd] e@ is a type of arrays with dimensions @n1 x ... x nd@ -- consisting of elements of type @e@. @@ -75,8 +86,8 @@ instance (Ix dim, Show e, Storable e) => Show (NDArr dim e) where go :: forall n r e' . (Ix r, Show (NDArr r e'), Storable e') => NDArr (n:r) e' -> ShowS go = (showString "[" .) . (. showString "]") . foldl' (.) id - . intersperse (showString ",\n") - . map shows + . List.intersperse (showString ",\n") + . List.map shows . toList . rows @'[n] instance (Ix d, Storable e) => Storable (NDArr d e) where @@ -85,6 +96,13 @@ instance (Ix d, Storable e) => Storable (NDArr d e) where peek (castPtr -> (ptr :: Ptr e)) = fromIndexM (peekElemOff ptr . fromEnum) poke (castPtr -> (ptr :: Ptr e)) = (`Storable.iforM_` pokeElemOff ptr) . runNDArr +instance (Ix d, Storable r, Num r) => Additive (NDArr d r) where + add = zipWith (+) + zero = MkNDArr $ Storable.replicate (enumSize (Index d)) 0 + +instance (Ix d, Storable r, Num r) => Module r (NDArr d r) where + modMult r = MkNDArr . Storable.map (r *) . runNDArr + -- | Construct an array from a function @f@ that maps indices to elements. -- This function is strict and therefore @f@ can't refer to the result of -- @fromIndex f@. @@ -104,38 +122,50 @@ fromIndexM f = do vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec --- | Access array element by its index. This is a total function. +-- | O(1). Access array element by its index. This is a total function. at :: (Storable e, Ix dim) => NDArr dim e -> Index dim -> e (MkNDArr v) `at` i = v Storable.! fromEnum i infixl 9 `at` --- | Extract a "row" from the array. If you're used to C or numpy arrays, +-- | O(1). Extract a "row" from the array. If you're used to C or numpy arrays, -- this is similar to @a[i]@. row :: forall d1 d2 e. (Ix d1, Ix d2, Storable e) => Index d1 -> NDArr (d1 ++ d2) e -> NDArr d2 e row i a = rows a `at` i --- | Extract all "rows" from the array as an array. +-- | O(1). Extract all "rows" from the array as an array. rows :: forall d1 d2 e. (Ix d2, Storable e) => NDArr (d1 ++ d2) e -> NDArr d1 (NDArr d2 e) rows = MkNDArr . Storable.unsafeCast . runNDArr +-- | O(array_size). Extract @`NDArr`@ elements as a list in the order they are +-- laid out in memory. toList :: Storable e => NDArr d e -> [e] toList = Storable.toList . runNDArr +-- | O(1). Flatten an array of arrays into a single array. concat :: forall d1 d2 e. (Ix d2, Storable e) => NDArr d1 (NDArr d2 e) -> NDArr (d1 ++ d2) e concat = MkNDArr . Storable.unsafeCast . runNDArr +-- | O(array_size). Transform elements of an array. +map :: (Storable a, Storable b) => (a -> b) -> NDArr d a -> NDArr d b +map f = MkNDArr . Storable.map f . runNDArr + +-- | O(array_size). Combine two arrays into one element-by-element. zipWith :: (Storable a, Storable b, Storable c) => (a -> b -> c) -> NDArr d a -> NDArr d b -> NDArr d c zipWith f (MkNDArr xs) (MkNDArr ys) = MkNDArr (Storable.zipWith f xs ys) +-- | O(array_size). Reduce array using an accumulator function and initial value. +foldr :: Storable a => (a -> r -> r) -> r -> NDArr d a -> r +foldr f r = Storable.foldr f r . runNDArr + vstack :: Storable e => NDArr (k : d) e -> NDArr (l : d) e -> NDArr ((k + l) : d) e vstack (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys) @@ -187,3 +217,19 @@ stack :: forall n -> (Stacks (Peano n) d1 d2 (Stack (Peano n) d1 d2), Storable e) => NDArr d1 e -> NDArr d2 e -> NDArr (Stack (Peano n) d1 d2) e stack n = stackWithWitness (stacks @(Peano n)) + +type family Size d where + Size '[] = 1 + Size (x:xs) = x * Size xs + +-- | Change the shape of an array. +reshape :: forall d2 -> (Size d1 ~ Size d2) => NDArr d1 e -> NDArr d2 e +reshape _ = MkNDArr . runNDArr + +-- | Prepend a single dimension of size 1. +prependDim :: NDArr d e -> NDArr (1:d) e +prependDim = MkNDArr . runNDArr + +-- | Remove a dimension of size 1 from the front. +stripDim :: NDArr (1:d) e -> NDArr d e +stripDim = MkNDArr . runNDArr From e96c2165180879deac2e752165da9b99fc0fe0a9 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 23:45:13 +0300 Subject: [PATCH 18/20] Test matmul with `rep` and `act` --- test/Spec.hs | 2 ++ test/Test/MLambda/Matrix.hs | 62 +++++++++++++++++++++++++++++++++++++ test/Test/MLambda/NDArr.hs | 6 ++-- test/Test/MLambda/Utils.hs | 13 +++++--- 4 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 test/Test/MLambda/Matrix.hs diff --git a/test/Spec.hs b/test/Spec.hs index e367a5a..fcbb023 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,5 +1,6 @@ {-# LANGUAGE QuasiQuotes #-} +import Test.MLambda.Matrix import Test.MLambda.NDArr import Test.Tasty import Test.Tasty.HUnit @@ -28,4 +29,5 @@ tests = testGroup "Tests" [ testCase "Mul" $ a `cross` b @?= c , testNDArr + , testMatrix ] diff --git a/test/Test/MLambda/Matrix.hs b/test/Test/MLambda/Matrix.hs new file mode 100644 index 0000000..897bd40 --- /dev/null +++ b/test/Test/MLambda/Matrix.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE TypeAbstractions #-} +module Test.MLambda.Matrix (testMatrix) where + +import MLambda.Index +import MLambda.Matrix +import MLambda.NDArr as NDArr +import MLambda.TypeLits + +import Test.MLambda.Utils + +import Control.Monad +import Data.Bool.Singletons +import Data.List.Singletons (type (++)) +import Data.Maybe +import Data.Proxy +import Data.Singletons +import GHC.TypeLits.Singletons +import Prelude.Singletons ((%+)) +import Test.Falsify.Generator qualified as Gen +import Test.Falsify.Predicate ((.$)) +import Test.Falsify.Predicate qualified as Pred +import Test.Tasty +import Test.Tasty.Falsify + +propId :: Property () +propId = do + (m, n) <- gen $ (,) <$> genSz <*> genSz + (SomeSing (SNat @m)) <- pure $ toSing m + (SomeSing (SNat @n)) <- pure $ toSing n + STrue <- pure (sing @1 %<=? sing @m) + STrue <- pure (sing @1 %<=? sing @n) + a <- gen $ genNDArr @'[n, m] genDouble + let a' = rep $ act a + assert $ Pred.eq .$ ("rep . act", a') .$ ("id", a) + +propCompose :: Property () +propCompose = do + (m, n) <- gen $ (,) <$> genSz <*> genSz + (SomeSing (SNat @m)) <- pure $ toSing m + (SomeSing (SNat @n)) <- pure $ toSing n + (SomeSing (SNat @k)) <- pure $ toSing n + STrue <- pure (sing @1 %<=? sing @m) + STrue <- pure (sing @1 %<=? sing @n) + STrue <- pure (sing @1 %<=? sing @k) + a <- gen $ genNDArr @'[n, m] genDouble + b <- gen $ genNDArr @'[m, k] genDouble + let c = a `cross` b + c' = rep (act a . act b) + eps = 10**(-7) + diff = NDArr.zipWith (\a b -> abs (a - b)) c c' + forM_ [minBound..maxBound] $ \i -> + when (diff `at` i > eps) $ testFailed "matrices unequal" + + +testMatmul :: TestTree +testMatmul = testGroup "matmul" + [ testProperty "rep . act = id" propId + , testProperty "rep (act a . act b) = a @ b" propCompose + ] + +testMatrix :: TestTree +testMatrix = testGroup "Matrix" [testMatmul] diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index a52e876..99edc48 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -37,7 +37,7 @@ propFromIndexDotAt :: Property () propFromIndexDotAt = do dim <- gen $ genDim 0 5 let p = withIx dim \(Proxy @dim) -> do - arr <- gen (genNDArr @dim) + arr <- gen $ genNDArr @dim genInt assert $ Pred.eq .$ ("fromIndex . at", (fromIndex . at) arr) .$ ("id", arr) @@ -70,8 +70,8 @@ propStack = do , concatIndexI (IxI @p) (IxI @(m2 : s)) , concatIndexI (IxI @p) (IxI @((m1 + m2) : s))) of (IxI, IxI, IxI) -> do - arr1 <- gen $ genNDArr @(p ++ (m1 : s)) - arr2 <- gen $ genNDArr @(p ++ (m2 : s)) + arr1 <- gen $ genNDArr @(p ++ (m1 : s)) genInt + arr2 <- gen $ genNDArr @(p ++ (m2 : s)) genInt let arr3 = stackWithWitness (SW (Proxy @'(p, m1, m2, s))) arr1 arr2 s <- gen $ genIndex @s p <- gen $ genIndex @p diff --git a/test/Test/MLambda/Utils.hs b/test/Test/MLambda/Utils.hs index ef1b2af..b3fec3f 100644 --- a/test/Test/MLambda/Utils.hs +++ b/test/Test/MLambda/Utils.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TypeAbstractions #-} module Test.MLambda.Utils where import MLambda.Index @@ -5,9 +6,9 @@ import MLambda.NDArr import MLambda.TypeLits import Data.Proxy +import Foreign.Storable import Test.Falsify.Generator (Gen) import Test.Falsify.Generator qualified as Gen -import Test.Falsify.Range (Range) import Test.Falsify.Range qualified as Range genSz :: Gen Natural @@ -19,6 +20,10 @@ genDim a b = Gen.list (Range.between (a, b)) genSz genInt :: Gen Int genInt = Gen.inRange $ Range.between (-1000, 1000) +genDouble :: Gen Double +genDouble = Gen.inRange $ Range.fromProperFraction 64 + \(Range.ProperFraction d) -> 1 + 4 * d + genIndex :: forall dim . Ix dim => Gen (Index dim) genIndex = case inst @dim of EI -> pure E @@ -27,9 +32,9 @@ genIndex = case inst @dim of t <- genIndex @t pure ((toEnum h :: Index '[h]) :. t) -genNDArr :: forall dim . Ix dim => Gen (NDArr dim Int) -genNDArr = do - Gen.Fn f <- Gen.fun genInt +genNDArr :: forall dim e . (Storable e, Ix dim) => Gen e -> Gen (NDArr dim e) +genNDArr g = do + Gen.Fn f <- Gen.fun g pure $ fromIndex f instance Enum (Index dim) => Gen.Function (Index dim) where From 0fbed2b2d827cc7b8568798cf4d60cc22a63c462 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sun, 17 Aug 2025 23:48:11 +0300 Subject: [PATCH 19/20] Format & doc --- src/MLambda/Matrix.hs | 3 ++- src/MLambda/NDArr.hs | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index c3fe01d..af436fa 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -31,7 +31,7 @@ module MLambda.Matrix import MLambda.Foreign.Utils (asFPtr, asPtr, char) import MLambda.Index import MLambda.Linear -import MLambda.NDArr hiding (concat, zipWith, map, foldr) +import MLambda.NDArr hiding (concat, foldr, map, zipWith) import MLambda.NDArr qualified as NDArr import MLambda.TypeLits @@ -200,6 +200,7 @@ rep :: forall m n e . => LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e -> NDArr '[n, m] e rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`)) +-- | A linear map of vector spaces that corresponds to a matrix. act :: forall m n e . (KnownNat m, 1 <= m , Storable e) => NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a) diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 1270555..c41b1e7 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -47,8 +47,8 @@ module MLambda.NDArr ) where import MLambda.Index -import MLambda.TypeLits import MLambda.Linear +import MLambda.TypeLits import Control.DeepSeq (NFData) import Control.Monad.ST (runST) @@ -61,7 +61,7 @@ import Data.Vector.Storable.Mutable qualified as Mutable import Foreign.Ptr (Ptr, castPtr) import Foreign.Storable (Storable (..)) import GHC.TypeError (ErrorMessage (..), TypeError) -import Prelude hiding (map, concat, zipWith, foldr) +import Prelude hiding (concat, foldr, map, zipWith) -- | @NDArr [n1,...,nd] e@ is a type of arrays with dimensions @n1 x ... x nd@ -- consisting of elements of type @e@. From ca2043489c5cde4c29361cf715384e0f7462b459 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Mon, 18 Aug 2025 00:00:49 +0300 Subject: [PATCH 20/20] Flatten code and clean up --- ml.cabal | 3 -- package.yaml | 2 +- src/MLambda/Matrix.hs | 1 - test/Test/MLambda/Matrix.hs | 8 +--- test/Test/MLambda/NDArr.hs | 82 +++++++++++++++++-------------------- 5 files changed, 39 insertions(+), 57 deletions(-) diff --git a/ml.cabal b/ml.cabal index 1f8d770..0943ff9 100644 --- a/ml.cabal +++ b/ml.cabal @@ -46,7 +46,6 @@ library base >=4.7 && <5 , blas-ffi , deepseq - , ghc-typelits-natnormalise , haskell-src-meta , massiv , mtl @@ -82,7 +81,6 @@ test-suite ml-test , blas-ffi , deepseq , falsify - , ghc-typelits-natnormalise , haskell-src-meta , massiv , ml @@ -117,7 +115,6 @@ benchmark ml-bench base >=4.7 && <5 , blas-ffi , deepseq - , ghc-typelits-natnormalise , haskell-src-meta , massiv , ml diff --git a/package.yaml b/package.yaml index 60f8660..6b67cea 100644 --- a/package.yaml +++ b/package.yaml @@ -31,7 +31,7 @@ dependencies: - haskell-src-meta - singletons - singletons-base -- ghc-typelits-natnormalise +# - ghc-typelits-natnormalise ghc-options: - -Wall diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index af436fa..87d6a3d 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -3,7 +3,6 @@ {-# LANGUAGE RequiredTypeArguments #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -- | -- Module : MLambda.Matrix diff --git a/test/Test/MLambda/Matrix.hs b/test/Test/MLambda/Matrix.hs index 897bd40..7b5d525 100644 --- a/test/Test/MLambda/Matrix.hs +++ b/test/Test/MLambda/Matrix.hs @@ -1,7 +1,6 @@ {-# LANGUAGE TypeAbstractions #-} module Test.MLambda.Matrix (testMatrix) where -import MLambda.Index import MLambda.Matrix import MLambda.NDArr as NDArr import MLambda.TypeLits @@ -10,13 +9,8 @@ import Test.MLambda.Utils import Control.Monad import Data.Bool.Singletons -import Data.List.Singletons (type (++)) -import Data.Maybe -import Data.Proxy import Data.Singletons import GHC.TypeLits.Singletons -import Prelude.Singletons ((%+)) -import Test.Falsify.Generator qualified as Gen import Test.Falsify.Predicate ((.$)) import Test.Falsify.Predicate qualified as Pred import Test.Tasty @@ -47,7 +41,7 @@ propCompose = do let c = a `cross` b c' = rep (act a . act b) eps = 10**(-7) - diff = NDArr.zipWith (\a b -> abs (a - b)) c c' + diff = NDArr.zipWith (\x y -> abs (x - y)) c c' forM_ [minBound..maxBound] $ \i -> when (diff `at` i > eps) $ testFailed "matrices unequal" diff --git a/test/Test/MLambda/NDArr.hs b/test/Test/MLambda/NDArr.hs index 99edc48..272bf8a 100644 --- a/test/Test/MLambda/NDArr.hs +++ b/test/Test/MLambda/NDArr.hs @@ -11,7 +11,6 @@ import Test.MLambda.Utils import Control.Monad import Data.Bool.Singletons import Data.List.Singletons (type (++)) -import Data.Maybe import Data.Proxy import Data.Singletons import GHC.TypeLits.Singletons @@ -25,23 +24,23 @@ import Test.Tasty.Falsify propAtDotFromIndex :: Property () propAtDotFromIndex = do dim <- gen $ genDim 0 5 - let p = withIx dim \(Proxy @dim) -> do - Fn f <- gen $ Gen.fun genInt - i <- gen (genIndex @dim) - assert $ Pred.eq - .$ ("at . fromIndex", (at . fromIndex) f i) - .$ ("id", f i) - fromMaybe (error "propAtDotFromIndex: impossible") p + SomeSing sdim <- pure $ toSing dim + Just (IxI @dim) <- pure $ singToIndexI sdim + Fn f <- gen $ Gen.fun genInt + i <- gen (genIndex @dim) + assert $ Pred.eq + .$ ("at . fromIndex", (at . fromIndex) f i) + .$ ("id", f i) propFromIndexDotAt :: Property () propFromIndexDotAt = do dim <- gen $ genDim 0 5 - let p = withIx dim \(Proxy @dim) -> do - arr <- gen $ genNDArr @dim genInt - assert $ Pred.eq - .$ ("fromIndex . at", (fromIndex . at) arr) - .$ ("id", arr) - fromMaybe (error "propFromIndexDotAt: impossible") p + SomeSing sdim <- pure $ toSing dim + Just (IxI @dim) <- pure $ singToIndexI sdim + arr <- gen $ genNDArr @dim genInt + assert $ Pred.eq + .$ ("fromIndex . at", (fromIndex . at) arr) + .$ ("id", arr) testFromIndex :: TestTree testFromIndex = testGroup "fromIndex" @@ -54,37 +53,30 @@ propStack = do dimpref <- gen $ genDim 0 2 dimmid1 <- gen genSz dimmid2 <- gen genSz - case ( toSing dimsuff - , toSing dimpref - , toSing dimmid1 - , toSing dimmid2) of - (SomeSing (singToIndexI -> Just (IxI @p)) - , SomeSing (singToIndexI -> Just (IxI @s)) - , SomeSing sm1@(SNat @m1), SomeSing sm2@(SNat @m2)) -> - case ( sing @1 %<=? sm1 - , sing @1 %<=? sm2 - , sm1 %+ sm2 - , sing @1 %<=? sm1 %+ sm2) of - (STrue, STrue, SNat, STrue) -> - case ( concatIndexI (IxI @p) (IxI @(m1 : s)) - , concatIndexI (IxI @p) (IxI @(m2 : s)) - , concatIndexI (IxI @p) (IxI @((m1 + m2) : s))) of - (IxI, IxI, IxI) -> do - arr1 <- gen $ genNDArr @(p ++ (m1 : s)) genInt - arr2 <- gen $ genNDArr @(p ++ (m2 : s)) genInt - let arr3 = stackWithWitness (SW (Proxy @'(p, m1, m2, s))) arr1 arr2 - s <- gen $ genIndex @s - p <- gen $ genIndex @p - m <- gen $ genIndex @'[m1 + m2] - let i1 = p `concatIndex` ((toEnum (fromEnum m) :: Index '[m1]) :. s) - i2 = p `concatIndex` - ((toEnum (fromEnum m - enumSize (Index [m1])) :: Index '[m2]) :. s) - i3 = p `concatIndex` (m :. s) - if fromEnum m <= fromEnum (maxBound @(Index '[m1])) - then when (arr1 `at` i1 /= arr3 `at` i3) $ testFailed "mid index in lhs" - else when (arr2 `at` i2 /= arr3 `at` i3) $ testFailed "mid index in rhs" - _ -> error "propStack: impossible" - _ -> error "propStack: impossible" + SomeSing (singToIndexI -> Just (IxI @p)) <- pure $ toSing dimsuff + SomeSing (singToIndexI -> Just (IxI @s)) <- pure $ toSing dimpref + SomeSing sm1@(SNat @m1) <- pure $ toSing dimmid1 + SomeSing sm2@(SNat @m2) <- pure $ toSing dimmid2 + STrue <- pure $ sing @1 %<=? sm1 + STrue <- pure $ sing @1 %<=? sm2 + SNat <- pure $ sm1 %+ sm2 + STrue <- pure $ sing @1 %<=? sm1 %+ sm2 + IxI <- pure $ concatIndexI (IxI @p) (IxI @(m1 : s)) + IxI <- pure $ concatIndexI (IxI @p) (IxI @(m2 : s)) + IxI <- pure $ concatIndexI (IxI @p) (IxI @((m1 + m2) : s)) + arr1 <- gen $ genNDArr @(p ++ (m1 : s)) genInt + arr2 <- gen $ genNDArr @(p ++ (m2 : s)) genInt + let arr3 = stackWithWitness (SW (Proxy @'(p, m1, m2, s))) arr1 arr2 + s <- gen $ genIndex @s + p <- gen $ genIndex @p + m <- gen $ genIndex @'[m1 + m2] + let i1 = p `concatIndex` ((toEnum (fromEnum m) :: Index '[m1]) :. s) + i2 = p `concatIndex` + ((toEnum (fromEnum m - enumSize (Index [m1])) :: Index '[m2]) :. s) + i3 = p `concatIndex` (m :. s) + if fromEnum m <= fromEnum (maxBound @(Index '[m1])) + then when (arr1 `at` i1 /= arr3 `at` i3) $ testFailed "mid index in lhs" + else when (arr2 `at` i2 /= arr3 `at` i3) $ testFailed "mid index in rhs" testStack :: TestTree testStack = testProperty "stack" propStack