diff --git a/bench/Bench.hs b/bench/Bench.hs index c574199..45d424f 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -1,30 +1,59 @@ +{-# LANGUAGE PartialTypeSignatures #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RequiredTypeArguments #-} import MLambda.Matrix import MLambda.NDArr -import MLambda.TypeLits (KnownNat) +import MLambda.TypeLits (KnownNat, natVal) +import Data.Massiv.Array (Array, Comp (..), Ix2, pattern Sz2) +import Data.Massiv.Array.Manifest (S) +import Data.Massiv.Array.Mutable (freeze, makeMArrayS) +import Data.Primitive.PrimVar import Data.Random.Normal (normalIO) -import GHC.TypeLits (type (<=)) +import Data.Vector.Storable qualified as Storable +import Foreign.Storable +import GHC.TypeLits (type (<=), type (*)) import System.Random (mkStdGen, setStdGen) -import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO) +import Test.Tasty (localOption) +import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO, TimeMode(..)) -type M = 1000 -type K = 1000 -type N = 1000 +type M = 100 +type K = 100 +type N = 100 setup :: IO (a -> b -> (a, b)) setup = (,) <$ setStdGen (mkStdGen 0) -mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n) => IO (NDArr [m, n] Double) -mkNd m n = fromIndexM @[m, n] (const normalIO) +mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n, Storable a) + => IO a -> IO (NDArr [m, n] a) +mkNd m n gen = fromIndexM @'[m, n] $ const gen + +mkVec :: forall m n -> (KnownNat m, KnownNat n, Storable a) + => IO a -> IO (Storable.Vector a) +mkVec m n gen = Storable.replicateM (natVal n * natVal m) $ gen + +mkMassiv :: forall m n -> (KnownNat m, KnownNat n, Storable a) => IO a -> IO (Array S Ix2 a) +mkMassiv m n gen = freeze Seq =<< makeMArrayS (Sz2 (natVal n) (natVal m)) (const gen) main :: IO () -main = defaultMain - [ bench "random init" $ nfIO $ mkNd M N - , env (setup <*> mkNd M K <*> mkNd K N) \input -> - bgroup "matmul" - [ bench "Massiv" $ nf (uncurry crossMassiv) input - , bench "OpenBLAS" $ nf (uncurry cross) input +main = do + var <- newPrimVar 0 + defaultMain $ localOption WallTime <$> + [ bgroup "primvar iota init" + [ bench "NDArr" $ nfIO $ mkNd M N (fetchAddInt var 1) + , bench "Storable.Vector" $ nfIO $ mkVec M N (fetchAddInt var 1) + , bench "massiv" $ nfIO $ mkMassiv M N (fetchAddInt var 1) + ] + , bgroup "random init" + [ bench "NDArr" $ nfIO $ mkNd @Double M N normalIO + , bench "Storable.Vector" $ nfIO $ mkVec @Double M N normalIO + , bench "massiv" $ nfIO $ mkMassiv @Double M N normalIO + ] + , env (setup <*> mkNd @Double M K normalIO <*> mkNd @Double K N normalIO) \input -> + bgroup "matmul" + [ bench "Massiv" $ nf (uncurry crossMassiv) input + , bench "OpenBLAS" $ nf (uncurry cross) input + , bench "Naive" $ nf (uncurry crossNaive) input + ] ] - ] diff --git a/ml.cabal b/ml.cabal index 0943ff9..75e726e 100644 --- a/ml.cabal +++ b/ml.cabal @@ -50,6 +50,7 @@ library , massiv , mtl , netlib-ffi + , primitive , singletons , singletons-base , template-haskell @@ -86,6 +87,7 @@ test-suite ml-test , ml , mtl , netlib-ffi + , primitive , singletons , singletons-base , tasty @@ -121,9 +123,11 @@ benchmark ml-bench , mtl , netlib-ffi , normaldistribution + , primitive , random , singletons , singletons-base + , tasty , tasty-bench , template-haskell , vector diff --git a/package.yaml b/package.yaml index 6b67cea..da5d39f 100644 --- a/package.yaml +++ b/package.yaml @@ -31,6 +31,7 @@ dependencies: - haskell-src-meta - singletons - singletons-base +- primitive # - ghc-typelits-natnormalise ghc-options: @@ -61,6 +62,7 @@ benchmarks: - ml - normaldistribution - random + - tasty - tasty-bench tests: @@ -76,8 +78,6 @@ tests: - tasty - tasty-hunit - falsify - - singletons - - singletons-base language: GHC2024 default-extensions: diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 8534137..c096c63 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -1,4 +1,5 @@ {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RequiredTypeArguments #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE ViewPatterns #-} -- | @@ -26,15 +27,21 @@ module MLambda.Index , withIx -- * Index operations , concatIndex + -- * Efficient iteration + , enumerate + , loop_ ) where import MLambda.TypeLits +import Control.Monad import Data.Bool.Singletons import Data.List.Singletons +import Data.Maybe (fromJust) import Data.Singletons import GHC.TypeLits.Singletons hiding (natVal) + -- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@. -- Instances are provided for convenient use: -- @@ -81,22 +88,64 @@ instance Bounded (Index '[]) where minBound = E maxBound = E +class FoldI d where + foldI :: r -> (forall d -> KnownNat d => r -> Int -> r) -> Index d -> r + +instance FoldI '[] where + foldI !acc f = const acc + +instance (KnownNat d, FoldI ds) => FoldI (d:ds) where + foldI !acc f (ICons h t) = foldI (f d acc h) f t + instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where minBound = 0 :. minBound maxBound = (-1) :. maxBound +class EnumImpl i where + succM :: i -> Maybe i + predM :: i -> Maybe i + +instance EnumImpl (Index '[]) where + succM = const Nothing + predM = const Nothing + +instance (KnownNat n, 1 <= n, EnumImpl (Index d), Bounded (Index d)) => + EnumImpl (Index (n:d)) where + succM (I h :. t) + | Just t' <- succM t = Just $ I h :. t' + | h < natVal n - 1 = Just $ I h :. minBound + | otherwise = Nothing + predM (I h :. t) + | Just t' <- predM t = Just $ I h :. t' + | h > 0 = Just $ I h :. maxBound + | otherwise = Nothing + instance Enum (Index '[]) where toEnum = const E fromEnum = const 0 -instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => +instance (KnownNat n, 1 <= n, EnumImpl (Index d), Enum (Index d), Bounded (Index d), FoldI d) => Enum (Index (n:d)) where toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r - fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r - succ (h :. t) | t == maxBound = succ h :. minBound - | otherwise = h :. succ t - pred (h :. t) | t == minBound = pred h :. maxBound - | otherwise = h :. pred t + fromEnum = foldI 0 f + where + f :: forall d -> KnownNat d => Int -> Int -> Int + f n r i = natVal n * r + i + succ = fromJust . succM + pred = fromJust . predM + +-- | Efficiently (compared to @`enumFromTo`@) enumerate all indices in lexicographic +-- order. +enumerate :: forall d -> Ix d => [Index d] +enumerate d = + case IxI @d of + EI -> [E] + -- TODO: this performs very poorly, optimize + _ :.= IxI @r -> (:.) <$> [minBound..maxBound] <*> enumerate r + +-- | Efficiently iterate through indices in lexicographic order. +loop_ :: forall d m e . (Ix d, Monad m) => (Index d -> m e) -> m () +loop_ = forM_ (enumerate d) -- | Concatenate two indices together concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys) @@ -143,7 +192,7 @@ concatIndexI (i1 :.= d1) d2 = case concatIndexI d1 d2 of -- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain -- a value of @`IndexI`@. -class (Bounded (Index dim), Enum (Index dim)) => Ix dim where +class (Bounded (Index dim), Enum (Index dim), EnumImpl (Index dim), FoldI dim) => Ix dim where -- | Returns a term-level witness of @Ix@. inst :: IndexI dim diff --git a/src/MLambda/Matrix.hs b/src/MLambda/Matrix.hs index 87d6a3d..e6341df 100644 --- a/src/MLambda/Matrix.hs +++ b/src/MLambda/Matrix.hs @@ -18,6 +18,7 @@ module MLambda.Matrix -- * Matrix multiplication ( cross , crossMassiv + , crossNaive -- * Matrix creation , mat , eye @@ -55,7 +56,8 @@ import GHC.Ptr import Language.Haskell.Meta.Parse qualified as Haskell import Language.Haskell.TH qualified as TH import Language.Haskell.TH.Quote qualified as Quote -import Numeric.BLAS.FFI.Double +import Numeric.BLAS.FFI.Generic +import Numeric.Netlib.Class qualified as BLAS import Text.ParserCombinators.ReadP qualified as P massivSize :: forall m n -> (KnownNat m, KnownNat n) => Massiv.Sz2 @@ -74,13 +76,13 @@ toMassiv = Massiv.fromVector' Massiv.Par (massivSize m n) . runNDArr -- | Matrix product reused from massiv. crossMassiv :: - (KnownNat m, KnownNat k, KnownNat n) => - NDArr [m, k] Double -> NDArr [k, n] Double -> NDArr [m, n] Double + (KnownNat m, KnownNat k, KnownNat n, Num e, Storable e) => + NDArr [m, k] e -> NDArr [k, n] e -> NDArr [m, n] e crossMassiv a b = fromMassiv $ toMassiv a Massiv.!> NDArr [m, k] Double -> NDArr [k, n] Double -> NDArr [m, n] Double +cross :: forall m k n e . (KnownNat n, KnownNat m, KnownNat k, BLAS.Floating e) + => NDArr [m, k] e -> NDArr [k, n] e -> NDArr [m, n] e (runNDArr -> a) `cross` (runNDArr -> b) = unsafePerformIO $ evalContT do let (afptr, _alen) = Storable.unsafeToForeignPtr0 a (bfptr, _blen) = Storable.unsafeToForeignPtr0 b @@ -104,6 +106,29 @@ cross :: forall m k n . (KnownNat n, KnownNat m, KnownNat k) let carr = Storable.unsafeFromForeignPtr0 cfptr len pure $ unsafeMkNDArr carr +crossGeneric :: forall m k n e1 e2 e3 . + ( KnownNat n, KnownNat m, KnownNat k + , 1 <= n, 1 <= m, 1 <= k + , Storable e1, Storable e2, Storable e3) + => (e1 -> e2 -> e3) -> (e3 -> e3 -> e3) + -> NDArr '[m, k] e1 -> NDArr '[k, n] e2 -> NDArr '[m, n] e3 +crossGeneric mul plus a b = runST do + mvec <- Mutable.new (natVal m * natVal n) + loop_ \i -> + loop_ \k -> + loop_ \j -> + Mutable.modify mvec + (plus (mul (a `at` (i :. k)) (b `at` (k :. j)))) + (fromEnum (i :. j)) + unsafeMkNDArr @'[m, n] <$> Storable.unsafeFreeze mvec + +-- | Naive implementation of matrix multiplication. +crossNaive :: ( KnownNat n, KnownNat m, KnownNat k + , 1 <= n, 1 <= m, 1 <= k + , Storable e, Num e) + => NDArr '[m, k] e -> NDArr '[k, n] e -> NDArr '[m, n] e +crossNaive = crossGeneric (*) (+) + infixl 7 `cross` infixl 7 `crossMassiv` @@ -200,11 +225,14 @@ rep :: forall m n e . rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`)) -- | A linear map of vector spaces that corresponds to a matrix. -act :: forall m n e . (KnownNat m, 1 <= m , Storable e) +act :: forall m n e . + ( KnownNat m, 1 <= m + , KnownNat n, 1 <= n + , Storable e) => NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e -act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a) +act a = reshape [n] . crossGeneric modMult add a . reshape [m, 1] -- | Matrix transposition. -transpose :: (KnownNat m, 1 <= m , KnownNat n, 1 <= n , Storable e) +transpose :: (KnownNat m, 1 <= m, KnownNat n, 1 <= n, Storable e) => NDArr '[m, n] e -> NDArr '[n, m] e transpose m = fromIndex \(i :. j) -> m `at` (j :. i) diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index c41b1e7..4144b66 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -51,8 +51,9 @@ import MLambda.Linear import MLambda.TypeLits import Control.DeepSeq (NFData) +import Control.Monad import Control.Monad.ST (runST) -import Data.Foldable (forM_) +import Data.Foldable (for_) import Data.List qualified as List import Data.List.Singletons import Data.Singletons @@ -117,8 +118,8 @@ fromIndex f = runST $ fromIndexM $ pure . f fromIndexM :: forall dim m e . (Mutable.PrimMonad m, Ix dim, Storable e) => (Index dim -> m e) -> m (NDArr dim e) fromIndexM f = do - mvec <- Mutable.new (enumSize (Index dim)) - forM_ [minBound..maxBound] (\i -> f i >>= Mutable.write mvec (fromEnum i)) + mvec <- Mutable.unsafeNew (enumSize (Index dim)) + for_ (zip [0..] $ enumerate dim) \(i, index) -> Mutable.write mvec i =<< f index vec <- Storable.unsafeFreeze mvec pure $ MkNDArr vec