Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions src/MLambda/Index.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : MLambda.Index
-- Description : Type of multidimensional array indices.
Expand All @@ -10,14 +12,16 @@
-- This module contains definition of 'Index' type of multidimensional array
-- indices along with its instances and public interface.
module MLambda.Index
( Index ((:.))
( Index (E, (:.))
, consIndex
, concatIndex
, IndexI (..)
, Ix
, inst
) where

import Data.Proxy (Proxy (..))

import MLambda.TypeLits

-- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@.
Expand All @@ -32,8 +36,9 @@ import MLambda.TypeLits
-- - @`Enum`@ provides means to iterate through indices in the same
-- order their respective elements are laid out in memory.
data Index (dim :: [Natural]) where
(:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds)
E :: Index '[]
I :: Int -> Index '[n]
(:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds)

deriving instance Eq (Index dim)
deriving instance Ord (Index dim)
Expand All @@ -49,42 +54,46 @@ instance (KnownNat n, 1 <= n) => Num (Index '[n]) where
(I a) + (I b) = I $ (a + b) `mod` natVal n
(I a) * (I b) = I $ (a * b) `mod` natVal n

instance (KnownNat n, 1 <= n) => Bounded (Index '[n]) where
minBound = I 0
maxBound = I $ natVal n - 1

instance (KnownNat n, 1 <= n, Bounded (Index (a:r))) => Bounded (Index (n:a:r)) where
minBound = minBound :. minBound
maxBound = maxBound :. maxBound

instance (KnownNat n, 1 <= n) => Enum (Index '[n]) where
fromEnum (I m) = m
toEnum = I . (`mod` natVal n)
succ (I m) | m == natVal n - 1 = error "Undefined succ"
succ (I m) = I (m + 1)
pred (I 0) = error "Undefined pred"
pred (I m) = I (m - 1)

instance (KnownNat n, 1 <= n, Enum (Index (a:r)), Bounded (Index (a:r)))
=> Enum (Index (n:a:r)) where
fromEnum (I n :. t) = enumSize (Index (a:r)) * n + fromEnum t
toEnum m = I q :. toEnum t
where
(q, t) = m `quotRem` enumSize (Index (a:r))
succ (h :. t) | t == maxBound = succ h :. minBound
succ (h :. t) = h :. succ t
pred (h :. t) | t == minBound = pred h :. maxBound
pred (h :. t) = h :. pred t
instance Bounded (Index '[]) where
minBound = E
maxBound = E

instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where
minBound = 0 `consIndex` minBound
maxBound = (-1) `consIndex` maxBound

instance Enum (Index '[]) where
toEnum = const E
fromEnum = const 0

instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) =>
Enum (Index (n:d)) where
toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q `consIndex` toEnum r
fromEnum = \case
I i -> i
I q :. r -> q * enumSize (Index d) + fromEnum r
succ = \case
h@(I i) | h == maxBound -> error "Undefined succ"
| otherwise -> I (succ i)
h :. t | t == maxBound -> succ h :. minBound
| otherwise -> h :. succ t
pred = \case
h@(I i) | h == minBound -> error "Undefined pred"
| otherwise -> I (pred i)
h :. t | t == minBound -> pred h :. maxBound
| otherwise -> h :. pred t

-- | Prepend a single-dimensional index to multi-dimensional one
consIndex :: Index '[x] -> Index xs -> Index (x : xs)
consIndex (I x) = \case
E -> I x
I y -> I x :. I y
I y :. xs -> I x :. I y :. xs

-- | Concatenate two indices together
concatIndex :: Index xs -> Index ys -> Index (xs ++ ys)
concatIndex = \case
E -> id
I x -> consIndex (I x)
I x :. xs -> consIndex (I x) . concatIndex xs

Expand All @@ -97,8 +106,9 @@ concatIndex = \case
-- > _ :.= _ -> f j -- @f j@ can be called here because we have @Ix ds@ now
-- > f i = ...
data IndexI (dim :: [Natural]) where
(:.=) :: Ix (d : ds) => IndexI '[n] -> IndexI (d : ds) -> IndexI (n : d : ds)
II :: (KnownNat n, 1 <= n) => IndexI '[n]
EI :: IndexI '[]
(:.=) ::
(KnownNat n, 1 <= n, Ix ds) => Proxy n -> IndexI ds -> IndexI (n : ds)

infixr 5 :.=

Expand All @@ -108,8 +118,8 @@ class (Bounded (Index dim), Enum (Index dim)) => Ix dim where
-- | Returns a term-level witness of @Ix@.
inst :: IndexI dim

instance (KnownNat n, 1 <= n) => Ix '[n] where
inst = II
instance Ix '[] where
inst = EI

instance (KnownNat n, 1 <= n, Ix (d:ds)) => Ix (n:d:ds) where
inst = II :.= inst
instance (KnownNat n, 1 <= n, Ix ds) => Ix (n:ds) where
inst = Proxy :.= inst
71 changes: 36 additions & 35 deletions src/MLambda/NDArr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ module MLambda.NDArr
, row
, rows
-- * Array composition
, Stacks
, Stack
, Stacks
, stack
-- * Unsafe API
, unsafeMkNDArr
Expand Down Expand Up @@ -70,7 +70,7 @@ instance (Ix (n:a:r), Show (NDArr (a:r) e), Storable e) =>
Show (NDArr (n:a:r) e) where
showsPrec _ =
case inst @(n:a:r) of
II :.= _ -> (showString "[" .) . (. showString "]")
_ :.= _ -> (showString "[" .) . (. showString "]")
. foldl' (.) id
. intersperse (showString ",\n")
. map shows
Expand Down Expand Up @@ -115,66 +115,67 @@ row ::
row i a = rows a `at` i

-- | Extract all "rows" from the array as an array.
rows :: forall d1 d2 e. (Ix d2, Storable e) => NDArr (d1 ++ d2) e -> NDArr d1 (NDArr d2 e)
rows ::
forall d1 d2 e. (Ix d2, Storable e) =>
NDArr (d1 ++ d2) e -> NDArr d1 (NDArr d2 e)
rows = MkNDArr . Storable.unsafeCast . runNDArr

toList :: Storable e => NDArr d e -> [e]
toList = Storable.toList . runNDArr

concat ::
forall d1 d2 e. (Ix d2, Storable e) => NDArr d1 (NDArr d2 e) -> NDArr (d1 ++ d2) e
forall d1 d2 e. (Ix d2, Storable e) =>
NDArr d1 (NDArr d2 e) -> NDArr (d1 ++ d2) e
concat = MkNDArr . Storable.unsafeCast . runNDArr

zipWith ::
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> NDArr d a -> NDArr d b -> NDArr d c
zipWith f (MkNDArr xs) (MkNDArr ys) = MkNDArr (Storable.zipWith f xs ys)

data StackWitness i d1 d2 dr where
SZ :: (a + b ~ c, Ix (a : d), Ix (b : d), Ix (c : d))
=> Proxy '(a, b, c) -> Proxy d
-> StackWitness PZ (a : d) (b : d) (c : d)
SS :: ( Ix (a : s), Ix (b : s), Ix (c : s), a + b ~ c)
=> Proxy p -> Proxy '(a, b, c) -> Proxy s
-> StackWitness (PS i) (p ++ (a : s)) (p ++ (b : s)) (p ++ (c : s))
vstack ::
Storable e => NDArr (k : d) e -> NDArr (l : d) e -> NDArr ((k + l) : d) e
vstack (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys)

-- | A type family which computes the resulting size of a stacked array.
type Stack n d e = StackImpl (StackError n d e) (Peano n) d e

type StackError n d e =
Text "Not enough dimensions to stack along axis " :<>: ShowType n
:<>: Text ":" :$$: ShowType d :$$: ShowType e

-- | A type family which computes the resulting size of a stacked array.
type family Stack msg i d e where
Stack _ PZ (n : d) (m : e) = n + m : Unify "Dimensions" d e
Stack msg (PS i) (n : d) (m : e) = Unify "Sizes" n m : Stack msg i d e
Stack msg _ _ _ = TypeError msg
type family StackImpl msg i d e where
StackImpl _ PZ (n : d) (m : e) = n + m : Unify "Dimensions" d e
StackImpl msg (PS i) (n : d) (m : e) = Unify "Sizes" n m : StackImpl msg i d e
StackImpl msg _ _ _ = TypeError msg

-- | A constraint which links together compatible dimensions and axis along
-- which they will be stacked.
class Stacks i dim1 dim2 dimr where
stacks :: StackWitness i dim1 dim2 dimr

instance (n + m ~ k, Ix (n : d), Ix (m : d), Ix (k : d))
=> Stacks PZ (n : d) (m : d) (k : d) where
stacks = SZ (Proxy @'(n, m, n + m)) (Proxy @d)
data StackWitness i d1 d2 dr where
SW ::
( KnownNat k, KnownNat l, KnownNat (k + l), Ix t
, 1 <= k, 1 <= l, 1 <= k + l
) => Proxy '(s, k, l, t) ->
StackWitness (Length s) (s ++ (k : t)) (s ++ (l : t)) (s ++ ((k + l) : t))

instance
( KnownNat n, KnownNat m, KnownNat k, Ix d
, n + m ~ k, 1 <= n, 1 <= m, 1 <= k
) => Stacks PZ (n : d) (m : d) (k : d) where
stacks = SW (Proxy @' ('[], n, m, d))

instance Stacks i d e r => Stacks (PS i) (n : d) (n : e) (n : r) where
stacks = case stacks @i @d @e @r of
(SZ m s) -> SS (Proxy @'[n]) m s
(SS (Proxy @p) m s) -> SS (Proxy @(n:p)) m s
stacks = case stacks @i of
SW (Proxy @'(s, k, l, t)) -> SW (Proxy @'(n : s, k, l, t))

-- | @stack i@ stacks arrays along the axis @i@. All other axes are required
-- to be the same lengths.
stack ::
forall n ->
( err ~ StackError n d1 d2
, Stacks (Peano n) d1 d2 (Stack err (Peano n) d1 d2)
, Storable e
) => NDArr d1 e -> NDArr d2 e -> NDArr (Stack err (Peano n) d1 d2) e
stack n = go (stacks @(Peano n))
where
go ::
forall n d1 d2 dr e. Storable e => StackWitness n d1 d2 dr ->
NDArr d1 e -> NDArr d2 e -> NDArr dr e
go (SZ _ _) (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys)
go (SS (Proxy @p) m@(Proxy @'(a, b, _)) s@(Proxy @s)) xs ys =
concat $ zipWith (go (SZ m s)) (rows @p @(a : s) xs) (rows @p @(b : s) ys)
forall n -> (Stacks (Peano n) d1 d2 (Stack n d1 d2), Storable e) =>
NDArr d1 e -> NDArr d2 e -> NDArr (Stack n d1 d2) e
stack n xs ys = case stacks @(Peano n) of
SW (Proxy @'(s, k, l, t)) ->
concat $ zipWith vstack (rows @s @(k : t) xs) (rows @s @(l : t) ys)
6 changes: 6 additions & 0 deletions src/MLambda/TypeLits.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ module MLambda.TypeLits
, Unify
, type (++)
, PNat (..)
, Length
, Peano
, RNat (..)
, RPNat (..)
Expand Down Expand Up @@ -52,6 +53,11 @@ type family xs ++ ys where
-- | Peano naturals.
data PNat = PZ | PS PNat

-- | Compute length of a type-level list as a Peano natural.
type family Length xs where
Length '[] = PZ
Length (_:xs) = PS (Length xs)

-- | Compute Peano representation from type-level natural.
type family Peano n where
Peano 0 = PZ
Expand Down