From f554a8c86744caf7eb6c42678fa7ac17cfee6342 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Mon, 18 Aug 2025 03:15:32 +0300 Subject: [PATCH 01/14] Make `cross` a bit more generic --- src/MLambda/Matrix.hs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index 87d6a3d..b58c9eb 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -55,7 +55,8 @@ import GHC.Ptr import Language.Haskell.Meta.Parse qualified as Haskell import Language.Haskell.TH qualified as TH import Language.Haskell.TH.Quote qualified as Quote -import Numeric.BLAS.FFI.Double +import Numeric.BLAS.FFI.Generic +import Numeric.Netlib.Class qualified as BLAS import Text.ParserCombinators.ReadP qualified as P massivSize :: forall m n -> (KnownNat m, KnownNat n) => Massiv.Sz2 @@ -74,13 +75,13 @@ toMassiv = Massiv.fromVector' Massiv.Par (massivSize m n) . runNDArr -- | Matrix product reused from massiv. crossMassiv :: - (KnownNat m, KnownNat k, KnownNat n) => - NDArr [m, k] Double -> NDArr [k, n] Double -> NDArr [m, n] Double + (KnownNat m, KnownNat k, KnownNat n, Num e, Storable e) => + NDArr [m, k] e -> NDArr [k, n] e -> NDArr [m, n] e crossMassiv a b = fromMassiv $ toMassiv a Massiv.!> NDArr [m, k] Double -> NDArr [k, n] Double -> NDArr [m, n] Double +cross :: forall m k n e . (KnownNat n, KnownNat m, KnownNat k, BLAS.Floating e) + => NDArr [m, k] e -> NDArr [k, n] e -> NDArr [m, n] e (runNDArr -> a) `cross` (runNDArr -> b) = unsafePerformIO $ evalContT do let (afptr, _alen) = Storable.unsafeToForeignPtr0 a (bfptr, _blen) = Storable.unsafeToForeignPtr0 b @@ -205,6 +206,6 @@ act :: forall m n e . (KnownNat m, 1 <= m , Storable e) act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a) -- | Matrix transposition. -transpose :: (KnownNat m, 1 <= m , KnownNat n, 1 <= n , Storable e) +transpose :: (KnownNat m, 1 <= m, KnownNat n, 1 <= n, Storable e) => NDArr '[m, n] e -> NDArr '[n, m] e transpose m = fromIndex \(i :. j) -> m `at` (j :. i) From 6d3fe15833c2f244f7e768399ec72429ba853fe6 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Mon, 18 Aug 2025 03:21:27 +0300 Subject: [PATCH 02/14] Benchmark initialization against `Data.Vector.Storable` --- bench/Bench.hs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bench/Bench.hs b/bench/Bench.hs index c574199..ef93e35 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -2,9 +2,10 @@ import MLambda.Matrix import MLambda.NDArr -import MLambda.TypeLits (KnownNat) +import MLambda.TypeLits (KnownNat, natVal) import Data.Random.Normal (normalIO) +import Data.Vector.Storable qualified as Storable import GHC.TypeLits (type (<=)) import System.Random (mkStdGen, setStdGen) import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO) @@ -19,9 +20,16 @@ setup = (,) <$ setStdGen (mkStdGen 0) mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n) => IO (NDArr [m, n] Double) mkNd m n = fromIndexM @[m, n] (const normalIO) +mkVec :: forall m n -> (KnownNat m, KnownNat n) + => IO (Storable.Vector Double) +mkVec m n = Storable.replicateM (natVal n * natVal m) normalIO + main :: IO () main = defaultMain - [ bench "random init" $ nfIO $ mkNd M N + [ bgroup "random init" + [ bench "NDArr" $ nfIO $ mkNd M N + , bench "Storable.Vector" $ nfIO $ mkVec M N + ] , env (setup <*> mkNd M K <*> mkNd K N) \input -> bgroup "matmul" [ bench "Massiv" $ nf (uncurry crossMassiv) input From 620df61f452ce5be37f974dc7454fd36f2e7dfce Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Mon, 18 Aug 2025 08:02:56 +0300 Subject: [PATCH 03/14] Add naive matmul implementation --- bench/Bench.hs | 1 + src/MLambda/Matrix.hs | 31 +++++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/bench/Bench.hs b/bench/Bench.hs index ef93e35..b88b890 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -34,5 +34,6 @@ main = defaultMain bgroup "matmul" [ bench "Massiv" $ nf (uncurry crossMassiv) input , bench "OpenBLAS" $ nf (uncurry cross) input + , bench "Naive" $ nf (uncurry crossNaive) input ] ] diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index b58c9eb..c3f531f 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -18,6 +18,7 @@ module MLambda.Matrix -- * Matrix multiplication ( cross , crossMassiv + , crossNaive -- * Matrix creation , mat , eye @@ -105,6 +106,29 @@ cross :: forall m k n e . (KnownNat n, KnownNat m, KnownNat k, BLAS.Floating e) let carr = Storable.unsafeFromForeignPtr0 cfptr len pure $ unsafeMkNDArr carr +crossGeneric :: forall m k n e1 e2 e3 . + ( KnownNat n, KnownNat m, KnownNat k + , 1 <= n, 1 <= m, 1 <= k + , Storable e1, Storable e2, Storable e3) + => (e1 -> e2 -> e3) -> (e3 -> e3 -> e3) + -> NDArr '[m, k] e1 -> NDArr '[k, n] e2 -> NDArr '[m, n] e3 +crossGeneric mul plus a b = runST do + mvec <- Mutable.new (natVal m * natVal n) + forM_ [minBound..maxBound :: Index '[m]] \i -> + forM_ [minBound..maxBound :: Index '[k]] \k -> + forM_ [minBound..maxBound :: Index '[n]] \j -> + Mutable.modify mvec + (plus (mul (a `at` (i :. k)) (b `at` (k :. j)))) + (fromEnum (i :. j)) + unsafeMkNDArr @'[m, n] <$> Storable.unsafeFreeze mvec + +-- | Naive implementation of matrix multiplication. +crossNaive :: ( KnownNat n, KnownNat m, KnownNat k + , 1 <= n, 1 <= m, 1 <= k + , Storable e, Num e) + => NDArr '[m, k] e -> NDArr '[k, n] e -> NDArr '[m, n] e +crossNaive = crossGeneric (*) (+) + infixl 7 `cross` infixl 7 `crossMassiv` @@ -201,9 +225,12 @@ rep :: forall m n e . rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`)) -- | A linear map of vector spaces that corresponds to a matrix. -act :: forall m n e . (KnownNat m, 1 <= m , Storable e) +act :: forall m n e . + ( KnownNat m, 1 <= m + , KnownNat n, 1 <= n + , Storable e) => NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e -act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a) +act a = reshape [n] . crossGeneric modMult add a . reshape [m, 1] -- | Matrix transposition. transpose :: (KnownNat m, 1 <= m, KnownNat n, 1 <= n, Storable e) From b87c7539194766919358a53e09ca965198e7f83f Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sat, 23 Aug 2025 11:15:24 +0300 Subject: [PATCH 04/14] Add means of efficient iteration over indices --- src/MLambda/Index.hs | 13 +++++++++++++ src/MLambda/NDArr.hs | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 8534137..891d5cc 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -1,4 +1,5 @@ {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RequiredTypeArguments #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE ViewPatterns #-} -- | @@ -26,15 +27,19 @@ module MLambda.Index , withIx -- * Index operations , concatIndex + -- * Efficient iteration + , loop_ ) where import MLambda.TypeLits +import Control.Monad import Data.Bool.Singletons import Data.List.Singletons import Data.Singletons import GHC.TypeLits.Singletons hiding (natVal) + -- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@. -- Instances are provided for convenient use: -- @@ -98,6 +103,14 @@ instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => pred (h :. t) | t == minBound = pred h :. maxBound | otherwise = h :. pred t +-- | Efficiently (not yet) iterate through indices in lexicographic order. +loop_ :: forall d m e . (Ix d, Monad m) => (Index d -> m e) -> m () +loop_ f = + case IxI @d of + EI -> void $ f E + _ :.= _ -> forM_ [minBound..maxBound] \i -> loop_ (f . (i :.)) + + -- | Concatenate two indices together concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys) concatIndex E = id diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index c41b1e7..d6279d6 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -118,7 +118,7 @@ fromIndexM :: forall dim m e . (Mutable.PrimMonad m, Ix dim, Storable e) => (Index dim -> m e) -> m (NDArr dim e) fromIndexM f = do mvec <- Mutable.new (enumSize (Index dim)) - forM_ [minBound..maxBound] (\i -> f i >>= Mutable.write mvec (fromEnum i)) + loop_ (\i -> f i >>= Mutable.write mvec (fromEnum i)) vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec From 5eefeed1c74c1933dc7f2a810b581cf32c563eb4 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sat, 23 Aug 2025 11:54:04 +0300 Subject: [PATCH 05/14] Make iteration actually efficient --- src/MLambda/Index.hs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 891d5cc..202a925 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -28,6 +28,7 @@ module MLambda.Index -- * Index operations , concatIndex -- * Efficient iteration + , enumerate , loop_ ) where @@ -103,13 +104,17 @@ instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => pred (h :. t) | t == minBound = pred h :. maxBound | otherwise = h :. pred t --- | Efficiently (not yet) iterate through indices in lexicographic order. -loop_ :: forall d m e . (Ix d, Monad m) => (Index d -> m e) -> m () -loop_ f = +-- | Efficiently (compared to @`enumFromTo`@) enumerate all indices in lexicographic +-- order. +enumerate :: forall d -> Ix d => [Index d] +enumerate d = case IxI @d of - EI -> void $ f E - _ :.= _ -> forM_ [minBound..maxBound] \i -> loop_ (f . (i :.)) + EI -> [] + _ :.= IxI @r -> (:.) <$> [minBound..maxBound] <*> enumerate r +-- | Efficiently (not yet) iterate through indices in lexicographic order. +loop_ :: forall d m e . (Ix d, Monad m) => (Index d -> m e) -> m () +loop_ = forM_ (enumerate d) -- | Concatenate two indices together concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys) From 3b67c536c2cab3c271887c5f2f880b775b605c5b Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sat, 23 Aug 2025 11:55:27 +0300 Subject: [PATCH 06/14] Style --- src/MLambda/Index.hs | 2 +- src/MLambda/NDArr.hs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 202a925..98a32b7 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -109,7 +109,7 @@ instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => enumerate :: forall d -> Ix d => [Index d] enumerate d = case IxI @d of - EI -> [] + EI -> [] _ :.= IxI @r -> (:.) <$> [minBound..maxBound] <*> enumerate r -- | Efficiently (not yet) iterate through indices in lexicographic order. diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index d6279d6..406231b 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -52,7 +52,6 @@ import MLambda.TypeLits import Control.DeepSeq (NFData) import Control.Monad.ST (runST) -import Data.Foldable (forM_) import Data.List qualified as List import Data.List.Singletons import Data.Singletons From 7e46c4f38a067f9313fd894f20bed42a59f53e87 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Sat, 23 Aug 2025 13:07:10 +0300 Subject: [PATCH 07/14] Optimize more --- bench/Bench.hs | 6 +++--- src/MLambda/Index.hs | 5 +++-- src/MLambda/Matrix.hs | 6 +++--- src/MLambda/NDArr.hs | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/bench/Bench.hs b/bench/Bench.hs index b88b890..e954069 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -10,9 +10,9 @@ import GHC.TypeLits (type (<=)) import System.Random (mkStdGen, setStdGen) import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO) -type M = 1000 -type K = 1000 -type N = 1000 +type M = 100 +type K = 100 +type N = 100 setup :: IO (a -> b -> (a, b)) setup = (,) <$ setStdGen (mkStdGen 0) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 98a32b7..82c61f0 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -109,10 +109,11 @@ instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => enumerate :: forall d -> Ix d => [Index d] enumerate d = case IxI @d of - EI -> [] + EI -> [E] + -- TODO: this performs very poorly, optimize _ :.= IxI @r -> (:.) <$> [minBound..maxBound] <*> enumerate r --- | Efficiently (not yet) iterate through indices in lexicographic order. +-- | Efficiently iterate through indices in lexicographic order. loop_ :: forall d m e . (Ix d, Monad m) => (Index d -> m e) -> m () loop_ = forM_ (enumerate d) diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index c3f531f..e6341df 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -114,9 +114,9 @@ crossGeneric :: forall m k n e1 e2 e3 . -> NDArr '[m, k] e1 -> NDArr '[k, n] e2 -> NDArr '[m, n] e3 crossGeneric mul plus a b = runST do mvec <- Mutable.new (natVal m * natVal n) - forM_ [minBound..maxBound :: Index '[m]] \i -> - forM_ [minBound..maxBound :: Index '[k]] \k -> - forM_ [minBound..maxBound :: Index '[n]] \j -> + loop_ \i -> + loop_ \k -> + loop_ \j -> Mutable.modify mvec (plus (mul (a `at` (i :. k)) (b `at` (k :. j)))) (fromEnum (i :. j)) diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 406231b..1d449a7 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -116,7 +116,7 @@ fromIndex f = runST $ fromIndexM $ pure . f fromIndexM :: forall dim m e . (Mutable.PrimMonad m, Ix dim, Storable e) => (Index dim -> m e) -> m (NDArr dim e) fromIndexM f = do - mvec <- Mutable.new (enumSize (Index dim)) + mvec <- Mutable.unsafeNew (enumSize (Index dim)) loop_ (\i -> f i >>= Mutable.write mvec (fromEnum i)) vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec From b5e85f40accf3f0f9b3c19325e146b70fa89be45 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Mon, 5 Jan 2026 18:44:58 +0300 Subject: [PATCH 08/14] Optimize `fromEnum` for indices --- src/MLambda/Index.hs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 82c61f0..77abf46 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -87,6 +87,15 @@ instance Bounded (Index '[]) where minBound = E maxBound = E +class FoldI d where + foldI :: r -> (forall d -> KnownNat d => r -> Int -> r) -> Index d -> r + +instance FoldI '[] where + foldI !acc f = const acc + +instance (KnownNat d, FoldI ds) => FoldI (d:ds) where + foldI !acc f (ICons h t) = foldI (f d acc h) f t + instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where minBound = 0 :. minBound maxBound = (-1) :. maxBound @@ -95,10 +104,13 @@ instance Enum (Index '[]) where toEnum = const E fromEnum = const 0 -instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => +instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d), FoldI d) => Enum (Index (n:d)) where toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r - fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r + fromEnum = foldI 0 f + where + f :: forall d -> KnownNat d => Int -> Int -> Int + f n r i = natVal n * r + i succ (h :. t) | t == maxBound = succ h :. minBound | otherwise = h :. succ t pred (h :. t) | t == minBound = pred h :. maxBound @@ -162,7 +174,7 @@ concatIndexI (i1 :.= d1) d2 = case concatIndexI d1 d2 of -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain -- a value of @`IndexI`@. -class (Bounded (Index dim), Enum (Index dim)) => Ix dim where +class (Bounded (Index dim), Enum (Index dim), FoldI dim) => Ix dim where -- | Returns a term-level witness of @Ix@. inst :: IndexI dim From 013c3a0ff470330686d353cb46288fed638d464f Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Tue, 20 Jan 2026 23:10:39 +0300 Subject: [PATCH 09/14] Bench against `massiv` --- bench/Bench.hs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bench/Bench.hs b/bench/Bench.hs index e954069..c3274c1 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE PartialTypeSignatures #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RequiredTypeArguments #-} import MLambda.Matrix @@ -5,6 +7,9 @@ import MLambda.NDArr import MLambda.TypeLits (KnownNat, natVal) import Data.Random.Normal (normalIO) +import Data.Massiv.Array (Array, pattern Sz2, Ix2, Comp(..)) +import Data.Massiv.Array.Manifest (S) +import Data.Massiv.Array.Mutable (makeMArrayS, freeze) import Data.Vector.Storable qualified as Storable import GHC.TypeLits (type (<=)) import System.Random (mkStdGen, setStdGen) @@ -24,11 +29,15 @@ mkVec :: forall m n -> (KnownNat m, KnownNat n) => IO (Storable.Vector Double) mkVec m n = Storable.replicateM (natVal n * natVal m) normalIO +mkMassiv :: forall m n -> (KnownNat m, KnownNat n) => IO (Array S Ix2 Double) +mkMassiv m n = freeze Seq =<< makeMArrayS (Sz2 (natVal n) (natVal m)) (const normalIO) + main :: IO () main = defaultMain [ bgroup "random init" [ bench "NDArr" $ nfIO $ mkNd M N , bench "Storable.Vector" $ nfIO $ mkVec M N + , bench "massiv" $ nfIO $ mkMassiv M N ] , env (setup <*> mkNd M K <*> mkNd K N) \input -> bgroup "matmul" From 934a794a5fb6a3917a462dbc4f3979758b4d6086 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Tue, 20 Jan 2026 23:11:06 +0300 Subject: [PATCH 10/14] Optimize `succ` (did not help) --- src/MLambda/Index.hs | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 77abf46..21ec1bb 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -37,6 +37,7 @@ import MLambda.TypeLits import Control.Monad import Data.Bool.Singletons import Data.List.Singletons +import Data.Maybe (fromJust) import Data.Singletons import GHC.TypeLits.Singletons hiding (natVal) @@ -100,21 +101,34 @@ instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where minBound = 0 :. minBound maxBound = (-1) :. maxBound +class EnumImpl i where + succM :: i -> Maybe i + predM :: i -> Maybe i + +instance EnumImpl (Index '[]) where + succM = const Nothing + predM = const Nothing + +instance (KnownNat n, 1 <= n, EnumImpl (Index d), Bounded (Index d)) => + EnumImpl (Index (n:d)) where + succM (h :. (succM -> Just t)) = Just $ h :. t + succM ((succM -> h) :. _) = fmap (:. minBound) h + predM (h :. (predM -> Just t)) = Just $ h :. t + predM ((predM -> h) :. _) = fmap (:. maxBound) h + instance Enum (Index '[]) where toEnum = const E fromEnum = const 0 -instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d), FoldI d) => +instance (KnownNat n, 1 <= n, EnumImpl (Index d), Enum (Index d), Bounded (Index d), FoldI d) => Enum (Index (n:d)) where toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r fromEnum = foldI 0 f where f :: forall d -> KnownNat d => Int -> Int -> Int f n r i = natVal n * r + i - succ (h :. t) | t == maxBound = succ h :. minBound - | otherwise = h :. succ t - pred (h :. t) | t == minBound = pred h :. maxBound - | otherwise = h :. pred t + succ = fromJust . succM + pred = fromJust . predM -- | Efficiently (compared to @`enumFromTo`@) enumerate all indices in lexicographic -- order. @@ -174,7 +188,7 @@ concatIndexI (i1 :.= d1) d2 = case concatIndexI d1 d2 of -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain -- a value of @`IndexI`@. -class (Bounded (Index dim), Enum (Index dim), FoldI dim) => Ix dim where +class (Bounded (Index dim), Enum (Index dim), EnumImpl (Index dim), FoldI dim) => Ix dim where -- | Returns a term-level witness of @Ix@. inst :: IndexI dim From 3241389c5a77d05202b2f68c03f65caad7735b77 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Tue, 20 Jan 2026 23:23:03 +0300 Subject: [PATCH 11/14] Remove useless conversions from `Index` --- ml.cabal | 3 +++ package.yaml | 1 + src/MLambda/NDArr.hs | 7 ++++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ml.cabal b/ml.cabal index 0943ff9..e70cb9e 100644 --- a/ml.cabal +++ b/ml.cabal @@ -50,6 +50,7 @@ library , massiv , mtl , netlib-ffi + , primitive , singletons , singletons-base , template-haskell @@ -86,6 +87,7 @@ test-suite ml-test , ml , mtl , netlib-ffi + , primitive , singletons , singletons-base , tasty @@ -121,6 +123,7 @@ benchmark ml-bench , mtl , netlib-ffi , normaldistribution + , primitive , random , singletons , singletons-base diff --git a/package.yaml b/package.yaml index 6b67cea..6bb4f6b 100644 --- a/package.yaml +++ b/package.yaml @@ -31,6 +31,7 @@ dependencies: - haskell-src-meta - singletons - singletons-base +- primitive # - ghc-typelits-natnormalise ghc-options: diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 1d449a7..39d9a6a 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -54,6 +54,7 @@ import Control.DeepSeq (NFData) import Control.Monad.ST (runST) import Data.List qualified as List import Data.List.Singletons +import Data.Primitive.PrimVar import Data.Singletons import Data.Vector.Storable qualified as Storable import Data.Vector.Storable.Mutable qualified as Mutable @@ -117,7 +118,11 @@ fromIndexM :: forall dim m e . (Mutable.PrimMonad m, Ix dim, Storable e) => (Index dim -> m e) -> m (NDArr dim e) fromIndexM f = do mvec <- Mutable.unsafeNew (enumSize (Index dim)) - loop_ (\i -> f i >>= Mutable.write mvec (fromEnum i)) + ivar <- newPrimVar 0 + loop_ (\index -> do + val <- f index + i <- fetchAddInt ivar 1 + Mutable.write mvec i val) vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec From ca6d07075c1eaca7cbc977fc7dd33a9ce8a1a00c Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Thu, 22 Jan 2026 00:49:45 +0300 Subject: [PATCH 12/14] Fix `EnumImpl` --- src/MLambda/Index.hs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 21ec1bb..c096c63 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -111,10 +111,14 @@ instance EnumImpl (Index '[]) where instance (KnownNat n, 1 <= n, EnumImpl (Index d), Bounded (Index d)) => EnumImpl (Index (n:d)) where - succM (h :. (succM -> Just t)) = Just $ h :. t - succM ((succM -> h) :. _) = fmap (:. minBound) h - predM (h :. (predM -> Just t)) = Just $ h :. t - predM ((predM -> h) :. _) = fmap (:. maxBound) h + succM (I h :. t) + | Just t' <- succM t = Just $ I h :. t' + | h < natVal n - 1 = Just $ I h :. minBound + | otherwise = Nothing + predM (I h :. t) + | Just t' <- predM t = Just $ I h :. t' + | h > 0 = Just $ I h :. maxBound + | otherwise = Nothing instance Enum (Index '[]) where toEnum = const E From f8a2633b9b54e8d26c90405cceab8f6debd59ada Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Mon, 23 Feb 2026 19:10:31 +0300 Subject: [PATCH 13/14] Optimize array construction further --- bench/Bench.hs | 6 +++--- package.yaml | 2 -- src/MLambda/NDArr.hs | 9 +++------ 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/bench/Bench.hs b/bench/Bench.hs index c3274c1..f159cae 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -6,10 +6,10 @@ import MLambda.Matrix import MLambda.NDArr import MLambda.TypeLits (KnownNat, natVal) -import Data.Random.Normal (normalIO) -import Data.Massiv.Array (Array, pattern Sz2, Ix2, Comp(..)) +import Data.Massiv.Array (Array, Comp (..), Ix2, pattern Sz2) import Data.Massiv.Array.Manifest (S) -import Data.Massiv.Array.Mutable (makeMArrayS, freeze) +import Data.Massiv.Array.Mutable (freeze, makeMArrayS) +import Data.Random.Normal (normalIO) import Data.Vector.Storable qualified as Storable import GHC.TypeLits (type (<=)) import System.Random (mkStdGen, setStdGen) diff --git a/package.yaml b/package.yaml index 6bb4f6b..9531666 100644 --- a/package.yaml +++ b/package.yaml @@ -77,8 +77,6 @@ tests: - tasty - tasty-hunit - falsify - - singletons - - singletons-base language: GHC2024 default-extensions: diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 39d9a6a..4144b66 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -51,10 +51,11 @@ import MLambda.Linear import MLambda.TypeLits import Control.DeepSeq (NFData) +import Control.Monad import Control.Monad.ST (runST) +import Data.Foldable (for_) import Data.List qualified as List import Data.List.Singletons -import Data.Primitive.PrimVar import Data.Singletons import Data.Vector.Storable qualified as Storable import Data.Vector.Storable.Mutable qualified as Mutable @@ -118,11 +119,7 @@ fromIndexM :: forall dim m e . (Mutable.PrimMonad m, Ix dim, Storable e) => (Index dim -> m e) -> m (NDArr dim e) fromIndexM f = do mvec <- Mutable.unsafeNew (enumSize (Index dim)) - ivar <- newPrimVar 0 - loop_ (\index -> do - val <- f index - i <- fetchAddInt ivar 1 - Mutable.write mvec i val) + for_ (zip [0..] $ enumerate dim) \(i, index) -> Mutable.write mvec i =<< f index vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec From c05281a8561cd4aa8569bb3d28f276e74833a2a9 Mon Sep 17 00:00:00 2001 From: Nikita Solodovnikov Date: Fri, 20 Mar 2026 00:38:36 +0300 Subject: [PATCH 14/14] add iota benchmarks because I said so --- bench/Bench.hs | 53 ++++++++++++++++++++++++++++++-------------------- ml.cabal | 1 + package.yaml | 1 + 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/bench/Bench.hs b/bench/Bench.hs index f159cae..45d424f 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -9,11 +9,14 @@ import MLambda.TypeLits (KnownNat, natVal) import Data.Massiv.Array (Array, Comp (..), Ix2, pattern Sz2) import Data.Massiv.Array.Manifest (S) import Data.Massiv.Array.Mutable (freeze, makeMArrayS) +import Data.Primitive.PrimVar import Data.Random.Normal (normalIO) import Data.Vector.Storable qualified as Storable -import GHC.TypeLits (type (<=)) +import Foreign.Storable +import GHC.TypeLits (type (<=), type (*)) import System.Random (mkStdGen, setStdGen) -import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO) +import Test.Tasty (localOption) +import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO, TimeMode(..)) type M = 100 type K = 100 @@ -22,27 +25,35 @@ type N = 100 setup :: IO (a -> b -> (a, b)) setup = (,) <$ setStdGen (mkStdGen 0) -mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n) => IO (NDArr [m, n] Double) -mkNd m n = fromIndexM @[m, n] (const normalIO) +mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n, Storable a) + => IO a -> IO (NDArr [m, n] a) +mkNd m n gen = fromIndexM @'[m, n] $ const gen -mkVec :: forall m n -> (KnownNat m, KnownNat n) - => IO (Storable.Vector Double) -mkVec m n = Storable.replicateM (natVal n * natVal m) normalIO +mkVec :: forall m n -> (KnownNat m, KnownNat n, Storable a) + => IO a -> IO (Storable.Vector a) +mkVec m n gen = Storable.replicateM (natVal n * natVal m) $ gen -mkMassiv :: forall m n -> (KnownNat m, KnownNat n) => IO (Array S Ix2 Double) -mkMassiv m n = freeze Seq =<< makeMArrayS (Sz2 (natVal n) (natVal m)) (const normalIO) +mkMassiv :: forall m n -> (KnownNat m, KnownNat n, Storable a) => IO a -> IO (Array S Ix2 a) +mkMassiv m n gen = freeze Seq =<< makeMArrayS (Sz2 (natVal n) (natVal m)) (const gen) main :: IO () -main = defaultMain - [ bgroup "random init" - [ bench "NDArr" $ nfIO $ mkNd M N - , bench "Storable.Vector" $ nfIO $ mkVec M N - , bench "massiv" $ nfIO $ mkMassiv M N +main = do + var <- newPrimVar 0 + defaultMain $ localOption WallTime <$> + [ bgroup "primvar iota init" + [ bench "NDArr" $ nfIO $ mkNd M N (fetchAddInt var 1) + , bench "Storable.Vector" $ nfIO $ mkVec M N (fetchAddInt var 1) + , bench "massiv" $ nfIO $ mkMassiv M N (fetchAddInt var 1) + ] + , bgroup "random init" + [ bench "NDArr" $ nfIO $ mkNd @Double M N normalIO + , bench "Storable.Vector" $ nfIO $ mkVec @Double M N normalIO + , bench "massiv" $ nfIO $ mkMassiv @Double M N normalIO + ] + , env (setup <*> mkNd @Double M K normalIO <*> mkNd @Double K N normalIO) \input -> + bgroup "matmul" + [ bench "Massiv" $ nf (uncurry crossMassiv) input + , bench "OpenBLAS" $ nf (uncurry cross) input + , bench "Naive" $ nf (uncurry crossNaive) input + ] ] - , env (setup <*> mkNd M K <*> mkNd K N) \input -> - bgroup "matmul" - [ bench "Massiv" $ nf (uncurry crossMassiv) input - , bench "OpenBLAS" $ nf (uncurry cross) input - , bench "Naive" $ nf (uncurry crossNaive) input - ] - ] diff --git a/ml.cabal b/ml.cabal index e70cb9e..75e726e 100644 --- a/ml.cabal +++ b/ml.cabal @@ -127,6 +127,7 @@ benchmark ml-bench , random , singletons , singletons-base + , tasty , tasty-bench , template-haskell , vector diff --git a/package.yaml b/package.yaml index 9531666..da5d39f 100644 --- a/package.yaml +++ b/package.yaml @@ -62,6 +62,7 @@ benchmarks: - ml - normaldistribution - random + - tasty - tasty-bench tests: