Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ml.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ source-repository head
library
exposed-modules:
MLambda.Index
MLambda.Linear
MLambda.Matrix
MLambda.NDArr
MLambda.TypeLits
Expand All @@ -39,6 +40,7 @@ library
RankNTypes
PolyKinds
UndecidableInstances
NoStarIsType
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints
build-depends:
base >=4.7 && <5
Expand All @@ -48,6 +50,8 @@ library
, massiv
, mtl
, netlib-ffi
, singletons
, singletons-base
, template-haskell
, vector
default-language: GHC2024
Expand All @@ -56,6 +60,9 @@ test-suite ml-test
type: exitcode-stdio-1.0
main-is: Spec.hs
other-modules:
Test.MLambda.Matrix
Test.MLambda.NDArr
Test.MLambda.Utils
Paths_ml
autogen-modules:
Paths_ml
Expand All @@ -67,16 +74,20 @@ test-suite ml-test
RankNTypes
PolyKinds
UndecidableInstances
NoStarIsType
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
base >=4.7 && <5
, blas-ffi
, deepseq
, falsify
, haskell-src-meta
, massiv
, ml
, mtl
, netlib-ffi
, singletons
, singletons-base
, tasty
, tasty-hunit
, template-haskell
Expand All @@ -98,6 +109,7 @@ benchmark ml-bench
RankNTypes
PolyKinds
UndecidableInstances
NoStarIsType
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
base >=4.7 && <5
Expand All @@ -110,6 +122,8 @@ benchmark ml-bench
, netlib-ffi
, normaldistribution
, random
, singletons
, singletons-base
, tasty-bench
, template-haskell
, vector
Expand Down
7 changes: 7 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ dependencies:
- template-haskell
- vector
- haskell-src-meta
- singletons
- singletons-base
# - ghc-typelits-natnormalise

ghc-options:
- -Wall
Expand Down Expand Up @@ -72,6 +75,9 @@ tests:
- ml
- tasty
- tasty-hunit
- falsify
- singletons
- singletons-base

language: GHC2024
default-extensions:
Expand All @@ -80,3 +86,4 @@ default-extensions:
- RankNTypes
- PolyKinds
- UndecidableInstances
- NoStarIsType
121 changes: 83 additions & 38 deletions src/MLambda/Index.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : MLambda.Index
-- Description : Type of multidimensional array indices.
Expand All @@ -12,18 +13,28 @@
-- This module contains definition of 'Index' type of multidimensional array
-- indices along with its instances and public interface.
module MLambda.Index
( Index (E, (:.))
, consIndex
, concatIndex
( -- * Index type
Index (E)
, pattern (:.)
-- * Index instances
, Ix(..)
, IndexI (..)
, Ix
, inst
, pattern IxI
, concatIndexI
-- * Lifting of runtime dimesions into indices
, singToIndexI
, withIx
-- * Index operations
, concatIndex
) where

import Data.Proxy (Proxy (..))

import MLambda.TypeLits

import Data.Bool.Singletons
import Data.List.Singletons
import Data.Singletons
import GHC.TypeLits.Singletons hiding (natVal)

-- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@.
-- Instances are provided for convenient use:
--
Expand All @@ -37,8 +48,20 @@ import MLambda.TypeLits
-- order their respective elements are laid out in memory.
data Index (dim :: [Natural]) where
E :: Index '[]
I :: Int -> Index '[n]
(:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds)
ICons :: Int -> Index ds -> Index (d : ds)

{-# COMPLETE I #-}
pattern I :: Int -> Index '[d]
pattern I m = ICons m E

viewI :: Index (d:ds) -> (Index '[d], Index ds)
viewI (ICons x xs) = (I x, xs)

{-# COMPLETE (:.) #-}
-- | Prepend a single-dimensional index to multi-dimensional one
pattern (:.) :: Index '[d] -> Index ds -> Index (d:ds)
pattern x :. xs <- (viewI -> (x, xs))
where (ICons x E) :. xs = ICons x xs

deriving instance Eq (Index dim)
deriving instance Ord (Index dim)
Expand All @@ -59,43 +82,26 @@ instance Bounded (Index '[]) where
maxBound = E

instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where
minBound = 0 `consIndex` minBound
maxBound = (-1) `consIndex` maxBound
minBound = 0 :. minBound
maxBound = (-1) :. maxBound

instance Enum (Index '[]) where
toEnum = const E
fromEnum = const 0

instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) =>
Enum (Index (n:d)) where
toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q `consIndex` toEnum r
fromEnum = \case
I i -> i
I q :. r -> q * enumSize (Index d) + fromEnum r
succ = \case
h@(I i) | h == maxBound -> error "Undefined succ"
| otherwise -> I (succ i)
h :. t | t == maxBound -> succ h :. minBound
| otherwise -> h :. succ t
pred = \case
h@(I i) | h == minBound -> error "Undefined pred"
| otherwise -> I (pred i)
h :. t | t == minBound -> pred h :. maxBound
| otherwise -> h :. pred t

-- | Prepend a single-dimensional index to multi-dimensional one
consIndex :: Index '[x] -> Index xs -> Index (x : xs)
consIndex (I x) = \case
E -> I x
I y -> I x :. I y
I y :. xs -> I x :. I y :. xs
toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I (q `mod` natVal n) :. toEnum r
fromEnum (I q :. r) = q * enumSize (Index d) + fromEnum r
succ (h :. t) | t == maxBound = succ h :. minBound
| otherwise = h :. succ t
pred (h :. t) | t == minBound = pred h :. maxBound
| otherwise = h :. pred t

-- | Concatenate two indices together
concatIndex :: Index xs -> Index ys -> Index (xs ++ ys)
concatIndex = \case
E -> id
I x -> consIndex (I x)
I x :. xs -> consIndex (I x) . concatIndex xs
concatIndex :: forall xs ys . Index xs -> Index ys -> Index (xs ++ ys)
concatIndex E = id
concatIndex (ICons x xs) = ICons x . concatIndex xs

-- | A helper type that holds instances for everything you can get by pattern matching on
-- an @`Index`@ value. If you need to get instances for a head/tail of an @`Index`@,
Expand All @@ -110,8 +116,31 @@ data IndexI (dim :: [Natural]) where
(:.=) ::
(KnownNat n, 1 <= n, Ix ds) => Proxy n -> IndexI ds -> IndexI (n : ds)

data IxInstance (dim :: [Natural]) where
IxInstance :: Ix dim => IxInstance dim

viewII :: IndexI dim -> IxInstance dim
viewII (_ :.= _) = IxInstance
viewII EI = IxInstance

-- | A simpler pattern for @`IndexI`@ in case you don't need instances
-- for suffixes.
{-# COMPLETE IxI #-}
pattern IxI :: () => Ix dim => IndexI dim
pattern IxI <- (viewII -> IxInstance) where
IxI = inst

deriving instance Show (IndexI dim)

infixr 5 :.=

-- | Concatenate two @`IndexI`@s to get the instances for concatenated
-- dimensions.
concatIndexI :: IndexI d1 -> IndexI d2 -> IndexI (d1 ++ d2)
concatIndexI EI d2 = d2
concatIndexI (i1 :.= d1) d2 = case concatIndexI d1 d2 of
IxI -> i1 :.= IxI

-- | A class used both as a shorthand for useful @`Index`@ instances and a way to obtain
-- a value of @`IndexI`@.
class (Bounded (Index dim), Enum (Index dim)) => Ix dim where
Expand All @@ -123,3 +152,19 @@ instance Ix '[] where

instance (KnownNat n, 1 <= n, Ix ds) => Ix (n:ds) where
inst = Proxy :.= inst

-- | Convert @`Sing`@ of a dimension list to a witness for @`Ix`@ instances
-- of that index.
singToIndexI :: forall dim . Sing dim -> Maybe (IndexI dim)
singToIndexI SNil = Just EI
singToIndexI (SCons sn@SNat sr) =
case (sing @1 %<=? sn, singToIndexI sr) of
(STrue, Just r@IxI) -> Just $ Proxy :.= r
_ -> Nothing

-- | Simple interface for lifting runtime dimensions into type level.
-- Provides the given continuation with the type of said dimensions and their
-- @`Ix`@ instance.
withIx :: Demote [Natural] -> (forall dim . Ix dim => Proxy dim -> r) -> Maybe r
withIx d f = withSomeSing d \(singToIndexI -> r) ->
flip fmap r \case (IxI :: IndexI dim) -> f $ Proxy @dim
45 changes: 45 additions & 0 deletions src/MLambda/Linear.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
-- |
-- Module : MLambda.Linear
-- Description : Linear maps.
-- Copyright : (c) neclitoris, TurtlePU, 2025
-- License : BSD-3-Clause
-- Maintainer : nas140301@gmail.com
-- Stability : experimental
-- Portability : portable
--
-- This module describes linear maps.
module MLambda.Linear (Additive(..), Module(..), LinearMap, LinearMap') where

import Data.Kind

-- | @Additive@ describes things you can add.
class Additive m where
add :: m -> m -> m
zero :: m

-- | @Module@ is a class describing modules over a ring @r@.
-- Since @`Num`@ describes commutative rings, this is both a right
-- and a left module.
class (Num r, Additive m) => Module r m where
modMult :: r -> m -> m

instance {-# OVERLAPPABLE #-} Num r => Additive r where
add = (+)
zero = 0

instance Num r => Module r r where
modMult = (*)

-- | This version of LinearMap allows one to pass an arbitrary additional
-- constraint on the type mapped over. Note that this can break things:
-- @LinearMap' ((~) r) s m r@ will allow you to do naughty things.
type LinearMap' (e :: Type -> Constraint) t s r =
forall m . (Module r m, e m) => t m -> s m

class Empty t

instance Empty t

-- | @LinearMap@ forces linearity (in mathematical sense) in its argument
-- since you can't multiply an element of @m@ by itself, only by some @a@.
type LinearMap t s r = LinearMap' Empty t s r
34 changes: 33 additions & 1 deletion src/MLambda/Matrix.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@ module MLambda.Matrix
, crossMassiv
-- * Matrix creation
, mat
, eye
, rep
, act
-- * Shape manipulation
, transpose
) where

import MLambda.Foreign.Utils (asFPtr, asPtr, char)
import MLambda.NDArr
import MLambda.Index
import MLambda.Linear
import MLambda.NDArr hiding (concat, foldr, map, zipWith)
import MLambda.NDArr qualified as NDArr
import MLambda.TypeLits

import Control.Applicative
Expand Down Expand Up @@ -176,3 +184,27 @@ matE s = do
vec <- Storable.unsafeFreeze mvec
pure $ unsafeMkNDArr @'[$tm, $tn] @Double vec
|]

-- | Identity matrix.
eye :: (KnownNat n, 1 <= n, Storable e, Num e) => NDArr '[n, n] e
eye = fromIndex go where
go (i :. j) | i == j = 1
| otherwise = 0

-- | A matrix that corresponds to a linear map of vector spaces.
rep :: forall m n e .
( KnownNat m, 1 <= m
, KnownNat n, 1 <= n
, Storable e, Num e)
=> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e -> NDArr '[n, m] e
rep f = transpose $ NDArr.concat $ fromIndex (f . (rows @'[m] eye `at`))

-- | A linear map of vector spaces that corresponds to a matrix.
act :: forall m n e . (KnownNat m, 1 <= m , Storable e)
=> NDArr '[n, m] e -> LinearMap' Storable (NDArr '[m]) (NDArr '[n]) e
act a x = NDArr.map (NDArr.foldr add zero . flip (NDArr.zipWith modMult) x) (rows a)

-- | Matrix transposition.
transpose :: (KnownNat m, 1 <= m , KnownNat n, 1 <= n , Storable e)
=> NDArr '[m, n] e -> NDArr '[n, m] e
transpose m = fromIndex \(i :. j) -> m `at` (j :. i)
Loading