From 01d9153328e609804d5443a96f8fdf7871133200 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Sat, 16 Aug 2025 05:23:50 +0300 Subject: [PATCH 1/2] + 0-dim indices --- src/MLambda/Index.hs | 76 +++++++++++++++++++++++------------------ src/MLambda/NDArr.hs | 71 +++++++++++++++++++------------------- src/MLambda/TypeLits.hs | 6 ++++ 3 files changed, 85 insertions(+), 68 deletions(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 99f4175..49c9467 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE ViewPatterns #-} + -- | -- Module : MLambda.Index -- Description : Type of multidimensional array indices. @@ -18,6 +20,8 @@ module MLambda.Index , inst ) where +import Data.Proxy (Proxy (..)) + import MLambda.TypeLits -- | @Index dim@ is the type of indices of multidimensional arrays of dimensions @dim@. @@ -32,8 +36,9 @@ import MLambda.TypeLits -- - @`Enum`@ provides means to iterate through indices in the same -- order their respective elements are laid out in memory. data Index (dim :: [Natural]) where - (:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds) + E :: Index '[] I :: Int -> Index '[n] + (:.) :: Index '[n] -> Index (d : ds) -> Index (n : d : ds) deriving instance Eq (Index dim) deriving instance Ord (Index dim) @@ -49,42 +54,46 @@ instance (KnownNat n, 1 <= n) => Num (Index '[n]) where (I a) + (I b) = I $ (a + b) `mod` natVal n (I a) * (I b) = I $ (a * b) `mod` natVal n -instance (KnownNat n, 1 <= n) => Bounded (Index '[n]) where - minBound = I 0 - maxBound = I $ natVal n - 1 - -instance (KnownNat n, 1 <= n, Bounded (Index (a:r))) => Bounded (Index (n:a:r)) where - minBound = minBound :. minBound - maxBound = maxBound :. maxBound - -instance (KnownNat n, 1 <= n) => Enum (Index '[n]) where - fromEnum (I m) = m - toEnum = I . (`mod` natVal n) - succ (I m) | m == natVal n - 1 = error "Undefined succ" - succ (I m) = I (m + 1) - pred (I 0) = error "Undefined pred" - pred (I m) = I (m - 1) - -instance (KnownNat n, 1 <= n, Enum (Index (a:r)), Bounded (Index (a:r))) - => Enum (Index (n:a:r)) where - fromEnum (I n :. t) = enumSize (Index (a:r)) * n + fromEnum t - toEnum m = I q :. toEnum t - where - (q, t) = m `quotRem` enumSize (Index (a:r)) - succ (h :. t) | t == maxBound = succ h :. minBound - succ (h :. t) = h :. succ t - pred (h :. t) | t == minBound = pred h :. maxBound - pred (h :. t) = h :. pred t +instance Bounded (Index '[]) where + minBound = E + maxBound = E + +instance (KnownNat n, 1 <= n, Bounded (Index d)) => Bounded (Index (n:d)) where + minBound = 0 `consIndex` minBound + maxBound = (-1) `consIndex` maxBound + +instance Enum (Index '[]) where + toEnum = const E + fromEnum = const 0 + +instance (KnownNat n, 1 <= n, Enum (Index d), Bounded (Index d)) => + Enum (Index (n:d)) where + toEnum ((`quotRem` enumSize (Index d)) -> (q, r)) = I q `consIndex` toEnum r + fromEnum = \case + I i -> i + I q :. r -> q * enumSize (Index d) + fromEnum r + succ = \case + h@(I i) | h == maxBound -> error "Undefined succ" + | otherwise -> I (succ i) + h :. t | t == maxBound -> succ h :. minBound + | otherwise -> h :. succ t + pred = \case + h@(I i) | h == minBound -> error "Undefined pred" + | otherwise -> I (pred i) + h :. t | t == minBound -> pred h :. maxBound + | otherwise -> h :. pred t -- | Prepend a single-dimensional index to multi-dimensional one consIndex :: Index '[x] -> Index xs -> Index (x : xs) consIndex (I x) = \case + E -> I x I y -> I x :. I y I y :. xs -> I x :. I y :. xs -- | Concatenate two indices together concatIndex :: Index xs -> Index ys -> Index (xs ++ ys) concatIndex = \case + E -> id I x -> consIndex (I x) I x :. xs -> consIndex (I x) . concatIndex xs @@ -97,8 +106,9 @@ concatIndex = \case -- > _ :.= _ -> f j -- @f j@ can be called here because we have @Ix ds@ now -- > f i = ... data IndexI (dim :: [Natural]) where - (:.=) :: Ix (d : ds) => IndexI '[n] -> IndexI (d : ds) -> IndexI (n : d : ds) - II :: (KnownNat n, 1 <= n) => IndexI '[n] + EI :: IndexI '[] + (:.=) :: + (KnownNat n, 1 <= n, Ix ds) => Proxy n -> IndexI ds -> IndexI (n : ds) infixr 5 :.= @@ -108,8 +118,8 @@ class (Bounded (Index dim), Enum (Index dim)) => Ix dim where -- | Returns a term-level witness of @Ix@. inst :: IndexI dim -instance (KnownNat n, 1 <= n) => Ix '[n] where - inst = II +instance Ix '[] where + inst = EI -instance (KnownNat n, 1 <= n, Ix (d:ds)) => Ix (n:d:ds) where - inst = II :.= inst +instance (KnownNat n, 1 <= n, Ix ds) => Ix (n:ds) where + inst = Proxy :.= inst diff --git a/src/MLambda/NDArr.hs b/src/MLambda/NDArr.hs index 30cfb38..f60efc1 100644 --- a/src/MLambda/NDArr.hs +++ b/src/MLambda/NDArr.hs @@ -27,8 +27,8 @@ module MLambda.NDArr , row , rows -- * Array composition - , Stacks , Stack + , Stacks , stack -- * Unsafe API , unsafeMkNDArr @@ -70,7 +70,7 @@ instance (Ix (n:a:r), Show (NDArr (a:r) e), Storable e) => Show (NDArr (n:a:r) e) where showsPrec _ = case inst @(n:a:r) of - II :.= _ -> (showString "[" .) . (. showString "]") + _ :.= _ -> (showString "[" .) . (. showString "]") . foldl' (.) id . intersperse (showString ",\n") . map shows @@ -115,14 +115,17 @@ row :: row i a = rows a `at` i -- | Extract all "rows" from the array as an array. -rows :: forall d1 d2 e. (Ix d2, Storable e) => NDArr (d1 ++ d2) e -> NDArr d1 (NDArr d2 e) +rows :: + forall d1 d2 e. (Ix d2, Storable e) => + NDArr (d1 ++ d2) e -> NDArr d1 (NDArr d2 e) rows = MkNDArr . Storable.unsafeCast . runNDArr toList :: Storable e => NDArr d e -> [e] toList = Storable.toList . runNDArr concat :: - forall d1 d2 e. (Ix d2, Storable e) => NDArr d1 (NDArr d2 e) -> NDArr (d1 ++ d2) e + forall d1 d2 e. (Ix d2, Storable e) => + NDArr d1 (NDArr d2 e) -> NDArr (d1 ++ d2) e concat = MkNDArr . Storable.unsafeCast . runNDArr zipWith :: @@ -130,51 +133,49 @@ zipWith :: (a -> b -> c) -> NDArr d a -> NDArr d b -> NDArr d c zipWith f (MkNDArr xs) (MkNDArr ys) = MkNDArr (Storable.zipWith f xs ys) -data StackWitness i d1 d2 dr where - SZ :: (a + b ~ c, Ix (a : d), Ix (b : d), Ix (c : d)) - => Proxy '(a, b, c) -> Proxy d - -> StackWitness PZ (a : d) (b : d) (c : d) - SS :: ( Ix (a : s), Ix (b : s), Ix (c : s), a + b ~ c) - => Proxy p -> Proxy '(a, b, c) -> Proxy s - -> StackWitness (PS i) (p ++ (a : s)) (p ++ (b : s)) (p ++ (c : s)) +vstack :: + Storable e => NDArr (k : d) e -> NDArr (l : d) e -> NDArr ((k + l) : d) e +vstack (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys) + +-- | A type family which computes the resulting size of a stacked array. +type Stack n d e = StackImpl (StackError n d e) (Peano n) d e type StackError n d e = Text "Not enough dimensions to stack along axis " :<>: ShowType n :<>: Text ":" :$$: ShowType d :$$: ShowType e --- | A type family which computes the resulting size of a stacked array. -type family Stack msg i d e where - Stack _ PZ (n : d) (m : e) = n + m : Unify "Dimensions" d e - Stack msg (PS i) (n : d) (m : e) = Unify "Sizes" n m : Stack msg i d e - Stack msg _ _ _ = TypeError msg +type family StackImpl msg i d e where + StackImpl _ PZ (n : d) (m : e) = n + m : Unify "Dimensions" d e + StackImpl msg (PS i) (n : d) (m : e) = Unify "Sizes" n m : StackImpl msg i d e + StackImpl msg _ _ _ = TypeError msg -- | A constraint which links together compatible dimensions and axis along -- which they will be stacked. class Stacks i dim1 dim2 dimr where stacks :: StackWitness i dim1 dim2 dimr -instance (n + m ~ k, Ix (n : d), Ix (m : d), Ix (k : d)) - => Stacks PZ (n : d) (m : d) (k : d) where - stacks = SZ (Proxy @'(n, m, n + m)) (Proxy @d) +data StackWitness i d1 d2 dr where + SW :: + ( KnownNat k, KnownNat l, KnownNat (k + l), Ix t + , 1 <= k, 1 <= l, 1 <= k + l + ) => Proxy '(s, k, l, t) -> + StackWitness (Length s) (s ++ (k : t)) (s ++ (l : t)) (s ++ ((k + l) : t)) + +instance + ( KnownNat n, KnownNat m, KnownNat k, Ix d + , n + m ~ k, 1 <= n, 1 <= m, 1 <= k + ) => Stacks PZ (n : d) (m : d) (k : d) where + stacks = SW (Proxy @' ('[], n, m, d)) instance Stacks i d e r => Stacks (PS i) (n : d) (n : e) (n : r) where - stacks = case stacks @i @d @e @r of - (SZ m s) -> SS (Proxy @'[n]) m s - (SS (Proxy @p) m s) -> SS (Proxy @(n:p)) m s + stacks = case stacks @i of + SW (Proxy @'(s, k, l, t)) -> SW (Proxy @'(n : s, k, l, t)) -- | @stack i@ stacks arrays along the axis @i@. All other axes are required -- to be the same lengths. stack :: - forall n -> - ( err ~ StackError n d1 d2 - , Stacks (Peano n) d1 d2 (Stack err (Peano n) d1 d2) - , Storable e - ) => NDArr d1 e -> NDArr d2 e -> NDArr (Stack err (Peano n) d1 d2) e -stack n = go (stacks @(Peano n)) - where - go :: - forall n d1 d2 dr e. Storable e => StackWitness n d1 d2 dr -> - NDArr d1 e -> NDArr d2 e -> NDArr dr e - go (SZ _ _) (MkNDArr xs) (MkNDArr ys) = MkNDArr (xs <> ys) - go (SS (Proxy @p) m@(Proxy @'(a, b, _)) s@(Proxy @s)) xs ys = - concat $ zipWith (go (SZ m s)) (rows @p @(a : s) xs) (rows @p @(b : s) ys) + forall n -> (Stacks (Peano n) d1 d2 (Stack n d1 d2), Storable e) => + NDArr d1 e -> NDArr d2 e -> NDArr (Stack n d1 d2) e +stack n xs ys = case stacks @(Peano n) of + SW (Proxy @'(s, k, l, t)) -> + concat $ zipWith vstack (rows @s @(k : t) xs) (rows @s @(l : t) ys) diff --git a/src/MLambda/TypeLits.hs b/src/MLambda/TypeLits.hs index 426e51e..5f0b54b 100644 --- a/src/MLambda/TypeLits.hs +++ b/src/MLambda/TypeLits.hs @@ -19,6 +19,7 @@ module MLambda.TypeLits , Unify , type (++) , PNat (..) + , Length , Peano , RNat (..) , RPNat (..) @@ -52,6 +53,11 @@ type family xs ++ ys where -- | Peano naturals. data PNat = PZ | PS PNat +-- | Compute length of a type-level list as a Peano natural. +type family Length xs where + Length '[] = PZ + Length (_:xs) = PS (Length xs) + -- | Compute Peano representation from type-level natural. type family Peano n where Peano 0 = PZ From fe693b2d03abea4dd81e74a5c476cc81589e69d4 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Sat, 16 Aug 2025 05:32:07 +0300 Subject: [PATCH 2/2] export new constructor --- src/MLambda/Index.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MLambda/Index.hs b/src/MLambda/Index.hs index 49c9467..2aae007 100644 --- a/src/MLambda/Index.hs +++ b/src/MLambda/Index.hs @@ -12,7 +12,7 @@ -- This module contains definition of 'Index' type of multidimensional array -- indices along with its instances and public interface. module MLambda.Index - ( Index ((:.)) + ( Index (E, (:.)) , consIndex , concatIndex , IndexI (..)