Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions bench/Bench.hs
Original file line number Diff line number Diff line change
@@ -1,30 +1,59 @@
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RequiredTypeArguments #-}

import MLambda.Matrix
import MLambda.NDArr
import MLambda.TypeLits (KnownNat)
import MLambda.TypeLits (KnownNat, natVal)

import Data.Massiv.Array (Array, Comp (..), Ix2, pattern Sz2)
import Data.Massiv.Array.Manifest (S)
import Data.Massiv.Array.Mutable (freeze, makeMArrayS)
import Data.Primitive.PrimVar
import Data.Random.Normal (normalIO)
import GHC.TypeLits (type (<=))
import Data.Vector.Storable qualified as Storable
import Foreign.Storable
import GHC.TypeLits (type (<=), type (*))

Check warning on line 16 in bench/Bench.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

The import of ‘*’ from module ‘GHC.TypeLits’ is redundant
import System.Random (mkStdGen, setStdGen)
import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO)
import Test.Tasty (localOption)
import Test.Tasty.Bench (bench, bgroup, defaultMain, env, nf, nfIO, TimeMode(..))

type M = 1000
type K = 1000
type N = 1000
type M = 100
type K = 100
type N = 100

setup :: IO (a -> b -> (a, b))
setup = (,) <$ setStdGen (mkStdGen 0)

mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n) => IO (NDArr [m, n] Double)
mkNd m n = fromIndexM @[m, n] (const normalIO)
mkNd :: forall m n -> (KnownNat m, KnownNat n, 1 <= m, 1 <= n, Storable a)
=> IO a -> IO (NDArr [m, n] a)
mkNd m n gen = fromIndexM @'[m, n] $ const gen

mkVec :: forall m n -> (KnownNat m, KnownNat n, Storable a)
=> IO a -> IO (Storable.Vector a)
mkVec m n gen = Storable.replicateM (natVal n * natVal m) $ gen

mkMassiv :: forall m n -> (KnownNat m, KnownNat n, Storable a) => IO a -> IO (Array S Ix2 a)
mkMassiv m n gen = freeze Seq =<< makeMArrayS (Sz2 (natVal n) (natVal m)) (const gen)

main :: IO ()
main = defaultMain
[ bench "random init" $ nfIO $ mkNd M N
, env (setup <*> mkNd M K <*> mkNd K N) \input ->
bgroup "matmul"
[ bench "Massiv" $ nf (uncurry crossMassiv) input
, bench "OpenBLAS" $ nf (uncurry cross) input
main = do
var <- newPrimVar 0
defaultMain $ localOption WallTime <$>
[ bgroup "primvar iota init"
[ bench "NDArr" $ nfIO $ mkNd M N (fetchAddInt var 1)
, bench "Storable.Vector" $ nfIO $ mkVec M N (fetchAddInt var 1)
, bench "massiv" $ nfIO $ mkMassiv M N (fetchAddInt var 1)
]
, bgroup "random init"
[ bench "NDArr" $ nfIO $ mkNd @Double M N normalIO
, bench "Storable.Vector" $ nfIO $ mkVec @Double M N normalIO
, bench "massiv" $ nfIO $ mkMassiv @Double M N normalIO
]
, env (setup <*> mkNd @Double M K normalIO <*> mkNd @Double K N normalIO) \input ->
bgroup "matmul"
[ bench "Massiv" $ nf (uncurry crossMassiv) input
, bench "OpenBLAS" $ nf (uncurry cross) input
, bench "Naive" $ nf (uncurry crossNaive) input
]
]
]
4 changes: 4 additions & 0 deletions ml.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ library
, massiv
, mtl
, netlib-ffi
, primitive
, singletons
, singletons-base
, template-haskell
Expand Down Expand Up @@ -86,6 +87,7 @@ test-suite ml-test
, ml
, mtl
, netlib-ffi
, primitive
, singletons
, singletons-base
, tasty
Expand Down Expand Up @@ -121,9 +123,11 @@ benchmark ml-bench
, mtl
, netlib-ffi
, normaldistribution
, primitive
, random
, singletons
, singletons-base
, tasty
, tasty-bench
, template-haskell
, vector
Expand Down
4 changes: 2 additions & 2 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies:
- haskell-src-meta
- singletons
- singletons-base
- primitive
# - ghc-typelits-natnormalise

ghc-options:
Expand Down Expand Up @@ -61,6 +62,7 @@ benchmarks:
- ml
- normaldistribution
- random
- tasty
- tasty-bench

tests:
Expand All @@ -76,8 +78,6 @@ tests:
- tasty
- tasty-hunit
- falsify
- singletons
- singletons-base

language: GHC2024
default-extensions:
Expand Down
63 changes: 56 additions & 7 deletions src/MLambda/Index.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RequiredTypeArguments #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE ViewPatterns #-}
-- |
Expand Down Expand Up @@ -26,15 +27,21 @@
, withIx
-- * Index operations
, concatIndex
-- * Efficient iteration
, enumerate
, loop_
) where

import MLambda.TypeLits

