{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

-- | Functions for manipulating shape. The module tends to supply equivalent functionality at type-level and value-level with functions of the same name (except for capitalization).
module Harpie.Shape
  ( -- * Type-level Nat
    SNat,
    pattern SNat,
    valueOf,

    -- * Type-level [Nat]
    SNats,
    pattern SNats,
    fromSNats,
    KnownNats (..),
    natVals,
    withKnownNats,
    SomeNats,
    someNatVals,
    withSomeSNats,

    -- * Shape
    valuesOf,
    rankOf,
    sizeOf,
    Fin (..),
    fin,
    safeFin,
    Fins (..),
    fins,
    safeFins,

    -- * Shape Operators at value- and type- level.
    rank,
    Rank,
    range,
    Range,
    rerank,
    Rerank,
    dimsOf,
    DimsOf,
    endDimsOf,
    EndDimsOf,
    size,
    Size,
    flatten,
    shapen,
    asSingleton,
    AsSingleton,
    asScalar,
    AsScalar,
    isSubset,
    IsSubset,
    exceptDims,
    ExceptDims,
    reorder,
    Reorder,
    ReorderOk,
    squeeze,
    Squeeze,

    -- * Primitives
    Min,
    Max,
    minimum,
    Minimum,

    -- * Position
    isFin,
    IsFin,
    isFins,
    IsFins,
    isDim,
    IsDim,
    isDims,
    IsDims,
    lastPos,
    LastPos,
    minDim,
    MinDim,

    -- * combinators
    EnumFromTo,
    Foldl',

    -- * single dimension
    GetIndex,
    SetIndex,
    getDim,
    GetDim,
    modifyDim,
    ModifyDim,
    incAt,
    IncAt,
    decAt,
    DecAt,
    setDim,
    SetDim,
    takeDim,
    TakeDim,
    dropDim,
    DropDim,
    deleteDim,
    DeleteDim,
    insertDim,
    InsertDim,
    InsertOk,
    SliceOk,
    SlicesOk,
    concatenate,
    Concatenate,
    ConcatenateOk,

    -- * multiple dimension
    getDims,
    GetDims,
    getLastPositions,
    GetLastPositions,
    modifyDims,
    insertDims,
    InsertDims,
    preDeletePositions,
    PreDeletePositions,
    preInsertPositions,
    PreInsertPositions,
    setDims,
    SetDims,
    deleteDims,
    DeleteDims,
    dropDims,
    DropDims,
    concatDims,
    ConcatDims,

    -- * value-only operations
    unconcatDimsIndex,
    reverseIndex,
    rotate,
    rotateIndex,
    rotatesIndex,
    isDiag,

    -- * windowed
    expandWindows,
    ExpandWindows,
    indexWindows,
    dimWindows,
    DimWindows,

    -- * Fcf re-exports
    Eval,
    type (++),
  )
where

import Data.Bool
import Data.Foldable hiding (minimum)
import Data.Function
import Data.List qualified as List
import Data.Maybe
import Data.Proxy
import Data.Type.Bool hiding (Not)
import Data.Type.Equality
import Data.Type.Ord hiding (Max, Min)
import Fcf hiding (type (&&), type (+), type (++), type (-), type (<), type (>), type (||))
import Fcf qualified
import Fcf.Class.Foldable
import Fcf.Data.List
import GHC.Exts
import GHC.TypeLits (ErrorMessage (..))
import GHC.TypeLits qualified as L
import GHC.TypeNats
import Prelude as P hiding (minimum)

{-# ANN module ("doctest-parallel: --no-implicit-module-import" :: String) #-}

-- $setup
-- >>> :set -XDataKinds
-- >>> :set -XTypeFamilies
-- >>> import Prelude
-- >>> import Fcf
-- >>> import GHC.Exts ()
-- >>> import Harpie.Shape as S

-- | Get the value of a type level Nat.
-- Use with explicit type application
--
-- >>> valueOf @42
-- 42
valueOf :: forall n. (KnownNat n) => Int
valueOf :: forall (n :: Nat). KnownNat n => Int
valueOf = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ SNat n -> Nat
forall (n :: Nat). SNat n -> Nat
fromSNat (forall (n :: Nat). KnownNat n => SNat n
SNat @n)
{-# INLINE valueOf #-}

type role SNats nominal

-- | A value-level witness for a type-level list of natural numbers.
--
-- Obtain an SNats value using:
--
-- - The natsSing method of KnownNats
-- - The SNats pattern
-- - The withSomeSNats function
--
-- >>> :t SNats @[2,3,4]
-- SNats @[2,3,4] :: KnownNats [2, 3, 4] => SNats [2, 3, 4]
-- >>> SNats @[2,3,4]
-- SNats @[2, 3, 4]
newtype SNats (ns :: [Nat]) = UnsafeSNats [Nat]

instance Eq (SNats ns) where
  SNats ns
_ == :: SNats ns -> SNats ns -> Bool
== SNats ns
_ = Bool
True

instance Ord (SNats ns) where
  compare :: SNats ns -> SNats ns -> Ordering
compare SNats ns
_ SNats ns
_ = Ordering
EQ

-- | Matches GHC printing quirks.
instance Show (SNats ns) where
  show :: SNats ns -> String
show (UnsafeSNats [Nat]
s) = String
"SNats @" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String -> String -> Bool -> String
forall a. a -> a -> Bool -> a
bool String
"" String
"'" ([Nat] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nat]
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"[" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [String] -> String
forall a. Monoid a => [a] -> a
mconcat (String -> [String] -> [String]
forall a. a -> [a] -> [a]
List.intersperse String
", " (Nat -> String
forall a. Show a => a -> String
show (Nat -> String) -> [Nat] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Nat]
s)) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"]"

-- | A explicitly bidirectional pattern synonym relating an 'SNats' to a 'KnownNats' constraint.
--
-- As an expression: Constructs an explicit 'SNats' ns value from an implicit 'KnownNats' ns constraint:
--
-- > SNat @n :: KnownNat n => SNat n
--
-- As a pattern: Matches on an explicit SNats n value bringing an implicit KnownNats n constraint into scope:
--
-- > f :: SNats ns -> ..
-- > f SNat = {- KnownNats ns in scope -}
--
-- or, if you need to both bring the KnownNats into scope and reuse the SNats input:
--
-- > f (SNats :: SNats s) = g (SNats @s)
pattern SNats :: forall ns. () => (KnownNats ns) => SNats ns
pattern $mSNats :: forall {r} {ns :: [Nat]}.
SNats ns -> (KnownNats ns => r) -> ((# #) -> r) -> r
$bSNats :: forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats <- (knownNatsInstance -> KnownNatsInstance)
  where
    SNats = SNats ns
forall (ns :: [Nat]). KnownNats ns => SNats ns
natsSing

{-# COMPLETE SNats #-}

-- | Return the value-level list of naturals in an SNats ns value.
--
-- >>> fromSNats (SNats @[2,3,4])
-- [2,3,4]
fromSNats :: SNats s -> [Nat]
fromSNats :: forall (s :: [Nat]). SNats s -> [Nat]
fromSNats (UnsafeSNats [Nat]
s) = [Nat]
s

-- An internal data type that is only used for defining the SNat pattern
-- synonym.
data KnownNatsInstance (ns :: [Nat]) where
  KnownNatsInstance :: (KnownNats ns) => KnownNatsInstance ns

-- An internal function that is only used for defining the SNat pattern
-- synonym.
knownNatsInstance :: SNats ns -> KnownNatsInstance ns
knownNatsInstance :: forall (ns :: [Nat]). SNats ns -> KnownNatsInstance ns
knownNatsInstance SNats ns
dims = SNats ns
-> (KnownNats ns => KnownNatsInstance ns) -> KnownNatsInstance ns
forall (ns :: [Nat]) r. SNats ns -> (KnownNats ns => r) -> r
withKnownNats SNats ns
dims KnownNatsInstance ns
KnownNats ns => KnownNatsInstance ns
forall (ns :: [Nat]). KnownNats ns => KnownNatsInstance ns
KnownNatsInstance

-- | Reflect a list of naturals.
--
-- >>> natsSing @'[2]
-- SNats @'[2]
class KnownNats (ns :: [Nat]) where
  natsSing :: SNats ns

instance KnownNats '[] where
  natsSing :: SNats '[]
natsSing = [Nat] -> SNats '[]
forall (ns :: [Nat]). [Nat] -> SNats ns
UnsafeSNats []

instance (KnownNat n, KnownNats s) => KnownNats (n ': s) where
  natsSing :: SNats (n : s)
natsSing = [Nat] -> SNats (n : s)
forall (ns :: [Nat]). [Nat] -> SNats ns
UnsafeSNats (SNat n -> Nat
forall (n :: Nat). SNat n -> Nat
fromSNat (SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat n) Nat -> [Nat] -> [Nat]
forall a. a -> [a] -> [a]
: SNats s -> [Nat]
forall (s :: [Nat]). SNats s -> [Nat]
fromSNats (SNats s
forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats :: SNats s))

-- | Obtain a value-level list of naturals from a type-level proxy
--
-- >>> natVals (SNats @[2,3,4])
-- [2,3,4]
natVals :: forall ns proxy. (KnownNats ns) => proxy ns -> [Nat]
natVals :: forall (ns :: [Nat]) (proxy :: [Nat] -> *).
KnownNats ns =>
proxy ns -> [Nat]
natVals proxy ns
_ = case SNats ns
forall (ns :: [Nat]). KnownNats ns => SNats ns
natsSing :: SNats ns of
  UnsafeSNats [Nat]
xs -> [Nat]
xs

-- | Convert an explicit SNats ns value into an implicit KnownNats ns constraint.
withKnownNats ::
  forall ns rep (r :: TYPE rep).
  SNats ns -> ((KnownNats ns) => r) -> r
withKnownNats :: forall (ns :: [Nat]) r. SNats ns -> (KnownNats ns => r) -> r
withKnownNats = forall (cls :: Constraint) meth r.
WithDict cls meth =>
meth -> (cls => r) -> r
withDict @(KnownNats ns)

-- | Convert a list of naturals into an SNats ns value, where ns is a fresh type-level list of naturals.
withSomeSNats ::
  forall rep (r :: TYPE rep).
  [Nat] -> (forall s. SNats s -> r) -> r
withSomeSNats :: forall r. [Nat] -> (forall (s :: [Nat]). SNats s -> r) -> r
withSomeSNats [Nat]
s forall (s :: [Nat]). SNats s -> r
k = SNats Any -> r
forall (s :: [Nat]). SNats s -> r
k ([Nat] -> SNats Any
forall (ns :: [Nat]). [Nat] -> SNats ns
UnsafeSNats [Nat]
s)
{-# NOINLINE withSomeSNats #-}

-- | An unknown type-level list of naturals.
data SomeNats = forall s. (KnownNats s) => SomeNats (Proxy s)

-- | Promote a list of naturals to unknown type-level
someNatVals :: [Nat] -> SomeNats
someNatVals :: [Nat] -> SomeNats
someNatVals [Nat]
s =
  [Nat] -> (forall (s :: [Nat]). SNats s -> SomeNats) -> SomeNats
forall r. [Nat] -> (forall (s :: [Nat]). SNats s -> r) -> r
withSomeSNats
    [Nat]
s
    ( \(SNats s
sn :: SNats s) ->
        SNats s -> (KnownNats s => SomeNats) -> SomeNats
forall (ns :: [Nat]) r. SNats ns -> (KnownNats ns => r) -> r
withKnownNats SNats s
sn (forall (s :: [Nat]). KnownNats s => Proxy s -> SomeNats
SomeNats @s Proxy s
forall {k} (t :: k). Proxy t
Proxy)
    )

-- * shape primitives

-- | The value of a 'KnownNats'.
--
-- >>> valuesOf @[2,3,4]
-- [2,3,4]
valuesOf :: forall s. (KnownNats s) => [Int]
valuesOf :: forall (s :: [Nat]). KnownNats s => [Int]
valuesOf = (Nat -> Int) -> [Nat] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SNats s -> [Nat]
forall (s :: [Nat]). SNats s -> [Nat]
fromSNats (SNats s
forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats :: SNats s))
{-# INLINE valuesOf #-}

-- | The rank (or length) of a KnownNats.
--
-- >>> rankOf @[2,3,4]
-- 3
rankOf :: forall s. (KnownNats s) => Int
rankOf :: forall (s :: [Nat]). KnownNats s => Int
rankOf = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s)
{-# INLINE rankOf #-}

-- | The size (or product) of a KnownNats.
--
-- >>> sizeOf @[2,3,4]
-- 24
sizeOf :: forall s. (KnownNats s) => Int
sizeOf :: forall (s :: [Nat]). KnownNats s => Int
sizeOf = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s)
{-# INLINE sizeOf #-}

-- | Fin most often represents a (finite) zero-based index for a single dimension (of a multi-dimensioned hyper-rectangular array).
type role Fin nominal

newtype Fin s
  = UnsafeFin
  { forall {k} (s :: k). Fin s -> Int
fromFin :: Int
  }
  deriving stock (Fin s -> Fin s -> Bool
(Fin s -> Fin s -> Bool) -> (Fin s -> Fin s -> Bool) -> Eq (Fin s)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (s :: k). Fin s -> Fin s -> Bool
$c== :: forall k (s :: k). Fin s -> Fin s -> Bool
== :: Fin s -> Fin s -> Bool
$c/= :: forall k (s :: k). Fin s -> Fin s -> Bool
/= :: Fin s -> Fin s -> Bool
Eq, Eq (Fin s)
Eq (Fin s) =>
(Fin s -> Fin s -> Ordering)
-> (Fin s -> Fin s -> Bool)
-> (Fin s -> Fin s -> Bool)
-> (Fin s -> Fin s -> Bool)
-> (Fin s -> Fin s -> Bool)
-> (Fin s -> Fin s -> Fin s)
-> (Fin s -> Fin s -> Fin s)
-> Ord (Fin s)
Fin s -> Fin s -> Bool
Fin s -> Fin s -> Ordering
Fin s -> Fin s -> Fin s
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (s :: k). Eq (Fin s)
forall k (s :: k). Fin s -> Fin s -> Bool
forall k (s :: k). Fin s -> Fin s -> Ordering
forall k (s :: k). Fin s -> Fin s -> Fin s
$ccompare :: forall k (s :: k). Fin s -> Fin s -> Ordering
compare :: Fin s -> Fin s -> Ordering
$c< :: forall k (s :: k). Fin s -> Fin s -> Bool
< :: Fin s -> Fin s -> Bool
$c<= :: forall k (s :: k). Fin s -> Fin s -> Bool
<= :: Fin s -> Fin s -> Bool
$c> :: forall k (s :: k). Fin s -> Fin s -> Bool
> :: Fin s -> Fin s -> Bool
$c>= :: forall k (s :: k). Fin s -> Fin s -> Bool
>= :: Fin s -> Fin s -> Bool
$cmax :: forall k (s :: k). Fin s -> Fin s -> Fin s
max :: Fin s -> Fin s -> Fin s
$cmin :: forall k (s :: k). Fin s -> Fin s -> Fin s
min :: Fin s -> Fin s -> Fin s
Ord)

instance Show (Fin n) where
  show :: Fin n -> String
show (UnsafeFin Int
x) = Int -> String
forall a. Show a => a -> String
show Int
x

-- | Construct a Fin.
--
-- Errors on out-of-bounds
--
-- >>> fin @2 1
-- 1
--
-- >>> fin @2 2
-- *** Exception: value outside bounds
-- ...
fin :: forall n. (KnownNat n) => Int -> Fin n
fin :: forall (n :: Nat). KnownNat n => Int -> Fin n
fin Int
x = Fin n -> Maybe (Fin n) -> Fin n
forall a. a -> Maybe a -> a
fromMaybe (String -> Fin n
forall a. HasCallStack => String -> a
error String
"value outside bounds") (Int -> Maybe (Fin n)
forall (n :: Nat). KnownNat n => Int -> Maybe (Fin n)
safeFin Int
x)

-- | Construct a Fin safely.
--
-- >>> safeFin 1 :: Maybe (Fin 2)
-- Just 1
--
-- >>> safeFin 2 :: Maybe (Fin 2)
-- Nothing
safeFin :: forall n. (KnownNat n) => Int -> Maybe (Fin n)
safeFin :: forall (n :: Nat). KnownNat n => Int -> Maybe (Fin n)
safeFin Int
x = Maybe (Fin n) -> Maybe (Fin n) -> Bool -> Maybe (Fin n)
forall a. a -> a -> Bool -> a
bool Maybe (Fin n)
forall a. Maybe a
Nothing (Fin n -> Maybe (Fin n)
forall a. a -> Maybe a
Just (Int -> Fin n
forall {k} (s :: k). Int -> Fin s
UnsafeFin Int
x)) (Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< forall (n :: Nat). KnownNat n => Int
valueOf @n)

-- | Fins most often represents (finite) indexes for multiple dimensions (of a multi-dimensioned hyper-rectangular array).
type role Fins nominal

newtype Fins s
  = UnsafeFins
  { forall {k} (s :: k). Fins s -> [Int]
fromFins :: [Int]
  }
  deriving stock (Fins s -> Fins s -> Bool
(Fins s -> Fins s -> Bool)
-> (Fins s -> Fins s -> Bool) -> Eq (Fins s)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (s :: k). Fins s -> Fins s -> Bool
$c== :: forall k (s :: k). Fins s -> Fins s -> Bool
== :: Fins s -> Fins s -> Bool
$c/= :: forall k (s :: k). Fins s -> Fins s -> Bool
/= :: Fins s -> Fins s -> Bool
Eq, Eq (Fins s)
Eq (Fins s) =>
(Fins s -> Fins s -> Ordering)
-> (Fins s -> Fins s -> Bool)
-> (Fins s -> Fins s -> Bool)
-> (Fins s -> Fins s -> Bool)
-> (Fins s -> Fins s -> Bool)
-> (Fins s -> Fins s -> Fins s)
-> (Fins s -> Fins s -> Fins s)
-> Ord (Fins s)
Fins s -> Fins s -> Bool
Fins s -> Fins s -> Ordering
Fins s -> Fins s -> Fins s
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (s :: k). Eq (Fins s)
forall k (s :: k). Fins s -> Fins s -> Bool
forall k (s :: k). Fins s -> Fins s -> Ordering
forall k (s :: k). Fins s -> Fins s -> Fins s
$ccompare :: forall k (s :: k). Fins s -> Fins s -> Ordering
compare :: Fins s -> Fins s -> Ordering
$c< :: forall k (s :: k). Fins s -> Fins s -> Bool
< :: Fins s -> Fins s -> Bool
$c<= :: forall k (s :: k). Fins s -> Fins s -> Bool
<= :: Fins s -> Fins s -> Bool
$c> :: forall k (s :: k). Fins s -> Fins s -> Bool
> :: Fins s -> Fins s -> Bool
$c>= :: forall k (s :: k). Fins s -> Fins s -> Bool
>= :: Fins s -> Fins s -> Bool
$cmax :: forall k (s :: k). Fins s -> Fins s -> Fins s
max :: Fins s -> Fins s -> Fins s
$cmin :: forall k (s :: k). Fins s -> Fins s -> Fins s
min :: Fins s -> Fins s -> Fins s
Ord, (forall a b. (a -> b) -> Fins a -> Fins b)
-> (forall a b. a -> Fins b -> Fins a) -> Functor Fins
forall a b. a -> Fins b -> Fins a
forall a b. (a -> b) -> Fins a -> Fins b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Fins a -> Fins b
fmap :: forall a b. (a -> b) -> Fins a -> Fins b
$c<$ :: forall a b. a -> Fins b -> Fins a
<$ :: forall a b. a -> Fins b -> Fins a
Functor)

instance Show (Fins n) where
  show :: Fins n -> String
show (UnsafeFins [Int]
x) = [Int] -> String
forall a. Show a => a -> String
show [Int]
x

-- | Construct a Fins.
--
-- Errors on out-of-bounds
--
-- >>> fins @[2,3,4] [1,2,3]
-- [1,2,3]
--
-- >>> fins @[2,3,4] [1,2,5]
-- *** Exception: value outside bounds
-- ...
fins :: forall s. (KnownNats s) => [Int] -> Fins s
fins :: forall (s :: [Nat]). KnownNats s => [Int] -> Fins s
fins [Int]
x = Fins s -> Maybe (Fins s) -> Fins s
forall a. a -> Maybe a -> a
fromMaybe (String -> Fins s
forall a. HasCallStack => String -> a
error String
"value outside bounds") ([Int] -> Maybe (Fins s)
forall (s :: [Nat]). KnownNats s => [Int] -> Maybe (Fins s)
safeFins [Int]
x)

-- | Construct a Fins safely.
--
-- >>> safeFins [1,2,3] :: Maybe (Fins [2,3,4])
-- Just [1,2,3]
--
-- >>> safeFins [2] :: Maybe (Fins '[2])
-- Nothing
safeFins :: forall s. (KnownNats s) => [Int] -> Maybe (Fins s)
safeFins :: forall (s :: [Nat]). KnownNats s => [Int] -> Maybe (Fins s)
safeFins [Int]
xs = Maybe (Fins s) -> Maybe (Fins s) -> Bool -> Maybe (Fins s)
forall a. a -> a -> Bool -> a
bool Maybe (Fins s)
forall a. Maybe a
Nothing (Fins s -> Maybe (Fins s)
forall a. a -> Maybe a
Just ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
xs)) ([Int] -> [Int] -> Bool
isFins [Int]
xs (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s))

-- | Number of dimensions
--
-- >>> rank @Int [2,3,4]
-- 3
rank :: [a] -> Int
rank :: forall a. [a] -> Int
rank = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length
{-# INLINE rank #-}

-- | Number of dimensions
--
-- >>> :k! Eval (Rank [2,3,4])
-- ...
-- = 3
data Rank :: [a] -> Exp Natural

type instance
  Eval (Rank xs) =
    Eval (Length xs)

-- | Enumerate a range of rank n
--
-- >>> range 0
-- []
--
-- >>> range 3
-- [0,1,2]
range :: Int -> [Int]
range :: Int -> [Int]
range Int
n = [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]

-- | Enumerate a range of rank n
--
-- >>> :k! Eval (Range 0)
-- ...
-- = '[]
--
-- >>> :k! Eval (Range 3)
-- ...
-- = [0, 1, 2]
data Range :: Nat -> Exp [Nat]

type instance
  Eval (Range x) =
    If (x == 0) '[] (Eval (EnumFromTo 0 (Eval ((Fcf.-) x 1))))

-- | Create a new rank by adding ones to the left, if the new rank is greater, or combining dimensions (from left to right) into rows, if the new rank is lower.
--
-- >>> rerank 4 [2,3,4]
-- [1,2,3,4]
-- >>> rerank 2 [2,3,4]
-- [6,4]
rerank :: Int -> [Int] -> [Int]
rerank :: Int -> [Int] -> [Int]
rerank Int
r [Int]
xs =
  Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r') Int
1
    [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> [Int] -> [Int] -> Bool -> [Int]
forall a. a -> a -> Bool -> a
bool [] [[Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (Int
r' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
xs)] (Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
r')
    [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
r' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
xs
  where
    r' :: Int
r' = [Int] -> Int
forall a. [a] -> Int
rank [Int]
xs

-- | Create a new rank by adding ones to the left, if the new rank is greater, or combining dimensions (from left to right) into rows, if the new rank is lower.
--
-- >>> :k! Eval (Rerank 4 [2,3,4])
-- ...
-- = [1, 2, 3, 4]
-- >>> :k! Eval (Rerank 2 [2,3,4])
-- ...
-- = [6, 4]
data Rerank :: Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (Rerank r xs) =
    If
      (Eval ((Fcf.>) r (Eval (Rank xs))))
      (Eval (Eval (Replicate (Eval ((Fcf.-) r (Eval (Rank xs)))) 1) ++ xs))
      ( Eval
          ( '[Eval (Size (Eval (Take (Eval ((Fcf.+) (Eval ((Fcf.-) (Eval (Rank xs)) r)) 1)) xs)))]
              ++ Eval (Drop (Eval ((Fcf.+) (Eval ((Fcf.-) (Eval (Rank xs)) r)) 1)) xs)
          )
      )

-- | Enumerate the dimensions of a shape.
--
-- dimsOf [2,3,4]
-- [0,1,2]
dimsOf :: [Int] -> [Int]
dimsOf :: [Int] -> [Int]
dimsOf [Int]
s = Int -> [Int]
range ([Int] -> Int
forall a. [a] -> Int
rank [Int]
s)

-- | Enumerate the dimensions of a shape.
--
-- >>> :k! Eval (DimsOf [2,3,4])
-- ...
-- = [0, 1, 2]
data DimsOf :: [Nat] -> Exp [Nat]

type instance
  Eval (DimsOf xs) =
    Eval (Range =<< Rank xs)

-- | Enumerate the final dimensions of a shape.
--
-- >>> endDimsOf [1,0] [2,3,4]
-- [2,1]
endDimsOf :: [Int] -> [Int] -> [Int]
endDimsOf :: [Int] -> [Int] -> [Int]
endDimsOf [Int]
xs [Int]
s = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take ([Int] -> Int
forall a. [a] -> Int
rank [Int]
xs) ([Int] -> [Int]
forall a. [a] -> [a]
List.reverse ([Int] -> [Int]
dimsOf [Int]
s))

-- | Enumerate the final dimensions of a shape.
--
-- >>> :k! Eval (EndDimsOf [1,0] [2,3,4])
-- ...
-- = [2, 1]
data EndDimsOf :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (EndDimsOf xs s) =
    Eval (LiftM2 Take (Rank xs) (Reverse =<< DimsOf s))

-- | Total number of elements (if the list is the shape of a hyper-rectangular array).
--
-- >>> size [2,3,4]
-- 24
size :: [Int] -> Int
size :: [Int] -> Int
size [] = Int
1
size [Int
x] = Int
x
size [Int]
xs = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
P.product [Int]
xs
{-# INLINE size #-}

-- | Total number of elements (if the list is the shape of a hyper-rectangular array).
--
-- >>> :k! (Eval (Size [2,3,4]))
-- ...
-- = 24
data Size :: [Nat] -> Exp Nat

type instance Eval (Size xs) = Eval (Foldr (Fcf.*) 1 xs)

-- | Convert from a n-dimensional shape list index to a flat index, which, technically is the lexicographic position of the position in a row-major array.
--
-- >>> flatten [2,3,4] [1,1,1]
-- 17
--
-- >>> flatten [] [1,1,1]
-- 0
flatten :: [Int] -> [Int] -> Int
flatten :: [Int] -> [Int] -> Int
flatten [] [Int]
_ = Int
0
flatten [Int]
_ [Int
x'] = Int
x'
flatten [Int]
ns [Int]
xs = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
xs (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
1 ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 [Int]
ns)
{-# INLINE flatten #-}

-- | Convert from a flat index to a shape index.
--
-- >>> shapen [2,3,4] 17
-- [1,1,1]
shapen :: [Int] -> Int -> [Int]
shapen :: [Int] -> Int -> [Int]
shapen [] Int
_ = []
shapen [Int
_] Int
x' = [Int
x']
shapen [Int
_, Int
y] Int
x' = let (Int
i, Int
j) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
x' Int
y in [Int
i, Int
j]
shapen [Int]
ns Int
x =
  ([Int], Int) -> [Int]
forall a b. (a, b) -> a
fst (([Int], Int) -> [Int]) -> ([Int], Int) -> [Int]
forall a b. (a -> b) -> a -> b
$
    (Int -> ([Int], Int) -> ([Int], Int))
-> ([Int], Int) -> [Int] -> ([Int], Int)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      ( \Int
a ([Int]
acc, Int
r) ->
          let (Int
d, Int
m) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
r Int
a
           in (Int
m Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
acc, Int
d)
      )
      ([], Int
x)
      [Int]
ns
{-# INLINE shapen #-}

-- | Convert a scalar to a dimensioned shape
--
-- >>> asSingleton []
-- [1]
-- >>> asSingleton [2,3,4]
-- [2,3,4]
asSingleton :: [Int] -> [Int]
asSingleton :: [Int] -> [Int]
asSingleton [] = [Int
1]
asSingleton [Int]
x = [Int]
x

-- | Convert a scalar to a dimensioned shape
-- >>> :k! Eval (AsSingleton '[])
-- ...
-- = '[1]
-- >>> :k! Eval (AsSingleton [2,3,4])
-- ...
-- = [2, 3, 4]
data AsSingleton :: [Nat] -> Exp [Nat]

type instance
  Eval (AsSingleton xs) =
    If (xs == '[]) '[1] xs

-- | Convert a (potentially) [1] dimensioned shape to a scalar shape
--
-- >>> asScalar [1]
-- []
-- >>> asScalar [2,3,4]
-- [2,3,4]
asScalar :: [Int] -> [Int]
asScalar :: [Int] -> [Int]
asScalar [Int
1] = []
asScalar [Int]
x = [Int]
x

-- | Convert a (potentially) [1] dimensioned shape to a scalar shape
-- >>> :k! Eval (AsScalar '[1])
-- ...
-- = '[]
-- >>> :k! Eval (AsScalar [2,3,4])
-- ...
-- = [2, 3, 4]
data AsScalar :: [Nat] -> Exp [Nat]

type instance
  Eval (AsScalar xs) =
    If (xs == '[1]) '[] xs

lte :: [Int] -> [Int] -> Bool
lte :: [Int] -> [Int] -> Bool
lte [Int]
xs [Int]
ys =
  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(<=) [Int]
xs [Int]
ys)
    Bool -> Bool -> Bool
&& [Int] -> Int
forall a. [a] -> Int
rank [Int]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
rank [Int]
ys

data LTE :: [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (LTE xs ys) =
    Eval
      ( LiftM2
          (Fcf.&&)
          (And =<< ZipWith (Fcf.<=) xs ys)
          (LiftM2 TyEq (Rank xs) (Rank ys))
      )

-- | Check if a shape is a subset (<=) another shape after reranking.
--
-- >>> isSubset [2,3,4] [2,3,4]
-- True
--
-- >>> isSubset [1,2] [2,3,4]
-- True
--
-- >>> isSubset [2,1] [1]
-- False
isSubset :: [Int] -> [Int] -> Bool
isSubset :: [Int] -> [Int] -> Bool
isSubset [Int]
xs [Int]
ys = [Int] -> [Int] -> Bool
lte (Int -> [Int] -> [Int]
rerank ([Int] -> Int
forall a. [a] -> Int
rank [Int]
ys) [Int]
xs) [Int]
ys

-- | Check if a shape is a subset (<=) another shape after reranking.
--
-- >>> :k! Eval (IsSubset [2,3,4] [2,3,4])
-- ...
-- = True
--
-- >>> :k! Eval (IsSubset [1,2] [2,3,4])
-- ...
-- = True
--
-- >>> :k! Eval (IsSubset [2,1] '[1])
-- ...
-- = False
data IsSubset :: [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (IsSubset xs ys) =
    Eval (LTE (Eval (Rerank (Eval (Rank ys)) xs)) ys)

-- | Compute dimensions for a shape other than the supplied dimensions.
--
-- >>> exceptDims [1,2] [2,3,4]
-- [0]
exceptDims :: [Int] -> [Int] -> [Int]
exceptDims :: [Int] -> [Int] -> [Int]
exceptDims [Int]
ds [Int]
s = [Int] -> [Int] -> [Int]
deleteDims [Int]
ds [Int
0 .. ([Int] -> Int
forall a. [a] -> Int
rank [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]

-- | Compute dimensions for a shape other than the supplied dimensions.
--
-- >>> :k! Eval (ExceptDims [1,2] [2,3,4])
-- ...
-- = '[0]
data ExceptDims :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (ExceptDims ds s) =
    Eval (DeleteDims ds =<< EnumFromTo 0 (Eval ((Fcf.-) (Eval (Rank s)) 1)))

-- | Reorder the dimensions of shape according to a list of positions.
--
-- >>> reorder [2,3,4] [2,0,1]
-- [4,2,3]
reorder :: [Int] -> [Int] -> [Int]
reorder :: [Int] -> [Int] -> [Int]
reorder [] [Int]
_ = []
reorder [Int]
_ [] = []
reorder [Int]
s (Int
d : [Int]
ds) = Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
reorder [Int]
s [Int]
ds

-- | Reorder the dimensions of shape according to a list of positions.
--
-- >>> :k! Eval (Reorder [2,3,4] [2,0,1])
-- ...
-- = [4, 2, 3]
data Reorder :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (Reorder ds xs) =
    If
      (Eval (ReorderOk ds xs))
      (Eval (Map (Flip GetDim ds) xs))
      (L.TypeError ('Text "Reorder dimension indices out of bounds"))

-- | Test if a Reorder is valid.
--
-- >>> :k! Eval (ReorderOk [2,3,4] [0,1])
-- ...
-- = False
data ReorderOk :: [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (ReorderOk ds xs) =
    Eval (TyEq (Eval (Rank ds)) (Eval (Rank xs)))
      && Eval (And =<< Map (Flip IsFin (Eval (Rank ds))) xs)

-- | remove 1's from a list
--
-- >>> squeeze [0,1,2,3]
-- [0,2,3]
squeeze :: [Int] -> [Int]
squeeze :: [Int] -> [Int]
squeeze = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1)

-- | Remove 1's from a list.
--
-- >>> :k! (Eval (Squeeze [0,1,2,3]))
-- ...
-- = [0, 2, 3]
data Squeeze :: [a] -> Exp [a]

type instance
  Eval (Squeeze xs) =
    Eval (Filter (Not <=< TyEq 1) xs)

-- | minimum of a list
--
-- >>> S.minimum []
-- *** Exception: zero-ranked
-- ...
-- >>> S.minimum [2,3,4]
-- 2
minimum :: [Int] -> Int
minimum :: [Int] -> Int
minimum [] = String -> Int
forall a. HasCallStack => String -> a
error String
"zero-ranked"
minimum [Int
x] = Int
x
minimum (Int
x : [Int]
xs) = Int -> Int -> Int
forall a. Ord a => a -> a -> a
P.min Int
x ([Int] -> Int
minimum [Int]
xs)

-- | minimum of a list
--
-- >>> :k! Eval (Minimum '[])
-- ...
-- = (TypeError ...)
--
-- >>> :k! Eval (Minimum [2,3,4])
-- ...
-- = 2
data Minimum :: [a] -> Exp a

type instance Eval (Minimum '[]) = L.TypeError (L.Text "zero ranked")

type instance
  Eval (Minimum (x ': xs)) =
    Eval (Foldr Min x xs)

-- | Minimum of two type values.
--
-- >>> :k! Eval (Min 0 1)
-- ...
-- = 0
data Min :: a -> a -> Exp a

type instance Eval (Min a b) = If (a <? b) a b

-- | Maximum of two type values.
--
-- >>> :k! Eval (Max 0 1)
-- ...
-- = 1
data Max :: a -> a -> Exp a

type instance Eval (Max a b) = If (a >? b) a b

-- | Check if i is a valid Fin (aka in-bounds index of a dimension)
--
-- >>> isFin 0 2
-- True
-- >>> isFin 2 2
-- False
isFin :: Int -> Int -> Bool
isFin :: Int -> Int -> Bool
isFin Int
i Int
d = Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
d

-- | Check if i is a valid Fin (aka in-bounds index of a dimension)
--
-- >>> :k! Eval (IsFin 0 2)
-- ...
-- = True
-- >>> :k! Eval (IsFin 2 2)
-- ...
-- = False
data IsFin :: Nat -> Nat -> Exp Bool

type instance
  Eval (IsFin x d) =
    x <? d

-- | Check if i is a valid Fins (aka in-bounds index of a Shape)
--
-- >>> isFins [0,1] [2,2]
-- True
-- >>> isFins [0,1] [2,1]
-- False
isFins :: [Int] -> [Int] -> Bool
isFins :: [Int] -> [Int] -> Bool
isFins [Int]
xs [Int]
ds = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
ds Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Bool
isFin [Int]
xs [Int]
ds)

-- | Check if i is a valid Fins (aka in-bounds index of a Shape)
--
-- >>> :k! Eval (IsFins [0,1] [2,2])
-- ...
-- = True
-- >>> :k! Eval (IsFins [0,1] [2,1])
-- ...
-- = False
data IsFins :: [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (IsFins xs ds) =
    Eval (And (Eval (ZipWith IsFin xs ds)))
      && Eval (LiftM2 TyEq (Rank xs) (Rank ds))

-- | Is a value a valid dimension of a shape.
--
-- >>> isDim 2 [2,3,4]
-- True
-- >>> isDim 0 []
-- True
isDim :: Int -> [Int] -> Bool
isDim :: Int -> [Int] -> Bool
isDim Int
d [Int]
s = Int -> Int -> Bool
isFin Int
d ([Int] -> Int
forall a. [a] -> Int
rank [Int]
s) Bool -> Bool -> Bool
|| Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
s

-- | Is a value a valid dimension of a shape.
--
-- >>> :k! Eval (IsDim 2 [2,3,4])
-- ...
-- = True
-- >>> :k! Eval (IsDim 0 '[])
-- ...
-- = True
data IsDim :: Nat -> [Nat] -> Exp Bool

type instance
  Eval (IsDim d s) =
    Eval (IsFin d =<< Rank s)
      || (0 == d && s == '[])

-- | Are values valid dimensions of a shape.
--
-- >>> isDims [2,1] [2,3,4]
-- True
-- >>> isDims [0] []
-- True
isDims :: [Int] -> [Int] -> Bool
isDims :: [Int] -> [Int] -> Bool
isDims [Int]
ds [Int]
s = (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> [Int] -> Bool
`isDim` [Int]
s) [Int]
ds

-- | Are values valid dimensions of a shape.
--
-- >>> :k! Eval (IsDims [2,1] [2,3,4])
-- ...
-- = True
-- >>> :k! Eval (IsDims '[0] '[])
-- ...
-- = True
data IsDims :: [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (IsDims ds s) =
    Eval (And =<< Map (Flip IsDim s) ds)

-- | Get the last position of a dimension of a shape.
--
-- >>> lastPos 2 [2,3,4]
-- 3
-- >>> lastPos 0 []
-- 0
lastPos :: Int -> [Int] -> Int
lastPos :: Int -> [Int] -> Int
lastPos Int
d [Int]
s =
  Int -> Int -> Bool -> Int
forall a. a -> a -> Bool -> a
bool (Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
0 (Int
0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
d Bool -> Bool -> Bool
&& [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
s)

-- | Get the last position of a dimension of a shape.
--
-- >>> :k! Eval (LastPos 2 [2,3,4])
-- ...
-- = 3
-- >>> :k! Eval (LastPos 0 '[])
-- ...
-- = 0
data LastPos :: Nat -> [Nat] -> Exp Nat

type instance
  Eval (LastPos d s) =
    If
      (0 == d && s == '[])
      0
      (Eval (GetDim d s) - 1)

-- | Get the minimum dimension as a singleton dimension.
--
-- >>> minDim [2,3,4]
-- [2]
-- >>> minDim []
-- []
minDim :: [Int] -> [Int]
minDim :: [Int] -> [Int]
minDim [] = []
minDim [Int]
s = [[Int] -> Int
minimum [Int]
s]

-- | Get the minimum dimension as a singleton dimension.
--
-- >>> :k! Eval (MinDim [2,3,4])
-- ...
-- = '[2]
-- >>> :k! Eval (MinDim '[])
-- ...
-- = '[]
data MinDim :: [Nat] -> Exp [Nat]

type instance
  Eval (MinDim s) =
    If
      (s == '[])
      '[]
      '[Eval (Minimum s)]

-- | Enumerate between two Nats
--
-- >>> :k! Eval (EnumFromTo 0 3)
-- ...
-- = [0, 1, 2, 3]
data EnumFromTo :: Nat -> Nat -> Exp [Nat]

type instance Eval (EnumFromTo a b) = Eval (Unfoldr (EnumFromToHelper b) a)

data EnumFromToHelper :: Nat -> Nat -> Exp (Maybe (a, Nat))

type instance
  Eval (EnumFromToHelper b a) =
    If
      (a >? b)
      'Nothing
      ('Just '(a, a + 1))

-- | Left fold.
--
-- >>> :k! Eval (Foldl' (Fcf.+) 0 [1,2,3])
-- ...
-- = 6
data Foldl' :: (b -> a -> Exp b) -> b -> t a -> Exp b

type instance Eval (Foldl' f y '[]) = y

type instance Eval (Foldl' f y (x ': xs)) = Eval (Foldl' f (Eval (f y x)) xs)

-- | Get an element at a given index.
--
-- >>> :kind! Eval (GetIndex 2 [2,3,4])
-- ...
-- = Just 4
data GetIndex :: Nat -> [a] -> Exp (Maybe a)

type instance Eval (GetIndex d xs) = GetIndexImpl d xs

type family GetIndexImpl (n :: Nat) (xs :: [k]) where
  GetIndexImpl _ '[] = 'Nothing
  GetIndexImpl 0 (x ': _) = 'Just x
  GetIndexImpl n (_ ': xs) = GetIndexImpl (n - 1) xs

-- | Get the dimension of a shape at the supplied index. Error if out-of-bounds.
--
-- >>> getDim 1 [2,3,4]
-- 3
-- >>> getDim 3 [2,3,4]
-- *** Exception: getDim outside bounds
-- ...
-- >>> getDim 0 []
-- 1
getDim :: Int -> [Int] -> Int
getDim :: Int -> [Int] -> Int
getDim Int
0 [] = Int
1
getDim Int
i [Int]
s = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe (String -> Int
forall a. HasCallStack => String -> a
error String
"getDim outside bounds") ([Int] -> Int -> Maybe Int
forall a. [a] -> Int -> Maybe a
maybeGetDim [Int]
s Int
i)

maybeGetDim :: [a] -> Int -> Maybe a
maybeGetDim :: forall a. [a] -> Int -> Maybe a
maybeGetDim [a]
xs Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Maybe a
forall a. Maybe a
Nothing
  | Bool
otherwise =
      (a -> (Int -> Maybe a) -> Int -> Maybe a)
-> (Int -> Maybe a) -> [a] -> Int -> Maybe a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
        ( \a
x Int -> Maybe a
r Int
k -> case Int
k of
            Int
0 -> a -> Maybe a
forall a. a -> Maybe a
Just a
x
            Int
_ -> Int -> Maybe a
r (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        )
        (Maybe a -> Int -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing)
        [a]
xs
        Int
n
{-# INLINEABLE maybeGetDim #-}

-- | Get the dimension of a shape at the supplied index. Error if out-of-bounds or non-computable (usually unknown to the compiler).
--
-- >>> :k! Eval (GetDim 1 [2,3,4])
-- ...
-- = 3
-- >>> :k! Eval (GetDim 3 [2,3,4])
-- ...
-- = (TypeError ...)
-- >>> :k! Eval (GetDim 0 '[])
-- ...
-- = 1
data GetDim :: Nat -> [Nat] -> Exp Nat

type instance
  Eval (GetDim n xs) =
    If
      (Eval (And [Eval (TyEq n 0), Eval (TyEq xs ('[] :: [Nat]))]))
      1
      (Eval (FromMaybe (L.TypeError (L.Text "GetDim out of bounds or non-computable: " :<>: ShowType n :<>: L.Text " " :<>: ShowType xs)) (Eval (GetIndex n xs))))

-- | modify an index at a specific dimension. Errors if out of bounds.
--
-- >>> modifyDim 0 (+1) [0,1,2]
-- [1,1,2]
-- >>> modifyDim 0 (+1) []
-- [2]
modifyDim :: Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim :: Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
0 Int -> Int
f [] = [Int -> Int
f Int
1]
modifyDim Int
d Int -> Int
f [Int]
xs =
  Int -> [Int] -> Int
getDim Int
d [Int]
xs
    Int -> (Int -> Int) -> Int
forall a b. a -> (a -> b) -> b
& Int -> Int
f
    Int -> (Int -> [Int]) -> [Int]
forall a b. a -> (a -> b) -> b
& (Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
xs)
    [Int] -> ([Int] -> [Int]) -> [Int]
forall a b. a -> (a -> b) -> b
& (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
d [Int]
xs <>)

-- | modify an index at a specific dimension. Errors if out of bounds.
--
-- >>> :k! Eval (ModifyDim 0 ((Fcf.+) 1) [0,1,2])
-- ...
-- = [1, 1, 2]
data ModifyDim :: Nat -> (Nat -> Exp Nat) -> [Nat] -> Exp [Nat]

type instance
  Eval (ModifyDim d f s) =
    Eval (LiftM2 (Fcf.++) (Take d s) (LiftM2 Cons (f =<< GetDim d s) (Drop (d + 1) s)))

-- | Increment the index at a dimension of a shape by 1. Scalars turn into singletons.
--
-- >>> incAt 1 [2,3,4]
-- [2,4,4]
-- >>> incAt 0 []
-- [2]
incAt :: Int -> [Int] -> [Int]
incAt :: Int -> [Int] -> [Int]
incAt Int
d [Int]
ds = Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Int] -> [Int]
asSingleton [Int]
ds)

-- | Increment the index at a dimension of a shape by 1. Scalars turn into singletons.
--
-- >>> :k! Eval (IncAt 1 [2,3,4])
-- ...
-- = [2, 4, 4]
-- >>> :k! Eval (IncAt 0 '[])
-- ...
-- = '[2]
data IncAt :: Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (IncAt d ds) =
    Eval (ModifyDim d ((Fcf.+) 1) (Eval (AsSingleton ds)))

-- | Decrement the index at a dimension os a shape by 1.
--
-- >>> decAt 1 [2,3,4]
-- [2,2,4]
decAt :: Int -> [Int] -> [Int]
decAt :: Int -> [Int] -> [Int]
decAt Int
d = Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- | Decrement the index at a dimension of a shape by 1.
--
-- >>> :k! Eval (DecAt 1 [2,3,4])
-- ...
-- = [2, 2, 4]
data DecAt :: Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (DecAt d ds) =
    Eval (ModifyDim d (Flip (Fcf.-) 1) ds)

-- | replace an index at a specific dimension, or transform a scalar into being 1-dimensional.
--
-- >>> setDim 0 1 [2,3,4]
-- [1,3,4]
-- >>> setDim 0 3 []
-- [3]
setDim :: Int -> Int -> [Int] -> [Int]
setDim :: Int -> Int -> [Int] -> [Int]
setDim Int
d Int
x = Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (Int -> Int -> Int
forall a b. a -> b -> a
const Int
x)

-- | replace an index at a specific dimension.
--
-- >>> :k! Eval (SetDim 0 1 [2,3,4])
-- ...
-- = [1, 3, 4]
data SetDim :: Nat -> Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (SetDim d x ds) =
    Eval (ModifyDim d (ConstFn x) ds)

data SetDimUncurried :: (Nat, Nat) -> [Nat] -> Exp [Nat]

type instance
  Eval (SetDimUncurried xs ds) =
    Eval (SetDim (Eval (Fst xs)) (Eval (Snd xs)) ds)

-- | Take along a dimension.
--
-- >>> takeDim 0 1 [2,3,4]
-- [1,3,4]
takeDim :: Int -> Int -> [Int] -> [Int]
takeDim :: Int -> Int -> [Int] -> [Int]
takeDim Int
d Int
t = Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
t)

-- | Take along a dimension.
--
-- >>> :k! Eval (TakeDim 0 1 [2,3,4])
-- ...
-- = [1, 3, 4]
data TakeDim :: Nat -> Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (TakeDim d t s) =
    Eval
      (ModifyDim d (Min t) s)

-- | Drop along a dimension.
--
-- >>> dropDim 2 1 [2,3,4]
-- [2,3,3]
dropDim :: Int -> Int -> [Int] -> [Int]
dropDim :: Int -> Int -> [Int] -> [Int]
dropDim Int
d Int
t = Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int -> Int) -> (Int -> Int) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t))

-- | Drop along a dimension.
--
-- >>> :k! Eval (DropDim 2 1 [2,3,4])
-- ...
-- = [2, 3, 3]
data DropDim :: Nat -> Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (DropDim d t s) =
    Eval
      ( ModifyDim
          d
          (Max 0 <=< Flip (Fcf.-) t)
          s
      )

-- | delete the i'th dimension. No effect on a scalar.
--
-- >>> deleteDim 1 [2, 3, 4]
-- [2,4]
-- >>> deleteDim 2 []
-- []
deleteDim :: Int -> [Int] -> [Int]
deleteDim :: Int -> [Int] -> [Int]
deleteDim Int
i [Int]
s = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
i [Int]
s [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
s

-- | delete the i'th dimension
--
-- >>> :k! Eval (DeleteDim 1 [2, 3, 4])
-- ...
-- = [2, 4]
-- >>> :k! Eval (DeleteDim 1 '[])
-- ...
-- = '[]
data DeleteDim :: Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (DeleteDim i ds) =
    Eval (LiftM2 (Fcf.++) (Take i ds) (Drop (i + 1) ds))

-- | Insert a new dimension at a position (or at the end if > rank).
--
-- >>> insertDim 1 3 [2,4]
-- [2,3,4]
-- >>> insertDim 0 4 []
-- [4]
insertDim :: Int -> Int -> [Int] -> [Int]
insertDim :: Int -> Int -> [Int] -> [Int]
insertDim Int
d Int
i [Int]
s = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
d [Int]
s [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
d [Int]
s)

-- | Insert a new dimension at a position (or at the end if > rank).
--
-- >>> :k! Eval (InsertDim 1 3 [2,4])
-- ...
-- = [2, 3, 4]
-- >>> :k! Eval (InsertDim 0 4 '[])
-- ...
-- = '[4]
data InsertDim :: Nat -> Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (InsertDim d i ds) =
    Eval (LiftM2 (Fcf.++) (Take d ds) (Cons i =<< Drop d ds))

data InsertDimUncurried :: (Nat, Nat) -> [Nat] -> Exp [Nat]

type instance
  Eval (InsertDimUncurried xs ds) =
    Eval (InsertDim (Eval (Fst xs)) (Eval (Snd xs)) ds)

-- | Is a slice ok constraint.
--
-- >>> :k! Eval (InsertOk 2 [2,3,4] [2,3])
-- ...
-- = True
-- >>> :k! Eval (InsertOk 0 '[] '[])
-- ...
-- = True
data InsertOk :: Nat -> [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (InsertOk d s si) =
    Eval
      ( And
          [ Eval (IsDim d s),
            Eval (TyEq si (Eval (DeleteDim d s)))
          ]
      )

-- | Is a slice ok?
--
-- >>> :k! Eval (SliceOk 1 1 2 [2,3,4])
-- ...
-- = True
data SliceOk :: Nat -> Nat -> Nat -> [Nat] -> Exp Bool

type instance
  Eval (SliceOk d off l s) =
    Eval
      ( And
          [ Eval (IsFin off =<< GetDim d s),
            Eval ((Fcf.<) l =<< GetDim d s),
            Eval ((Fcf.<) (off + l) (Eval (GetDim d s) + 1)),
            Eval (IsDim d s)
          ]
      )

-- | Combine elements of two lists pairwise.
data ZipWith3 :: (a -> b -> c -> Exp d) -> [a] -> [b] -> [c] -> Exp [d]

type instance Eval (ZipWith3 _f '[] _bs _cs) = '[]

type instance Eval (ZipWith3 _f _as '[] _cs) = '[]

type instance Eval (ZipWith3 _f _as _bs '[]) = '[]

type instance
  Eval (ZipWith3 f (a ': as) (b ': bs) (c ': cs)) =
    Eval (f a b c) ': Eval (ZipWith3 f as bs cs)

data SliceOk_ :: [Nat] -> Nat -> Nat -> Nat -> Exp Bool

type instance Eval (SliceOk_ s d off l) = Eval (SliceOk d off l s)

-- | Are slices ok?
--
-- >>> :k! Eval (SlicesOk '[1] '[1] '[2] [2,3,4])
-- ...
-- = True
data SlicesOk :: [Nat] -> [Nat] -> [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (SlicesOk ds offs ls s) =
    Eval (And =<< ZipWith3 (SliceOk_ s) ds offs ls)

-- | concatenate two arrays at dimension i
--
-- Bespoke logic for scalars.
--
-- >>> concatenate 1 [2,3,4] [2,3,4]
-- [2,6,4]
-- >>> concatenate 0 [3] []
-- [4]
-- >>> concatenate 0 [] [3]
-- [4]
-- >>> concatenate 0 [] []
-- [2]
concatenate :: Int -> [Int] -> [Int] -> [Int]
concatenate :: Int -> [Int] -> [Int] -> [Int]
concatenate Int
_ [] [] = [Int
2]
concatenate Int
_ [] [Int
x] = [Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1]
concatenate Int
_ [Int
x] [] = [Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1]
concatenate Int
i [Int]
s0 [Int]
s1 = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
i [Int]
s0 [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> [Int] -> Int
getDim Int
i [Int]
s0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> [Int] -> Int
getDim Int
i [Int]
s1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
s0)

-- | concatenate two arrays at dimension i
--
-- Bespoke logic for scalars.
--
-- >>> :k! Eval (Concatenate 1 [2,3,4] [2,3,4])
-- ...
-- = [2, 6, 4]
-- >>> :k! Eval (Concatenate 0 '[3] '[])
-- ...
-- = '[4]
-- >>> :k! Eval (Concatenate 0 '[] '[3])
-- ...
-- = '[4]
-- >>> :k! Eval (Concatenate 0 '[] '[])
-- ...
-- = '[2]
data Concatenate :: Nat -> [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (Concatenate i s0 s1) =
    If
      (Eval (ConcatenateOk i s0 s1))
      (Eval (Eval (Take i s0) ++ (Eval (GetDim i s0) + Eval (GetDim i s1) : Eval (Drop (i + 1) s0))))
      (L.TypeError (L.Text "Concatenate Mis-matched shapes."))

-- | Concatenate is Ok if ranks are the same and the non-indexed portion of the shapes are the same.
data ConcatenateOk :: Nat -> [Nat] -> [Nat] -> Exp Bool

type instance
  Eval (ConcatenateOk i s0 s1) =
    Eval (IsDim i s0)
      && Eval (IsDim i s1)
      && Eval (LiftM2 TyEq (DeleteDim i s0) (DeleteDim i s1))
      && Eval (LiftM2 TyEq (Rank =<< AsSingleton s0) (Rank =<< AsSingleton s1))

-- * multiple dimension manipulations

-- | Get dimensions of a shape.
--
-- >>> getDims [2,0] [2,3,4]
-- [4,2]
-- >>> getDims [2] []
-- []
getDims :: [Int] -> [Int] -> [Int]
getDims :: [Int] -> [Int] -> [Int]
getDims [Int]
_ [] = []
getDims [Int]
i [Int]
s = (Int -> [Int] -> Int
`getDim` [Int]
s) (Int -> Int) -> [Int] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
i

-- | Get dimensions of a shape.
--
-- >>> :k! Eval (GetDims [2,0] [2,3,4])
-- ...
-- = [4, 2]
-- >>> :k! Eval (GetDims '[2] '[])
-- ...
-- = '[(TypeError ...)]
data GetDims :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (GetDims xs ds) =
    Eval (Map (Flip GetDim ds) xs)

-- | Get the index of the last position in the selected dimensions of a shape. Errors on a 0-dimension.
--
-- >>> getLastPositions [2,0] [2,3,4]
-- [3,1]
-- >>> getLastPositions [0] [0]
-- [-1]
getLastPositions :: [Int] -> [Int] -> [Int]
getLastPositions :: [Int] -> [Int] -> [Int]
getLastPositions [Int]
ds [Int]
s =
  (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ([Int] -> [Int] -> [Int]
getDims [Int]
ds [Int]
s)

-- | Get the index of the last position in the selected dimensions of a shape. Errors on a 0-dimension.
--
-- >>> :k! Eval (GetLastPositions [2,0] [2,3,4])
-- ...
-- = [3, 1]
data GetLastPositions :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (GetLastPositions ds s) =
    Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s)))

-- | modify dimensions of a shape with (separate) functions.
--
-- >>> modifyDims [0,1] [(+1), (+5)] [2,3,4]
-- [3,8,4]
modifyDims :: [Int] -> [Int -> Int] -> [Int] -> [Int]
modifyDims :: [Int] -> [Int -> Int] -> [Int] -> [Int]
modifyDims [Int]
ds [Int -> Int]
fs [Int]
ns = ([Int] -> (Int, Int -> Int) -> [Int])
-> [Int] -> [(Int, Int -> Int)] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\[Int]
ns' (Int
d, Int -> Int
f) -> Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d Int -> Int
f [Int]
ns') [Int]
ns ([Int] -> [Int -> Int] -> [(Int, Int -> Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
ds [Int -> Int]
fs)

-- | Convert a list of positions that reference deletions according to a final shape to 1 that references deletions relative to an initial shape.
--
-- To delete the positions [1,2,5] from a list, for example, you need to delete position 1, (arriving at a 4 element list), then position 1, arriving at a 3 element list, and finally position 3.
--
-- >>> preDeletePositions [1,2,5]
-- [1,1,3]
--
-- >>> preDeletePositions [1,2,0]
-- [1,1,0]
preDeletePositions :: [Int] -> [Int]
preDeletePositions :: [Int] -> [Int]
preDeletePositions [Int]
as = [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int] -> [Int]
forall {a}. (Num a, Ord a) => [a] -> [a] -> [a]
go [Int]
as [])
  where
    go :: [a] -> [a] -> [a]
go [] [a]
r = [a]
r
    go (a
x : [a]
xs) [a]
r = [a] -> [a] -> [a]
go (a -> a -> a
forall {a}. (Num a, Ord a) => a -> a -> a
decPast a
x (a -> a) -> [a] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a]
xs) (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
r)
    decPast :: a -> a -> a
decPast a
x a
y = a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool (a
y a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a
y (a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x)

-- | Convert a list of positions that reference deletions according to a final shape to 1 that references deletions relative to an initial shape.
--
-- To delete the positions [1,2,5] from a list, for example, you need to delete position 1, (arriving at a 4 element list), then position 1, arriving at a 3 element list, and finally position 3.
--
-- >>> :k! Eval (PreDeletePositions [1,2,5])
-- ...
-- = [1, 1, 3]
--
-- >>> :k! Eval (PreDeletePositions [1,2,0])
-- ...
-- = [1, 1, 0]
data PreDeletePositions :: [Nat] -> Exp [Nat]

type instance
  Eval (PreDeletePositions xs) =
    Eval (Reverse (Eval (PreDeletePositionsGo xs '[])))

data PreDeletePositionsGo :: [Nat] -> [Nat] -> Exp [Nat]

type instance Eval (PreDeletePositionsGo '[] rs) = rs

type instance
  Eval (PreDeletePositionsGo (x : xs) r) =
    Eval (PreDeletePositionsGo (Eval (Map (DecPast x) xs)) (x : r))

data DecPast :: Nat -> Nat -> Exp Nat

type instance
  Eval (DecPast x d) =
    If (x + 1 <=? d) (d - 1) d

-- | Convert a list of position that reference insertions according to a final shape to 1 that references list insertions relative to an initial shape.
--
-- To insert into positions [1,2,0] from a list, starting from a 2 element list, for example, you need to insert at position 0, (arriving at a 3 element list), then position 1, arriving at a 4 element list, and finally position 0.
--
-- > preInsertPositions == reverse . preDeletePositions . reverse
-- >>> preInsertPositions [1,2,5]
-- [1,2,5]
--
-- >>> preInsertPositions [1,2,0]
-- [0,1,0]
preInsertPositions :: [Int] -> [Int]
preInsertPositions :: [Int] -> [Int]
preInsertPositions = [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
preDeletePositions ([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. [a] -> [a]
reverse

-- | Convert a list of position that reference insertions according to a final shape to 1 that references list insertions relative to an initial shape.
--
-- To insert into positions [1,2,0] from a list, starting from a 2 element list, for example, you need to insert at position 0, (arriving at a 3 element list), then position 1, arriving at a 4 element list, and finally position 0.
--
-- > preInsertPositions == reverse . preDeletePositions . reverse
-- >>> :k! Eval (PreInsertPositions [1,2,5])
-- ...
-- = [1, 2, 5]
--
-- >>> :k! Eval (PreInsertPositions [1,2,0])
-- ...
-- = [0, 1, 0]
data PreInsertPositions :: [Nat] -> Exp [Nat]

type instance
  Eval (PreInsertPositions xs) =
    Eval (Reverse =<< (PreDeletePositions =<< Reverse xs))

-- | drop dimensions of a shape according to a list of positions (where position refers to the initial shape)
--
-- >>> deleteDims [1,0] [2, 3, 4]
-- [4]
deleteDims :: [Int] -> [Int] -> [Int]
deleteDims :: [Int] -> [Int] -> [Int]
deleteDims [Int]
i [Int]
s = ([Int] -> Int -> [Int]) -> [Int] -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> [Int] -> [Int]) -> [Int] -> Int -> [Int]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> [Int] -> [Int]
deleteDim) [Int]
s ([Int] -> [Int]
preDeletePositions [Int]
i)

-- | drop dimensions of a shape according to a list of positions (where position refers to the initial shape)
--
-- >>> :k! Eval (DeleteDims [1,0] [2, 3, 4])
-- ...
-- = '[4]
data DeleteDims :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (DeleteDims xs ds) =
    Eval (Foldl' (Flip DeleteDim) ds =<< PreDeletePositions xs)

-- | Insert a list of dimensions according to dimensions and positions.  Note that the list of positions references the final shape and not the initial shape.
--
-- >>> insertDims [0] [5] []
-- [5]
-- >>> insertDims [1,0] [3,2] [4]
-- [2,3,4]
insertDims :: [Int] -> [Int] -> [Int] -> [Int]
insertDims :: [Int] -> [Int] -> [Int] -> [Int]
insertDims [Int]
ds [Int]
xs [Int]
s = ([Int] -> (Int, Int) -> [Int]) -> [Int] -> [(Int, Int)] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (((Int, Int) -> [Int] -> [Int]) -> [Int] -> (Int, Int) -> [Int]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> Int -> [Int] -> [Int]) -> (Int, Int) -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> [Int] -> [Int]
insertDim)) [Int]
s [(Int, Int)]
ps
  where
    ps :: [(Int, Int)]
ps = [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Int]
preInsertPositions [Int]
ds) [Int]
xs

-- | insert a list of dimensions according to dimension,position tuple lists.  Note that the list of positions references the final shape and not the initial shape.
--
-- >>> :k! Eval (InsertDims '[0] '[5] '[])
-- ...
-- = '[5]
-- >>> :k! Eval (InsertDims [1,0] [3,2] '[4])
-- ...
-- = [2, 3, 4]
data InsertDims :: [Nat] -> [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (InsertDims ds xs s) =
    Eval (Foldl' (Flip InsertDimUncurried) s =<< Flip Zip xs =<< PreInsertPositions ds)

-- | Set dimensions of a shape.
--
-- >>> setDims [0,1] [1,5] [2,3,4]
-- [1,5,4]
--
-- >>> setDims [0] [3] []
-- [3]
setDims :: [Int] -> [Int] -> [Int] -> [Int]
setDims :: [Int] -> [Int] -> [Int] -> [Int]
setDims [Int]
ds [Int]
xs [Int]
ns = ([Int] -> (Int, Int) -> [Int]) -> [Int] -> [(Int, Int)] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\[Int]
ns' (Int
d, Int
x) -> Int -> Int -> [Int] -> [Int]
setDim Int
d Int
x [Int]
ns') [Int]
ns ([Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
ds [Int]
xs)

-- | Set dimensions of a shape.
--
-- >>> :k! Eval (SetDims [0,1] [1,5] [2,3,4])
-- ...
-- = [1, 5, 4]
--
-- >>> :k! Eval (SetDims '[0] '[3] '[])
-- ...
-- = '[3]
data SetDims :: [Nat] -> [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (SetDims ds xs ns) =
    Eval (Foldl' (Flip SetDimUncurried) ns =<< Zip ds xs)

-- | Drop a number of elements of a shape along the supplied dimensions.
--
-- >>> dropDims [0,2] [1,3] [2,3,4]
-- [1,3,1]
dropDims :: [Int] -> [Int] -> [Int] -> [Int]
dropDims :: [Int] -> [Int] -> [Int] -> [Int]
dropDims [Int]
ds [Int]
xs [Int]
s = [Int] -> [Int] -> [Int] -> [Int]
setDims [Int]
ds [Int]
xs' [Int]
s
  where
    xs' :: [Int]
xs' = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) ([Int] -> [Int] -> [Int]
getDims [Int]
ds [Int]
s) [Int]
xs

-- | Drop a number of elements of a shape along the supplied dimensions.
--
-- >>> :k! Eval (DropDims [0,2] [1,3] [2,3,4])
-- ...
-- = [1, 3, 1]
data DropDims :: [Nat] -> [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (DropDims ds xs s) =
    Eval (SetDims ds (Eval (ZipWith (Fcf.-) (Eval (GetDims ds s)) xs)) s)

-- | Concatenate and replace dimensions, creating a new dimension at the supplied postion.
--
-- >>> concatDims [0,1] 1 [2,3,4]
-- [4,6]
concatDims :: [Int] -> Int -> [Int] -> [Int]
concatDims :: [Int] -> Int -> [Int] -> [Int]
concatDims [Int]
ds Int
n [Int]
s = Int -> Int -> [Int] -> [Int]
insertDim Int
n ([Int] -> Int
size ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
getDims [Int]
ds [Int]
s) ([Int] -> [Int] -> [Int]
deleteDims [Int]
ds [Int]
s)

-- | Drop a number of elements of a shape along the supplied dimensions.
--
-- >>> :k! Eval (ConcatDims [0,1] 1 [2,3,4])
-- ...
-- = [4, 6]
data ConcatDims :: [Nat] -> Nat -> [Nat] -> Exp [Nat]

type instance
  Eval (ConcatDims ds n s) =
    Eval (InsertDim n (Eval (Size (Eval (GetDims ds s)))) (Eval (DeleteDims ds s)))

-- | Unconcatenate and reinsert dimensions for an index.
--
-- >>> unconcatDimsIndex [0,1] 1 [4,6] [2,3]
-- [0,3,2]
unconcatDimsIndex :: [Int] -> Int -> [Int] -> [Int] -> [Int]
unconcatDimsIndex :: [Int] -> Int -> [Int] -> [Int] -> [Int]
unconcatDimsIndex [Int]
ds Int
n [Int]
s [Int]
i = [Int] -> [Int] -> [Int] -> [Int]
insertDims [Int]
ds ([Int] -> Int -> [Int]
shapen ([Int] -> [Int] -> [Int]
getDims [Int]
ds [Int]
s) (Int -> [Int] -> Int
getDim Int
n [Int]
i)) (Int -> [Int] -> [Int]
deleteDim Int
n [Int]
i)

-- | reverse an index along specific dimensions.
--
-- >>> reverseIndex [0] [2,3,4] [0,1,2]
-- [1,1,2]
reverseIndex :: [Int] -> [Int] -> [Int] -> [Int]
reverseIndex :: [Int] -> [Int] -> [Int] -> [Int]
reverseIndex [Int]
ds [Int]
ns [Int]
xs = ((Int, Int, Int) -> Int) -> [(Int, Int, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Int
i, Int
x, Int
n) -> Int -> Int -> Bool -> Int
forall a. a -> a -> Bool -> a
bool Int
x (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x) (Int
i Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
ds)) ([Int] -> [Int] -> [Int] -> [(Int, Int, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0 ..] [Int]
xs [Int]
ns)

-- | rotate a list
--
-- >>> rotate 1 [0..3]
-- [1,2,3,0]
-- >>> rotate (-1) [0..3]
-- [3,0,1,2]
rotate :: Int -> [a] -> [a]
rotate :: forall a. Int -> [a] -> [a]
rotate Int
r [a]
xs = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop Int
r' [a]
xs [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
r' [a]
xs
  where
    r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
List.length [a]
xs

-- | rotate an index along a specific dimension.
--
-- >>> rotateIndex 0 1 [2,3,4] [0,1,2]
-- [1,1,2]
rotateIndex :: Int -> Int -> [Int] -> [Int] -> [Int]
rotateIndex :: Int -> Int -> [Int] -> [Int] -> [Int]
rotateIndex Int
d Int
r [Int]
s = Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (\Int
x -> (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int -> [Int] -> Int
getDim Int
d [Int]
s)

-- | rotate an index along specific dimensions.
--
-- >>> rotatesIndex [0] [1] [2,3,4] [0,1,2]
-- [1,1,2]
rotatesIndex :: [Int] -> [Int] -> [Int] -> [Int] -> [Int]
rotatesIndex :: [Int] -> [Int] -> [Int] -> [Int] -> [Int]
rotatesIndex [Int]
ds [Int]
rs [Int]
s [Int]
xs = ((Int, Int) -> [Int] -> [Int]) -> [Int] -> [(Int, Int)] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Int
d, Int
r) [Int]
acc -> Int -> Int -> [Int] -> [Int] -> [Int]
rotateIndex Int
d Int
r [Int]
s [Int]
acc) [Int]
xs ([Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
ds [Int]
rs)

-- | Test whether an index is a diagonal one.
--
-- >>> isDiag [2,2,2]
-- True
-- >>> isDiag [1,2]
-- False
isDiag :: (Eq a) => [a] -> Bool
isDiag :: forall a. Eq a => [a] -> Bool
isDiag [] = Bool
True
isDiag [a
_] = Bool
True
isDiag [a
x, a
y] = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
isDiag (a
x : a
y : [a]
xs) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y Bool -> Bool -> Bool
&& [a] -> Bool
forall a. Eq a => [a] -> Bool
isDiag (a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs)

-- | Expanded shape of a windowed array
--
-- >>> expandWindows [2,2] [4,3,2]
-- [3,2,2,2,2]
expandWindows :: [Int] -> [Int] -> [Int]
expandWindows :: [Int] -> [Int] -> [Int]
expandWindows [Int]
ws [Int]
ds = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (\Int
s' Int
x' -> Int
s' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
ds [Int]
ws [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> [Int]
ws [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop ([Int] -> Int
forall a. [a] -> Int
rank [Int]
ws) [Int]
ds

-- | Expanded shape of a windowed array
--
-- >>> :k! Eval (ExpandWindows [2,2] [4,3,2])
-- ...
-- = [3, 2, 2, 2, 2]
data ExpandWindows :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (ExpandWindows ws ds) =
    Eval (Eval (ZipWith (Fcf.-) (Eval (Map ((Fcf.+) 1) ds)) ws) ++ Eval (ws ++ Eval (Drop (Eval (Rank ws)) ds)))

-- | Index into windows of an expanded windowed array, given a rank of the windows.
--
-- >>> indexWindows 2 [0,1,2,1,1]
-- [2,2,1]
indexWindows :: Int -> [Int] -> [Int]
indexWindows :: Int -> [Int] -> [Int]
indexWindows Int
r [Int]
ds = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
r [Int]
ds) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
r (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
r [Int]
ds)) [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) [Int]
ds

-- | Dimensions of a windowed array.
--
-- >>> dimWindows [2,2] [2,3,4]
-- [0,1,2]
dimWindows :: [Int] -> [Int] -> [Int]
dimWindows :: [Int] -> [Int] -> [Int]
dimWindows [Int]
ws [Int]
s = Int -> [Int]
range ([Int] -> Int
forall a. [a] -> Int
rank [Int]
s) [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> [[Int] -> Int
forall a. [a] -> Int
rank [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 .. ([Int] -> Int
forall a. [a] -> Int
rank [Int]
ws Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]

-- | Dimensions of a windowed array.
--
-- >>> :k! Eval (DimWindows [2,2] [4,3,2])
-- ...
-- = [0, 1, 2]
data DimWindows :: [Nat] -> [Nat] -> Exp [Nat]

type instance
  Eval (DimWindows ws s) =
    Eval (Eval (Range =<< Rank s) ++ Eval (EnumFromTo (Eval ((Fcf.*) 2 (Eval (Rank s)))) (Eval (Rank ws) - 1)))