import Control.Monad
import Data.Bool.Singletons
import Data.List.Singletons
import Data.Maybe (fromJust)
import Data.Singletons
import GHC.TypeLits.Singletons hiding (natVal)


-- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@.
-- Instances are provided for convenient use:
--
Expand Down Expand Up @@ -81,22 +88,64 @@
minBound = E
maxBound = E

class FoldI d where
foldI :: r -> (forall d -> KnownNat d => r -> Int -> r) -> Index d -> r

Check warning on line 92 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

This binding for ‘d’ shadows the existing binding

Check warning on line 92 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

This binding for ‘d’ shadows the existing binding

Check warning on line 92 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

This binding for ‘d’ shadows the existing binding

instance FoldI '[] where
foldI !acc f = const acc

Check warning on line 95 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

Defined but not used: ‘f’

Check warning on line 95 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

Defined but not used: ‘f’

Check warning on line 95 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

Defined but not used: ‘f’

instance (KnownNat d, FoldI ds) => FoldI (d:ds) where
foldI !acc f (ICons h t) = foldI (f d acc h) f t

instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where
minBound = 0 :. minBound
maxBound = (-1) :. maxBound

class EnumImpl i where
succM :: i -> Maybe i
predM :: i -> Maybe i

instance EnumImpl (Index '[]) where
succM = const Nothing
predM = const Nothing

instance (KnownNat n, 1 <= n, EnumImpl (Index d), Bounded (Index d)) =>
EnumImpl (Index (n:d)) where
succM (I h :. t)
| Just t' <- succM t = Just $ I h :. t'
| h < natVal n - 1 = Just $ I h :. minBound
| otherwise = Nothing
predM (I h :. t)
| Just t' <- predM t = Just $ I h :. t'
| h > 0 = Just $ I h :. maxBound
| otherwise = Nothing

instance Enum (Index '[]) where
toEnum = const E
fromEnum = const 0

instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) =>
instance (KnownNat n, 1 <= n, EnumImpl (Index d), Enum (Index d), Bounded (Index d), FoldI d) =>
Enum (Index (n:d)) where
toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r
fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r
succ (h :. t) | t == maxBound = succ h :. minBound
| otherwise = h :. succ t
pred (h :. t) | t == minBound = pred h :. maxBound
| otherwise = h :. pred t
fromEnum = foldI 0 f
where
f :: forall d -> KnownNat d => Int -> Int -> Int

Check warning on line 132 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

This binding for ‘d’ shadows the existing binding

Check warning on line 132 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

This binding for ‘d’ shadows the existing binding

Check warning on line 132 in src/MLambda/Index.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

This binding for ‘d’ shadows the existing binding
f n r i = natVal n * r + i
succ = fromJust . succM
pred = fromJust . predM

-- | Efficiently (compared to @`enumFromTo`@) enumerate all indices in lexicographic
-- order.
enumerate :: forall d -> Ix d => [Index d]
enumerate d =
case IxI @d of
EI -> [E]
-- TODO: this performs very poorly, optimize
_ :.= IxI @r -> (:.) <$> [minBound..maxBound] <*> enumerate r

-- | Efficiently iterate through indices in lexicographic order.
loop_ :: forall d m e . (Ix d, Monad m) => (Index d -> m e) -> m ()
loop_ = forM_ (enumerate d)

-- | Concatenate two indices together
concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys)
Expand Down Expand Up @@ -143,7 +192,7 @@

-- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain
-- a value of @`IndexI`@.
class (Bounded (Index dim), Enum (Index dim)) => Ix dim where
class (Bounded (Index dim), Enum (Index dim), EnumImpl (Index dim), FoldI dim) => Ix dim where
-- | Returns a term-level witness of @Ix@.
inst :: IndexI dim

Expand Down
44 changes: 36 additions & 8 deletions src/MLambda/Matrix.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module MLambda.Matrix
-- * Matrix multiplication
( cross
, crossMassiv
, crossNaive
-- * Matrix creation
, mat
, eye
Expand Down Expand Up @@ -55,7 +56,8 @@ import GHC.Ptr
import Language.Haskell.Meta.Parse qualified as Haskell
import Language.Haskell.TH qualified as TH
import Language.Haskell.TH.Quote qualified as Quote
import Numeric.BLAS.FFI.Double
import Numeric.BLAS.FFI.Generic
import Numeric.Netlib.Class qualified as BLAS
import Text.ParserCombinators.ReadP qualified as P

massivSize :: forall m n -> (KnownNat m, KnownNat n) => Massiv.Sz2
Expand All @@ -74,13 +76,13 @@ toMassiv = Massiv.fromVector' Massiv.Par (massivSize m n) . runNDArr

-- | Matrix product reused from massiv.
crossMassiv ::
(KnownNat m, KnownNat k, KnownNat n) =>
NDArr [m, k] Double -> NDArr [k, n] Double -> NDArr [m, n] Double
(KnownNat m, KnownNat k, KnownNat n, Num e, Storable e) =>
NDArr [m, k] e -> NDArr [k, n] e -> NDArr [m, n] e
crossMassiv a b = fromMassiv $ toMassiv a Massiv.!><! toMassiv b

-- | Your usual matrix product. Calls into BLAS's @gemm@ operation.
cross :: forall m k n . (KnownNat n, KnownNat m, KnownNat k)
=> NDArr [m, k] Double -> NDArr [k, n] Double -> NDArr [m, n] Double
cross :: forall m k n e . (KnownNat n, KnownNat m, KnownNat k, BLAS.Floating e)
=> NDArr [m, k] e -> NDArr [k, n] e -> NDArr [m, n] e
(runNDArr -> a) `cross` (runNDArr -> b) = unsafePerformIO $ evalContT do
let (afptr, _alen) = Storable.unsafeToForeignPtr0 a
(bfptr, _blen) = Storable.unsafeToForeignPtr0 b
Expand All @@ -104,6 +106,29 @@ cross :: forall m k n . (KnownNat n, KnownNat m, KnownNat k)
let carr = Storable.unsafeFromForeignPtr0 cfptr len
pure $ unsafeMkNDArr carr

crossGeneric :: forall m k n e1 e2 e3 .
( KnownNat n, KnownNat m, KnownNat k
, 1 <= n, 1 <= m, 1 <= k
, Storable e1, Storable e2, Storable e3)
=> (e1 -> e2 -> e3) -> (e3 -> e3 -> e3)
-> NDArr '[m, k] e1 -> NDArr '[k, n] e2 -> NDArr '[m, n] e3
crossGeneric mul plus a b = runST do
mvec <- Mutable.new (natVal m * natVal n)
loop_ \i ->
loop_ \k ->
loop_ \j ->
Mutable.modify mvec
(plus (mul (a `at` (i :. k)) (b `at` (k :. j))))
(fromEnum (i :. j))
unsafeMkNDArr @'[m, n] <$> Storable.unsafeFreeze mvec

-- | Naive implementation of matrix multiplication.
crossNaive :: ( KnownNat n, KnownNat m, KnownNat k
, 1 <= n, 1 <= m, 1 <= k
, Storable e, Num e)
=> NDArr '[m, k] e -> NDArr '[k, n] e -> NDArr '[m, n] e
crossNaive = crossGeneric (*) (+)

infixl 7 `cross`
infixl 7 `crossMassiv`

Expand Down Expand Up @@ -200,11 +225,14 @@ rep :: forall m n e .
rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`))

-- | A linear map of vector spaces that corresponds to a matrix.
act :: forall m n e . (KnownNat m, 1 <= m , Storable e)
act :: forall m n e .
( KnownNat m, 1 <= m
, KnownNat n, 1 <= n
, Storable e)
=> NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e
act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a)
act a = reshape [n] . crossGeneric modMult add a . reshape [m, 1]

-- | Matrix transposition.
transpose :: (KnownNat m, 1 <= m , KnownNat n, 1 <= n , Storable e)
transpose :: (KnownNat m, 1 <= m, KnownNat n, 1 <= n, Storable e)
=> NDArr '[m, n] e -> NDArr '[n, m] e
transpose m = fromIndex \(i :. j) -> m `at` (j :. i)
7 changes: 4 additions & 3 deletions src/MLambda/NDArr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
import MLambda.TypeLits

import Control.DeepSeq (NFData)
import Control.Monad
import Control.Monad.ST (runST)
import Data.Foldable (forM_)
import Data.Foldable (for_)
import Data.List qualified as List
import Data.List.Singletons
import Data.Singletons
Expand Down Expand Up @@ -117,8 +118,8 @@
fromIndexM :: forall dim m e . (Mutable.PrimMonad m, Ix dim, Storable e)
=> (Index dim -> m e) -> m (NDArr dim e)
fromIndexM f = do
mvec <- Mutable.new (enumSize (Index dim))
forM_ [minBound..maxBound] (\i -> f i >>= Mutable.write mvec (fromEnum i))
mvec <- Mutable.unsafeNew (enumSize (Index dim))
for_ (zip [0..] $ enumerate dim) \(i, index) -> Mutable.write mvec i =<< f index
vec <- Storable.unsafeFreeze mvec
pure $ MkNDArr vec

Expand Down Expand Up @@ -223,7 +224,7 @@
Size (x:xs) = x * Size xs

-- | Change the shape of an array.
reshape :: forall d2 -> (Size d1 ~ Size d2) => NDArr d1 e -> NDArr d2 e

Check warning on line 227 in src/MLambda/NDArr.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

Redundant constraint: Size d1 ~ Size d2

Check warning on line 227 in src/MLambda/NDArr.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

Redundant constraint: Size d1 ~ Size d2

Check warning on line 227 in src/MLambda/NDArr.hs

View workflow job for this annotation

GitHub Actions / GHC 9.12.2 on ubuntu-latest

Redundant constraint: Size d1 ~ Size d2
reshape _ = MkNDArr . runNDArr

-- | Prepend a single dimension of size 1.
Expand Down
Loading