{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- | Arrays with shape information and computations at a type-level.
module Harpie.Fixed
  ( -- * Usage
    -- $usage

    -- * Fixed Arrays
    Array (..),
    unsafeArray,
    validate,
    safeArray,
    array,
    unsafeModifyShape,
    unsafeModifyVector,

    -- * Dimensions
    Dim,
    pattern Dim,
    Dims,
    pattern Dims,

    -- * Conversion
    FromVector (..),
    toDynamic,
    with,
    SomeArray (..),
    someArray,

    -- * Shape Access
    shape,
    rank,
    size,
    length,
    isNull,

    -- * Indexing
    index,
    unsafeIndex,
    (!),
    (!?),
    tabulate,
    unsafeTabulate,
    backpermute,
    unsafeBackpermute,

    -- * Scalars
    fromScalar,
    toScalar,
    isScalar,
    asSingleton,
    asScalar,

    -- * Array Creation
    empty,
    range,
    corange,
    indices,
    ident,
    konst,
    singleton,
    diag,
    undiag,

    -- * Element-level functions
    zipWith,
    modify,
    imap,

    -- * Function generalisers
    rowWise,
    colWise,

    -- * Single-dimension functions
    take,
    takeB,
    drop,
    dropB,
    select,
    insert,
    delete,
    append,
    prepend,
    concatenate,
    couple,
    slice,
    rotate,

    -- * Multi-dimension functions
    takes,
    takeBs,
    drops,
    dropBs,
    indexes,
    indexesT,
    slices,
    heads,
    lasts,
    tails,
    inits,

    -- * Function application
    extracts,
    reduces,
    joins,
    join,
    traverses,
    maps,
    filters,
    zips,
    modifies,
    diffs,

    -- * Array expansion & contraction
    expand,
    coexpand,
    contract,
    prod,
    dot,
    mult,
    windows,

    -- * Search
    find,
    findNoOverlap,
    isPrefixOf,
    isSuffixOf,
    isInfixOf,

    -- * Shape manipulations
    fill,
    cut,
    cutSuffix,
    pad,
    lpad,
    reshape,
    flat,
    repeat,
    cycle,
    rerank,
    reorder,
    squeeze,
    elongate,
    transpose,
    inflate,
    intercalate,
    intersperse,
    concats,
    reverses,
    rotates,

    -- * Sorting
    sorts,
    sortsBy,
    orders,
    ordersBy,

    -- * Transmission
    telecasts,
    transmit,

    -- * Row specializations
    pattern (:<),
    cons,
    uncons,
    pattern (:>),
    snoc,
    unsnoc,

    -- * Shape specializations
    Vector,
    vector,
    vector',
    iota,
    Matrix,

    -- * Math
    uniform,
    invtri,
    inverse,
    chol,
    -- cross_,
    -- norm_,
  )
where

import Data.Bool
import Data.Distributive (Distributive (..))
import Data.Foldable hiding (find, length, minimum)
import Data.Functor.Classes
import Data.Functor.Rep
import Data.List qualified as List
import Data.Maybe
import Data.Vector qualified as V
import Fcf hiding (type (&&), type (+), type (++), type (-))
import Fcf qualified
import Fcf.Data.List
import GHC.Generics
import GHC.TypeNats
import Harpie.Array qualified as A
import Harpie.Shape hiding (asScalar, asSingleton, concatenate, range, rank, reorder, rerank, rotate, size, squeeze)
import Harpie.Shape qualified as S
import Harpie.Sort
import Prettyprinter hiding (dot, fill)
import System.Random hiding (uniform)
import System.Random.Stateful hiding (uniform)
import Test.QuickCheck hiding (tabulate, vector)
import Test.QuickCheck.Instances.Natural ()
import Unsafe.Coerce
import Prelude as P hiding (cycle, drop, length, repeat, sequence, take, zipWith)
import Prelude qualified

-- $setup
--
-- >>> :m -Prelude
-- >>> :set -XDataKinds
-- >>> :set -Wno-type-defaults
-- >>> :set -Wno-name-shadowing
-- >>> import Prelude hiding (cycle, repeat, take, drop, zipWith, length)
-- >>> import Harpie.Fixed as F
-- >>> import Harpie.Shape qualified as S
-- >>> import Harpie.Shape (SNats, Fin (..), Fins (..))
-- >>> import GHC.TypeNats
-- >>> import Data.List qualified as List
-- >>> import Prettyprinter hiding (dot,fill)
-- >>> import Data.Functor.Rep
-- >>> s = 1 :: Array '[] Int
-- >>> s
-- [1]
-- >>> shape s
-- []
-- >>> pretty s
-- 1
-- >>> let v = range @'[3]
-- >>> pretty v
-- [0,1,2]
-- >>> let m = range @[2,3]
-- >>> pretty m
-- [[0,1,2],
--  [3,4,5]]
-- >>> a = range @[2,3,4]
-- >>> a
-- [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
-- >>> pretty a
-- [[[0,1,2,3],
--   [4,5,6,7],
--   [8,9,10,11]],
--  [[12,13,14,15],
--   [16,17,18,19],
--   [20,21,22,23]]]
-- >>> e = array @[3,3] @Double [4,12,-16,12,37,-43,-16,-43,98]
-- >>> l = chol e

-- $usage
--
-- >>> :set -XDataKinds
--
-- Several names used in @harpie@ conflict with [Prelude](https://hackage.haskell.org/package/base/docs/Prelude.html):
--
-- >>> import Prelude hiding (cycle, repeat, take, drop, zipWith, length)
--
-- In general, 'Array' functionality is contained in @Harpie.Fixed@ and shape  functionality is contained in @Harpie.Shape@. These two modules also have name clashes and at least one needs to be qualified:
--
-- >>> import Harpie.Fixed as F
-- >>> import Harpie.Shape qualified as S
--
-- [@prettyprinter@](https://hackage.haskell.org/package/prettyprinter) is used to prettily render arrays to better visualise shape.
--
-- >>> import Prettyprinter hiding (dot,fill)
--
-- The 'Representable' class from [@adjunctions@](https://hackage.haskell.org/package/adjunctions) is used heavily by the module.
--
-- >>> import Data.Functor.Rep
--
-- An important base accounting of 'Array' shape is the singleton types 'SNat' (a type-level 'Natural' or 'Nat') from [GHC.TypeNats](https://hackage.haskell.org/package/base/docs/GHC-TypeNats.html) in base.
--
-- >>> import GHC.TypeNats
--
-- The [first-class-families](https://hackage.haskell.org/package/first-class-families) library was used to code most of type-level constraint logic.
--
-- >>> import Fcf qualified
--
-- Examples of arrays:
--
-- An array with no dimensions (a scalar).
--
-- >>> s = 1 :: Array '[] Int
-- >>> s
-- [1]
-- >>> shape s
-- []
-- >>> pretty s
-- 1
--
-- A single-dimension array (a vector).
--
-- >>> let v = range @'[3]
-- >>> pretty v
-- [0,1,2]
--
-- A two-dimensional array (a matrix).
--
-- >>> let m = range @[2,3]
-- >>> pretty m
-- [[0,1,2],
--  [3,4,5]]
--
-- An n-dimensional array (n should be finite).
--
-- >>> a = range @[2,3,4]
-- >>> a
-- [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
-- >>> pretty a
-- [[[0,1,2,3],
--   [4,5,6,7],
--   [8,9,10,11]],
--  [[12,13,14,15],
--   [16,17,18,19],
--   [20,21,22,23]]]
--
-- Conversion to a dynamic, value-level shaped 'Harpie.Array.Array'
--
-- >>> toDynamic a
-- UnsafeArray [2,3,4] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]

-- | A hyperrectangular (or multidimensional) array with a type-level shape.
--
-- >>> array @[2,3,4] @Int [1..24]
-- [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
-- >>> array [1..24] :: Array '[2,3,4] Int
-- [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
-- >>> pretty (array @[2,3,4] @Int [1..24])
-- [[[1,2,3,4],
--   [5,6,7,8],
--   [9,10,11,12]],
--  [[13,14,15,16],
--   [17,18,19,20],
--   [21,22,23,24]]]
--
-- >>> array [1,2,3] :: Array '[2,2] Int
-- *** Exception: Shape Mismatch
-- ...
--
-- In many situations, the use of  [TypeApplication](https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/type_applications.html) can lead to a clean coding style.
--
-- >>> array @[2,3] @Int [1..6]
-- [1,2,3,4,5,6]
--
-- The main computational entry and exit points are often via 'index' and 'tabulate' with arrays indexed by 'Fins':
--
-- >>> index a (S.UnsafeFins [1,2,3])
-- 23
--
-- >>> :t tabulate id :: Array [2,3] (Fins [2,3])
-- tabulate id :: Array [2,3] (Fins [2,3])
--   :: Array [2, 3] (Fins [2, 3])
-- >>> pretty (tabulate id :: Array [2,3] (Fins [2,3]))
-- [[[0,0],[0,1],[0,2]],
--  [[1,0],[1,1],[1,2]]]
type role Array nominal representational

newtype Array (s :: [Nat]) a where
  Array :: V.Vector a -> Array s a
  deriving stock ((forall a b. (a -> b) -> Array s a -> Array s b)
-> (forall a b. a -> Array s b -> Array s a) -> Functor (Array s)
forall (s :: [Nat]) a b. a -> Array s b -> Array s a
forall (s :: [Nat]) a b. (a -> b) -> Array s a -> Array s b
forall a b. a -> Array s b -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (s :: [Nat]) a b. (a -> b) -> Array s a -> Array s b
fmap :: forall a b. (a -> b) -> Array s a -> Array s b
$c<$ :: forall (s :: [Nat]) a b. a -> Array s b -> Array s a
<$ :: forall a b. a -> Array s b -> Array s a
Functor, (forall m. Monoid m => Array s m -> m)
-> (forall m a. Monoid m => (a -> m) -> Array s a -> m)
-> (forall m a. Monoid m => (a -> m) -> Array s a -> m)
-> (forall a b. (a -> b -> b) -> b -> Array s a -> b)
-> (forall a b. (a -> b -> b) -> b -> Array s a -> b)
-> (forall b a. (b -> a -> b) -> b -> Array s a -> b)
-> (forall b a. (b -> a -> b) -> b -> Array s a -> b)
-> (forall a. (a -> a -> a) -> Array s a -> a)
-> (forall a. (a -> a -> a) -> Array s a -> a)
-> (forall a. Array s a -> [a])
-> (forall a. Array s a -> Bool)
-> (forall a. Array s a -> Int)
-> (forall a. Eq a => a -> Array s a -> Bool)
-> (forall a. Ord a => Array s a -> a)
-> (forall a. Ord a => Array s a -> a)
-> (forall a. Num a => Array s a -> a)
-> (forall a. Num a => Array s a -> a)
-> Foldable (Array s)
forall (s :: [Nat]) a. Eq a => a -> Array s a -> Bool
forall (s :: [Nat]) a. Num a => Array s a -> a
forall (s :: [Nat]) a. Ord a => Array s a -> a
forall (s :: [Nat]) m. Monoid m => Array s m -> m
forall (s :: [Nat]) a. Array s a -> Bool
forall (s :: [Nat]) a. Array s a -> Int
forall (s :: [Nat]) a. Array s a -> [a]
forall (s :: [Nat]) a. (a -> a -> a) -> Array s a -> a
forall (s :: [Nat]) m a. Monoid m => (a -> m) -> Array s a -> m
forall (s :: [Nat]) b a. (b -> a -> b) -> b -> Array s a -> b
forall (s :: [Nat]) a b. (a -> b -> b) -> b -> Array s a -> b
forall a. Eq a => a -> Array s a -> Bool
forall a. Num a => Array s a -> a
forall a. Ord a => Array s a -> a
forall m. Monoid m => Array s m -> m
forall a. Array s a -> Bool
forall a. Array s a -> Int
forall a. Array s a -> [a]
forall a. (a -> a -> a) -> Array s a -> a
forall m a. Monoid m => (a -> m) -> Array s a -> m
forall b a. (b -> a -> b) -> b -> Array s a -> b
forall a b. (a -> b -> b) -> b -> Array s a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall (s :: [Nat]) m. Monoid m => Array s m -> m
fold :: forall m. Monoid m => Array s m -> m
$cfoldMap :: forall (s :: [Nat]) m a. Monoid m => (a -> m) -> Array s a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> Array s a -> m
$cfoldMap' :: forall (s :: [Nat]) m a. Monoid m => (a -> m) -> Array s a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> Array s a -> m
$cfoldr :: forall (s :: [Nat]) a b. (a -> b -> b) -> b -> Array s a -> b
foldr :: forall a b. (a -> b -> b) -> b -> Array s a -> b
$cfoldr' :: forall (s :: [Nat]) a b. (a -> b -> b) -> b -> Array s a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> Array s a -> b
$cfoldl :: forall (s :: [Nat]) b a. (b -> a -> b) -> b -> Array s a -> b
foldl :: forall b a. (b -> a -> b) -> b -> Array s a -> b
$cfoldl' :: forall (s :: [Nat]) b a. (b -> a -> b) -> b -> Array s a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> Array s a -> b
$cfoldr1 :: forall (s :: [Nat]) a. (a -> a -> a) -> Array s a -> a
foldr1 :: forall a. (a -> a -> a) -> Array s a -> a
$cfoldl1 :: forall (s :: [Nat]) a. (a -> a -> a) -> Array s a -> a
foldl1 :: forall a. (a -> a -> a) -> Array s a -> a
$ctoList :: forall (s :: [Nat]) a. Array s a -> [a]
toList :: forall a. Array s a -> [a]
$cnull :: forall (s :: [Nat]) a. Array s a -> Bool
null :: forall a. Array s a -> Bool
$clength :: forall (s :: [Nat]) a. Array s a -> Int
length :: forall a. Array s a -> Int
$celem :: forall (s :: [Nat]) a. Eq a => a -> Array s a -> Bool
elem :: forall a. Eq a => a -> Array s a -> Bool
$cmaximum :: forall (s :: [Nat]) a. Ord a => Array s a -> a
maximum :: forall a. Ord a => Array s a -> a
$cminimum :: forall (s :: [Nat]) a. Ord a => Array s a -> a
minimum :: forall a. Ord a => Array s a -> a
$csum :: forall (s :: [Nat]) a. Num a => Array s a -> a
sum :: forall a. Num a => Array s a -> a
$cproduct :: forall (s :: [Nat]) a. Num a => Array s a -> a
product :: forall a. Num a => Array s a -> a
Foldable, (forall x. Array s a -> Rep (Array s a) x)
-> (forall x. Rep (Array s a) x -> Array s a)
-> Generic (Array s a)
forall (s :: [Nat]) a x. Rep (Array s a) x -> Array s a
forall (s :: [Nat]) a x. Array s a -> Rep (Array s a) x
forall x. Rep (Array s a) x -> Array s a
forall x. Array s a -> Rep (Array s a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (s :: [Nat]) a x. Array s a -> Rep (Array s a) x
from :: forall x. Array s a -> Rep (Array s a) x
$cto :: forall (s :: [Nat]) a x. Rep (Array s a) x -> Array s a
to :: forall x. Rep (Array s a) x -> Array s a
Generic, Functor (Array s)
Foldable (Array s)
(Functor (Array s), Foldable (Array s)) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> Array s a -> f (Array s b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    Array s (f a) -> f (Array s a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> Array s a -> m (Array s b))
-> (forall (m :: * -> *) a.
    Monad m =>
    Array s (m a) -> m (Array s a))
-> Traversable (Array s)
forall (s :: [Nat]). Functor (Array s)
forall (s :: [Nat]). Foldable (Array s)
forall (s :: [Nat]) (m :: * -> *) a.
Monad m =>
Array s (m a) -> m (Array s a)
forall (s :: [Nat]) (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
forall (s :: [Nat]) (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
forall (s :: [Nat]) (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => Array s (m a) -> m (Array s a)
forall (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
$ctraverse :: forall (s :: [Nat]) (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
$csequenceA :: forall (s :: [Nat]) (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
$cmapM :: forall (s :: [Nat]) (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
$csequence :: forall (s :: [Nat]) (m :: * -> *) a.
Monad m =>
Array s (m a) -> m (Array s a)
sequence :: forall (m :: * -> *) a. Monad m => Array s (m a) -> m (Array s a)
Traversable)
  deriving newtype (Array s a -> Array s a -> Bool
(Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool) -> Eq (Array s a)
forall (s :: [Nat]) a. Eq a => Array s a -> Array s a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (s :: [Nat]) a. Eq a => Array s a -> Array s a -> Bool
== :: Array s a -> Array s a -> Bool
$c/= :: forall (s :: [Nat]) a. Eq a => Array s a -> Array s a -> Bool
/= :: Array s a -> Array s a -> Bool
Eq, (forall a. Eq a => Eq (Array s a)) =>
(forall a b. (a -> b -> Bool) -> Array s a -> Array s b -> Bool)
-> Eq1 (Array s)
forall (s :: [Nat]) a. Eq a => Eq (Array s a)
forall (s :: [Nat]) a b.
(a -> b -> Bool) -> Array s a -> Array s b -> Bool
forall a. Eq a => Eq (Array s a)
forall a b. (a -> b -> Bool) -> Array s a -> Array s b -> Bool
forall (f :: * -> *).
(forall a. Eq a => Eq (f a)) =>
(forall a b. (a -> b -> Bool) -> f a -> f b -> Bool) -> Eq1 f
$cliftEq :: forall (s :: [Nat]) a b.
(a -> b -> Bool) -> Array s a -> Array s b -> Bool
liftEq :: forall a b. (a -> b -> Bool) -> Array s a -> Array s b -> Bool
Eq1, Eq (Array s a)
Eq (Array s a) =>
(Array s a -> Array s a -> Ordering)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Array s a)
-> (Array s a -> Array s a -> Array s a)
-> Ord (Array s a)
Array s a -> Array s a -> Bool
Array s a -> Array s a -> Ordering
Array s a -> Array s a -> Array s a
forall (s :: [Nat]) a. Ord a => Eq (Array s a)
forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Bool
forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Ordering
forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Array s a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Ordering
compare :: Array s a -> Array s a -> Ordering
$c< :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Bool
< :: Array s a -> Array s a -> Bool
$c<= :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Bool
<= :: Array s a -> Array s a -> Bool
$c> :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Bool
> :: Array s a -> Array s a -> Bool
$c>= :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Bool
>= :: Array s a -> Array s a -> Bool
$cmax :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Array s a
max :: Array s a -> Array s a -> Array s a
$cmin :: forall (s :: [Nat]) a. Ord a => Array s a -> Array s a -> Array s a
min :: Array s a -> Array s a -> Array s a
Ord, Eq1 (Array s)
(Eq1 (Array s), forall a. Ord a => Ord (Array s a)) =>
(forall a b.
 (a -> b -> Ordering) -> Array s a -> Array s b -> Ordering)
-> Ord1 (Array s)
forall (s :: [Nat]). Eq1 (Array s)
forall (s :: [Nat]) a. Ord a => Ord (Array s a)
forall (s :: [Nat]) a b.
(a -> b -> Ordering) -> Array s a -> Array s b -> Ordering
forall a. Ord a => Ord (Array s a)
forall a b.
(a -> b -> Ordering) -> Array s a -> Array s b -> Ordering
forall (f :: * -> *).
(Eq1 f, forall a. Ord a => Ord (f a)) =>
(forall a b. (a -> b -> Ordering) -> f a -> f b -> Ordering)
-> Ord1 f
$cliftCompare :: forall (s :: [Nat]) a b.
(a -> b -> Ordering) -> Array s a -> Array s b -> Ordering
liftCompare :: forall a b.
(a -> b -> Ordering) -> Array s a -> Array s b -> Ordering
Ord1, Int -> Array s a -> ShowS
[Array s a] -> ShowS
Array s a -> String
(Int -> Array s a -> ShowS)
-> (Array s a -> String)
-> ([Array s a] -> ShowS)
-> Show (Array s a)
forall (s :: [Nat]) a. Show a => Int -> Array s a -> ShowS
forall (s :: [Nat]) a. Show a => [Array s a] -> ShowS
forall (s :: [Nat]) a. Show a => Array s a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (s :: [Nat]) a. Show a => Int -> Array s a -> ShowS
showsPrec :: Int -> Array s a -> ShowS
$cshow :: forall (s :: [Nat]) a. Show a => Array s a -> String
show :: Array s a -> String
$cshowList :: forall (s :: [Nat]) a. Show a => [Array s a] -> ShowS
showList :: [Array s a] -> ShowS
Show, (forall a. Show a => Show (Array s a)) =>
(forall a.
 (Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> Array s a -> ShowS)
-> (forall a.
    (Int -> a -> ShowS) -> ([a] -> ShowS) -> [Array s a] -> ShowS)
-> Show1 (Array s)
forall (s :: [Nat]) a. Show a => Show (Array s a)
forall (s :: [Nat]) a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> Array s a -> ShowS
forall (s :: [Nat]) a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> [Array s a] -> ShowS
forall a. Show a => Show (Array s a)
forall a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> Array s a -> ShowS
forall a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> [Array s a] -> ShowS
forall (f :: * -> *).
(forall a. Show a => Show (f a)) =>
(forall a.
 (Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> f a -> ShowS)
-> (forall a.
    (Int -> a -> ShowS) -> ([a] -> ShowS) -> [f a] -> ShowS)
-> Show1 f
$cliftShowsPrec :: forall (s :: [Nat]) a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> Array s a -> ShowS
liftShowsPrec :: forall a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> Array s a -> ShowS
$cliftShowList :: forall (s :: [Nat]) a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> [Array s a] -> ShowS
liftShowList :: forall a.
(Int -> a -> ShowS) -> ([a] -> ShowS) -> [Array s a] -> ShowS
Show1)

instance (Num a, KnownNats s) => Num (Array s a) where
  + :: Array s a -> Array s a -> Array s a
(+) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+)
  (-) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith (-)
  * :: Array s a -> Array s a -> Array s a
(*) = String -> Array s a -> Array s a -> Array s a
forall a. HasCallStack => String -> a
error String
"multiplication not defined"
  abs :: Array s a -> Array s a
abs = (a -> a) -> Array s a -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
  signum :: Array s a -> Array s a
signum = (a -> a) -> Array s a -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum
  fromInteger :: Integer -> Array s a
fromInteger Integer
x = forall (s :: [Nat]) a. KnownNats s => a -> Array s a
konst @s (Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
x)

instance (KnownNats s, Show a) => Pretty (Array s a) where
  pretty :: forall ann. Array s a -> Doc ann
pretty = Array a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Array a -> Doc ann
pretty (Array a -> Doc ann)
-> (Array s a -> Array a) -> Array s a -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array s a -> Array a
forall (s :: [Nat]) a. KnownNats s => Array s a -> Array a
toDynamic

instance
  (KnownNats s) =>
  Data.Distributive.Distributive (Array s)
  where
  distribute :: (KnownNats s, Functor f) => f (Array s a) -> Array s (f a)
  distribute :: forall (f :: * -> *) a.
(KnownNats s, Functor f) =>
f (Array s a) -> Array s (f a)
distribute = f (Array s a) -> Array s (f a)
forall (f :: * -> *) (w :: * -> *) a.
(Representable f, Functor w) =>
w (f a) -> f (w a)
distributeRep
  {-# INLINE distribute #-}

instance
  forall s.
  (KnownNats s) =>
  Representable (Array s)
  where
  type Rep (Array s) = Fins s

  tabulate :: forall a. (Rep (Array s) -> a) -> Array s a
tabulate Rep (Array s) -> a
f =
    Vector a -> Array s a
forall a (s :: [Nat]). Vector a -> Array s a
Array (Vector a -> Array s a)
-> ((Int -> a) -> Vector a) -> (Int -> a) -> Array s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate ([Int] -> Int
S.size [Int]
s) ((Int -> a) -> Array s a) -> (Int -> a) -> Array s a
forall a b. (a -> b) -> a -> b
$ (Rep (Array s) -> a
Fins s -> a
f (Fins s -> a) -> (Int -> Fins s) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins s) -> (Int -> [Int]) -> Int -> Fins s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen [Int]
s)
    where
      s :: [Int]
s = forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s
  {-# INLINE tabulate #-}

  index :: forall a. Array s a -> Rep (Array s) -> a
index (Array Vector a
v) Rep (Array s)
i = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s (Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array s)
Fins s
i))
    where
      s :: [Int]
s = forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s
  {-# INLINE index #-}

-- | Conversion to and from a `V.Vector`
--
-- Note that conversion of an 'Array' to a vector drops shape information, so that:
--
-- > vectorAs . asVector == id
-- > asVector . vectorAs == 'flat'
--
-- >>> asVector (range @[2,3])
-- [0,1,2,3,4,5]
--
-- >>> import Data.Vector qualified as V
-- >>> vectorAs (V.fromList [0..5]) :: Array [2,3] Int
-- [0,1,2,3,4,5]
class FromVector t a | t -> a where
  asVector :: t -> V.Vector a
  vectorAs :: V.Vector a -> t

instance FromVector (V.Vector a) a where
  asVector :: Vector a -> Vector a
asVector = Vector a -> Vector a
forall a. a -> a
id
  vectorAs :: Vector a -> Vector a
vectorAs = Vector a -> Vector a
forall a. a -> a
id

instance FromVector [a] a where
  asVector :: [a] -> Vector a
asVector = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList
  vectorAs :: Vector a -> [a]
vectorAs = Vector a -> [a]
forall a. Vector a -> [a]
V.toList

instance FromVector (Array s a) a where
  asVector :: Array s a -> Vector a
asVector (Array Vector a
v) = Vector a
v
  vectorAs :: Vector a -> Array s a
vectorAs Vector a
v = Vector a -> Array s a
forall a (s :: [Nat]). Vector a -> Array s a
Array Vector a
v

-- | Construct an array without shape validation.
--
-- >>> unsafeArray [0..4] :: Array [2,3] Int
-- [0,1,2,3,4]
unsafeArray :: (KnownNats s, FromVector t a) => t -> Array s a
unsafeArray :: forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Array s a
unsafeArray (t -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector -> Vector a
v) = Vector a -> Array s a
forall a (s :: [Nat]). Vector a -> Array s a
Array Vector a
v

-- | Validate the size and shape of an array.
--
-- >>> validate (unsafeArray [0..4] :: Array [2,3] Int)
-- False
validate :: (KnownNats s) => Array s a -> Bool
validate :: forall (s :: [Nat]) a. KnownNats s => Array s a -> Bool
validate Array s a
a = Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
size Array s a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector a -> Int
forall a. Vector a -> Int
V.length (Array s a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector Array s a
a)

-- | Construct an Array, checking shape.
--
-- >>> (safeArray [0..23] :: Maybe (Array [2,3,4] Int)) == Just a
-- True
safeArray :: (KnownNats s, FromVector t a) => t -> Maybe (Array s a)
safeArray :: forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Maybe (Array s a)
safeArray t
v =
  Maybe (Array s a) -> Maybe (Array s a) -> Bool -> Maybe (Array s a)
forall a. a -> a -> Bool -> a
bool Maybe (Array s a)
forall a. Maybe a
Nothing (Array s a -> Maybe (Array s a)
forall a. a -> Maybe a
Just Array s a
a) (Array s a -> Bool
forall (s :: [Nat]) a. KnownNats s => Array s a -> Bool
validate Array s a
a)
  where
    a :: Array s a
a = t -> Array s a
forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Array s a
unsafeArray t
v

-- | Construct an Array, throwing an exception on a bad shape.
--
-- >>> array [0..22] :: Array [2,3,4] Int
-- *** Exception: Shape Mismatch
-- ...
array :: forall s a t. (KnownNats s, FromVector t a) => t -> Array s a
array :: forall (s :: [Nat]) a t.
(KnownNats s, FromVector t a) =>
t -> Array s a
array t
v =
  Array s a -> Maybe (Array s a) -> Array s a
forall a. a -> Maybe a -> a
fromMaybe (String -> Array s a
forall a. HasCallStack => String -> a
error String
"Shape Mismatch") (t -> Maybe (Array s a)
forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Maybe (Array s a)
safeArray t
v)

-- | Unsafely modify an array shape.
--
-- >>> pretty (unsafeModifyShape @[3,2] (array @[2,3] @Int [0..5]))
-- [[0,1],
--  [2,3],
--  [4,5]]
unsafeModifyShape :: forall s' s a. (KnownNats s, KnownNats s') => Array s a -> Array s' a
unsafeModifyShape :: forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape Array s a
a = Vector a -> Array s' a
forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Array s a
unsafeArray (Array s a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector Array s a
a)

-- | Unsafely modify an array vector.
--
-- >>> import Data.Vector qualified as V
-- >>> pretty (unsafeModifyVector (V.map (+1)) (array [0..5] :: Array [2,3] Int))
-- [[1,2,3],
--  [4,5,6]]
unsafeModifyVector :: (KnownNats s) => (FromVector u a) => (FromVector v b) => (u -> v) -> Array s a -> Array s b
unsafeModifyVector :: forall (s :: [Nat]) u a v b.
(KnownNats s, FromVector u a, FromVector v b) =>
(u -> v) -> Array s a -> Array s b
unsafeModifyVector u -> v
f Array s a
a = Vector b -> Array s b
forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Array s a
unsafeArray (v -> Vector b
forall t a. FromVector t a => t -> Vector a
asVector (u -> v
f (Vector a -> u
forall t a. FromVector t a => Vector a -> t
vectorAs (Array s a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector Array s a
a))))

-- | Representation of an index into a shape (a type-level [Nat]). 'Dim @0' is commonly thought of as the row of an array.
type Dim = SNat

-- | Pattern synonym for a 'Dim'
pattern Dim :: () => (KnownNat n) => SNat n
pattern $mDim :: forall {r} {n :: Nat}.
SNat n -> (KnownNat n => r) -> ((# #) -> r) -> r
$bDim :: forall (n :: Nat). KnownNat n => SNat n
Dim = SNat

{-# COMPLETE Dim #-}

-- | Representation of indexes into a shape (a type-level [Nat]). The indexes are dimensions of the shape.
type Dims = SNats

-- | Pattern synonym for a 'Dims'
pattern Dims :: () => (KnownNats ns) => SNats ns
pattern $mDims :: forall {r} {ns :: [Nat]}.
SNats ns -> (KnownNats ns => r) -> ((# #) -> r) -> r
$bDims :: forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims = SNats

{-# COMPLETE Dims #-}

-- | Convert to a dynamic array with shape at the value level.
--
-- >>> toDynamic a
-- UnsafeArray [2,3,4] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
toDynamic :: (KnownNats s) => Array s a -> A.Array a
toDynamic :: forall (s :: [Nat]) a. KnownNats s => Array s a -> Array a
toDynamic Array s a
a = [Int] -> Vector a -> Array a
forall t a. FromVector t a => [Int] -> t -> Array a
A.array (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a) (Array s a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector Array s a
a)

-- | Use a dynamic array in a fixed context.
--
-- >>> import qualified Harpie.Array as A
-- >>> with (A.range [2,3,4]) show
-- "[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]"
--
-- This doesn't work for anything more complex where KnownNats need to be type computed:
--
-- >>> :t with (A.range [2,3,4]) (pretty . F.takes (Dims @'[0]) (S.SNats @'[1]))
-- ...
--     • Could not deduce ‘S.KnownNats (Fcf.Data.List.Drop_ 1 s)’
-- ...
with ::
  forall a r.
  A.Array a ->
  (forall s. (KnownNats s) => Array s a -> r) ->
  r
with :: forall a r.
Array a
-> (forall (s :: [Nat]). KnownNats s => Array s a -> r) -> r
with Array a
d forall (s :: [Nat]). KnownNats s => Array s a -> r
f =
  [Nat] -> (forall {s :: [Nat]}. SNats s -> r) -> r
forall r. [Nat] -> (forall (s :: [Nat]). SNats s -> r) -> r
withSomeSNats (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Nat) -> [Int] -> [Nat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Array a -> [Int]
forall a. Array a -> [Int]
A.shape Array a
d) ((forall {s :: [Nat]}. SNats s -> r) -> r)
-> (forall {s :: [Nat]}. SNats s -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNats s
SNats :: SNats s) -> SNats s -> (KnownNats s => r) -> r
forall (ns :: [Nat]) r. SNats ns -> (KnownNats ns => r) -> r
withKnownNats (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @s) (Array s a -> r
forall (s :: [Nat]). KnownNats s => Array s a -> r
f (forall (s :: [Nat]) a t.
(KnownNats s, FromVector t a) =>
t -> Array s a
array @s (Array a -> Vector a
forall t a. FromVector t a => t -> Vector a
A.asVector Array a
d)))

-- | Sigma type for an 'Array'
--
-- A fixed Array where shape was unknown at runtime.
--
-- The library design encourages the use of value-level shape arrays (in @Harpie.Array@) via 'toDynamic' in preference to dependent-type styles of coding. In particular, no attempt has been made to prove to the compiler that a particular Shape (resulting from any of the supplied functions) exists. Life is short.
--
-- >> P.take 4 <$> sample' arbitrary :: IO [SomeArray Int]
-- >> [SomeArray SNats @'[] [0],SomeArray SNats @'[0] [],SomeArray SNats @[1, 1] [1],SomeArray SNats @[5, 1, 4] [2,1,0,2,-6,0,5,6,-1,-4,0,5,-1,6,4,-6,1,0,3,-1]]
data SomeArray a = forall s. SomeArray (SNats s) (Array s a)

deriving instance (Show a) => Show (SomeArray a)

instance Functor SomeArray where
  fmap :: forall a b. (a -> b) -> SomeArray a -> SomeArray b
fmap a -> b
f (SomeArray SNats s
sn Array s a
a) = SNats s -> Array s b -> SomeArray b
forall a (s :: [Nat]). SNats s -> Array s a -> SomeArray a
SomeArray SNats s
sn ((a -> b) -> Array s a -> Array s b
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Array s a
a)

instance Foldable SomeArray where
  foldMap :: forall m a. Monoid m => (a -> m) -> SomeArray a -> m
foldMap a -> m
f (SomeArray SNats s
_ Array s a
a) = (a -> m) -> Array s a -> m
forall m a. Monoid m => (a -> m) -> Array s a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f Array s a
a

-- | Construct a SomeArray
someArray :: forall s t a. (FromVector t a) => SNats s -> t -> SomeArray a
someArray :: forall (s :: [Nat]) t a.
FromVector t a =>
SNats s -> t -> SomeArray a
someArray SNats s
s t
t = SNats s -> Array s a -> SomeArray a
forall a (s :: [Nat]). SNats s -> Array s a -> SomeArray a
SomeArray SNats s
s (Vector a -> Array s a
forall a (s :: [Nat]). Vector a -> Array s a
Array (t -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector t
t))

instance (Arbitrary a) => Arbitrary (SomeArray a) where
  arbitrary :: Gen (SomeArray a)
arbitrary = do
    [Small Nat]
s <- Gen [Small Nat]
forall a. Arbitrary a => Gen a
arbitrary :: Gen [Small Nat]
    let s' :: [Nat]
s' = Int -> [Nat] -> [Nat]
forall a. Int -> [a] -> [a]
Prelude.take Int
3 (Small Nat -> Nat
forall a. Small a -> a
getSmall (Small Nat -> Nat) -> [Small Nat] -> [Nat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Small Nat]
s)
    Vector a
v <- Int -> Gen a -> Gen (Vector a)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Nat -> Int
forall a b. (Integral a, Num b) => a -> b
Prelude.fromIntegral (Nat -> Int) -> [Nat] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Nat]
s')) Gen a
forall a. Arbitrary a => Gen a
arbitrary
    [Nat]
-> (forall {s :: [Nat]}. SNats s -> Gen (SomeArray a))
-> Gen (SomeArray a)
forall r. [Nat] -> (forall (s :: [Nat]). SNats s -> r) -> r
withSomeSNats [Nat]
s' ((forall {s :: [Nat]}. SNats s -> Gen (SomeArray a))
 -> Gen (SomeArray a))
-> (forall {s :: [Nat]}. SNats s -> Gen (SomeArray a))
-> Gen (SomeArray a)
forall a b. (a -> b) -> a -> b
$ \SNats s
sn -> SomeArray a -> Gen (SomeArray a)
forall a. a -> Gen a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SNats s -> Vector a -> SomeArray a
forall (s :: [Nat]) t a.
FromVector t a =>
SNats s -> t -> SomeArray a
someArray SNats s
sn Vector a
v)

-- | Get shape of an Array as a value.
--
-- >>> shape a
-- [2,3,4]
shape :: forall a s. (KnownNats s) => Array s a -> [Int]
shape :: forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
_ = forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s
{-# INLINE shape #-}

-- | Get rank of an Array as a value.
--
-- >>> rank a
-- 3
rank :: forall a s. (KnownNats s) => Array s a -> Int
rank :: forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank = [Int] -> Int
forall a. [a] -> Int
S.rank ([Int] -> Int) -> (Array s a -> [Int]) -> Array s a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape
{-# INLINE rank #-}

-- | Get size of an Array as a value.
--
-- >>> size a
-- 24
size :: forall a s. (KnownNats s) => Array s a -> Int
size :: forall a (s :: [Nat]). KnownNats s => Array s a -> Int
size = [Int] -> Int
S.size ([Int] -> Int) -> (Array s a -> [Int]) -> Array s a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape
{-# INLINE size #-}

-- | Number of rows (first dimension size) in an Array. As a convention, a scalar value is still a single row.
--
-- >>> length a
-- 2
-- >>> length (toScalar 0)
-- 1
length :: (KnownNats s) => Array s a -> Int
length :: forall (s :: [Nat]) a. KnownNats s => Array s a -> Int
length Array s a
a = case Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a of
  [] -> Int
1
  (Int
x : [Int]
_) -> Int
x

-- | Is the Array empty (has zero number of elements).
--
-- >>> isNull (array [] :: Array [2,0] ())
-- True
-- >>> isNull (array [4] :: Array '[] Int)
-- False
isNull :: (KnownNats s) => Array s a -> Bool
isNull :: forall (s :: [Nat]) a. KnownNats s => Array s a -> Bool
isNull = (Int
0 ==) (Int -> Bool) -> (Array s a -> Int) -> Array s a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
size

-- | Extract an element at an index, unsafely.
--
-- >>> unsafeIndex a [1,2,3]
-- 23
unsafeIndex :: (KnownNats s) => Array s a -> [Int] -> a
unsafeIndex :: forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Array s a
a [Int]
xs = Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
xs)

-- | Extract an element at an index, unsafely.
--
-- >>> a ! [1,2,3]
-- 23
(!) :: (KnownNats s) => Array s a -> [Int] -> a
! :: forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
(!) Array s a
a [Int]
xs = Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
xs)

infixl 9 !

-- | Extract an element at an index, safely.
--
-- >>> a !? [1,2,3]
-- Just 23
-- >>> a !? [2,3,1]
-- Nothing
(!?) :: (KnownNats s) => Array s a -> [Int] -> Maybe a
!? :: forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> Maybe a
(!?) Array s a
a [Int]
xs = Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (Fins s -> a) -> Maybe (Fins s) -> Maybe a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> Maybe (Fins s)
forall (s :: [Nat]). KnownNats s => [Int] -> Maybe (Fins s)
safeFins [Int]
xs

infixl 9 !?

-- | Tabulate unsafely.
--
-- >>> :t tabulate @(Array [2,3]) id
-- tabulate @(Array [2,3]) id :: Array [2, 3] (Fins [2, 3])
-- >>> :t unsafeTabulate @[2,3] id
-- unsafeTabulate @[2,3] id :: Array [2, 3] [Int]
-- >>> pretty $ unsafeTabulate @[2,3] id
-- [[[0,0],[0,1],[0,2]],
--  [[1,0],[1,1],[1,2]]]
unsafeTabulate :: (KnownNats s) => ([Int] -> a) -> Array s a
unsafeTabulate :: forall (s :: [Nat]) a. KnownNats s => ([Int] -> a) -> Array s a
unsafeTabulate [Int] -> a
f = (Rep (Array s) -> a) -> Array s a
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate ([Int] -> a
f ([Int] -> a) -> (Fins s -> [Int]) -> Fins s -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins)

-- | @backpermute@ is a tabulation where the contents of an array do not need to be accessed, and is thus a fulcrum for leveraging laziness and fusion via the rule:
--
-- > backpermute f (backpermute f' a) == backpermute (f . f') a
--
-- Many functions in this module are examples of backpermute usage.
--
-- >>> pretty $ backpermute @[4,3,2] (UnsafeFins . List.reverse . fromFins) a
-- [[[0,12],
--   [4,16],
--   [8,20]],
--  [[1,13],
--   [5,17],
--   [9,21]],
--  [[2,14],
--   [6,18],
--   [10,22]],
--  [[3,15],
--   [7,19],
--   [11,23]]]
backpermute :: forall s' s a. (KnownNats s, KnownNats s') => (Fins s' -> Fins s) -> Array s a -> Array s' a
backpermute :: forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
(Fins s' -> Fins s) -> Array s a -> Array s' a
backpermute Fins s' -> Fins s
f Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (Fins s -> a) -> (Fins s' -> Fins s) -> Fins s' -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s' -> Fins s
f)
{-# INLINEABLE backpermute #-}

{- RULES
   "backpermute/backpermute" forall f f' (a :: forall a. Array a)). backpermute f (backpermute f' a) == backpermute (f . f') a
-}

-- | Unsafe backpermute
--
-- >>> pretty $ unsafeBackpermute @[4,3,2] List.reverse a
-- [[[0,12],
--   [4,16],
--   [8,20]],
--  [[1,13],
--   [5,17],
--   [9,21]],
--  [[2,14],
--   [6,18],
--   [10,22]],
--  [[3,15],
--   [7,19],
--   [11,23]]]
unsafeBackpermute :: forall s' s a. (KnownNats s, KnownNats s') => ([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute :: forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
f Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (Fins s -> a) -> (Fins s' -> Fins s) -> Fins s' -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins s) -> (Fins s' -> [Int]) -> Fins s' -> Fins s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
f ([Int] -> [Int]) -> (Fins s' -> [Int]) -> Fins s' -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins)

{- RULES
   "unsafeBackpermute/unsafeBackpermute" forall f f' (a :: forall a. Array a)). unsafeBackpermute f (unsafeBackpermute f' a) == unsafeBackpermute (f . f') a
-}

-- | Unwrap a scalar.
--
-- >>> s = array @'[] @Int [3]
-- >>> :t fromScalar s
-- fromScalar s :: Int
fromScalar :: Array '[] a -> a
fromScalar :: forall a. Array '[] a -> a
fromScalar Array '[] a
a = Array '[] a -> Rep (Array '[]) -> a
forall a. Array '[] a -> Rep (Array '[]) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[] a
a ([Int] -> Fins '[]
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [])

-- | Wrap a scalar.
--
-- >>> :t toScalar @Int 2
-- toScalar @Int 2 :: Array '[] Int
toScalar :: a -> Array '[] a
toScalar :: forall a. a -> Array '[] a
toScalar a
a = Vector a -> Array '[] a
forall a (s :: [Nat]). Vector a -> Array s a
Array (a -> Vector a
forall a. a -> Vector a
V.singleton a
a)

-- | Is an array a scalar?
--
-- >>> isScalar (toScalar (2::Int))
-- True
isScalar :: (KnownNats s) => Array s a -> Bool
isScalar :: forall (s :: [Nat]) a. KnownNats s => Array s a -> Bool
isScalar Array s a
a = Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array s a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0

-- | Convert a scalar to being a dimensioned array. Do nothing if not a scalar.
--
-- >>> asSingleton (toScalar 4)
-- [4]
asSingleton :: (KnownNats s, KnownNats s', s' ~ Eval (AsSingleton s)) => Array s a -> Array s' a
asSingleton :: forall (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (AsSingleton s)) =>
Array s a -> Array s' a
asSingleton = Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape

-- | Convert an array with shape [1] to being a scalar (Do nothing if not a shape [1] array).
--
-- >>> pretty (asScalar (singleton 3))
-- 3
asScalar :: (KnownNats s, KnownNats s', s' ~ Eval (AsScalar s)) => Array s a -> Array s' a
asScalar :: forall (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (AsScalar s)) =>
Array s a -> Array s' a
asScalar = Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape

-- | An array with no elements.
--
-- >>> toDynamic empty
-- UnsafeArray [0] []
empty :: Array '[0] a
empty :: forall a. Array '[0] a
empty = [a] -> Array '[0] a
forall (s :: [Nat]) a t.
(KnownNats s, FromVector t a) =>
t -> Array s a
array []

-- | An enumeration of row-major or [lexicographic](https://en.wikipedia.org/wiki/Lexicographic_order) order.
--
-- >>> pretty (range :: Array [2,3] Int)
-- [[0,1,2],
--  [3,4,5]]
range :: forall s. (KnownNats s) => Array s Int
range :: forall (s :: [Nat]). KnownNats s => Array s Int
range = (Rep (Array s) -> Int) -> Array s Int
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate ([Int] -> [Int] -> Int
S.flatten (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s) ([Int] -> Int) -> (Fins s -> [Int]) -> Fins s -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins)

-- | An enumeration of col-major or [colexicographic](https://en.wikipedia.org/wiki/Lexicographic_order) order.
--
-- >>> pretty (corange @[2,3,4])
-- [[[0,6,12,18],
--   [2,8,14,20],
--   [4,10,16,22]],
--  [[1,7,13,19],
--   [3,9,15,21],
--   [5,11,17,23]]]
corange :: forall s. (KnownNats s) => Array s Int
corange :: forall (s :: [Nat]). KnownNats s => Array s Int
corange = (Rep (Array s) -> Int) -> Array s Int
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate ([Int] -> [Int] -> Int
S.flatten ([Int] -> [Int]
forall a. [a] -> [a]
List.reverse (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s)) ([Int] -> Int) -> (Fins s -> [Int]) -> Fins s -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. [a] -> [a]
List.reverse ([Int] -> [Int]) -> (Fins s -> [Int]) -> Fins s -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins)

-- | Indices of an array shape.
--
-- >>> pretty $ indices @[3,3]
-- [[[0,0],[0,1],[0,2]],
--  [[1,0],[1,1],[1,2]],
--  [[2,0],[2,1],[2,2]]]
indices :: (KnownNats s) => Array s [Int]
indices :: forall (s :: [Nat]). KnownNats s => Array s [Int]
indices = (Rep (Array s) -> [Int]) -> Array s [Int]
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate Rep (Array s) -> [Int]
Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins

-- | The identity array.
--
-- >>> pretty $ ident @[3,3]
-- [[1,0,0],
--  [0,1,0],
--  [0,0,1]]
ident :: (KnownNats s, Num a) => Array s a
ident :: forall (s :: [Nat]) a. (KnownNats s, Num a) => Array s a
ident = (Rep (Array s) -> a) -> Array s a
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
0 a
1 (Bool -> a) -> (Fins s -> Bool) -> Fins s -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Bool
forall a. Eq a => [a] -> Bool
S.isDiag ([Int] -> Bool) -> (Fins s -> [Int]) -> Fins s -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins)

-- | Create an array composed of a single value.
--
-- >>> pretty $ konst @[3,2] 1
-- [[1,1],
--  [1,1],
--  [1,1]]
konst :: (KnownNats s) => a -> Array s a
konst :: forall (s :: [Nat]) a. KnownNats s => a -> Array s a
konst a
a = (Rep (Array s) -> a) -> Array s a
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (a -> Fins s -> a
forall a b. a -> b -> a
const a
a)

-- | Create an array of shape [1].
--
-- >>> pretty $ singleton 1
-- [1]
singleton :: a -> Array '[1] a
singleton :: forall a. a -> Array '[1] a
singleton a
a = Vector a -> Array '[1] a
forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Array s a
unsafeArray (a -> Vector a
forall a. a -> Vector a
V.singleton a
a)

-- | Extract the diagonal of an array.
--
-- >>> pretty $ diag (ident @[3,3])
-- [1,1,1]
diag ::
  forall s' a s.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (MinDim s)
  ) =>
  Array s a ->
  Array s' a
diag :: forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array s a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Int] -> Int
getDim Int
0) Array s a
a

-- | Expand an array to form a diagonal array
--
-- >>> pretty $ undiag (range @'[3])
-- [[0,0,0],
--  [0,1,0],
--  [0,0,2]]
undiag ::
  forall s' a s.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval ((++) s s),
    Num a
  ) =>
  Array s a ->
  Array s' a
undiag :: forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (s ++ s), Num a) =>
Array s a -> Array s' a
undiag Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array s')
xs -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
0 (Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins s) -> [Int] -> Fins s
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Int
getDim Int
0 (Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array s')
Fins s'
xs))) ([Int] -> Bool
forall a. Eq a => [a] -> Bool
isDiag (Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array s')
Fins s'
xs)))

-- | Zip two arrays at an element level.
--
-- >>> zipWith (-) v v
-- [0,0,0]
zipWith :: (KnownNats s) => (a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith :: forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith a -> b -> c
f (Array s a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector -> Vector a
a) (Array s b -> Vector b
forall t a. FromVector t a => t -> Vector a
asVector -> Vector b
b) = Vector c -> Array s c
forall (s :: [Nat]) t a.
(KnownNats s, FromVector t a) =>
t -> Array s a
unsafeArray ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> b -> c
f Vector a
a Vector b
b)

-- | Modify a single value at an index.
--
-- >>> pretty $ modify (S.UnsafeFins [0,0]) (const 100) (range @[3,2])
-- [[100,1],
--  [2,3],
--  [4,5]]
modify :: (KnownNats s) => Fins s -> (a -> a) -> Array s a -> Array s a
modify :: forall (s :: [Nat]) a.
KnownNats s =>
Fins s -> (a -> a) -> Array s a -> Array s a
modify Fins s
ds a -> a
f Array s a
a = (Rep (Array s) -> a) -> Array s a
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array s)
s -> (a -> a) -> (a -> a) -> Bool -> a -> a
forall a. a -> a -> Bool -> a
bool a -> a
forall a. a -> a
id a -> a
f (Rep (Array s)
Fins s
s Fins s -> Fins s -> Bool
forall a. Eq a => a -> a -> Bool
== Fins s
ds) (Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a Rep (Array s)
s))

-- | Maps an index function at element-level.
--
-- >>> pretty $ imap (\xs x -> x - sum xs) a
-- [[[0,0,0,0],
--   [3,3,3,3],
--   [6,6,6,6]],
--  [[11,11,11,11],
--   [14,14,14,14],
--   [17,17,17,17]]]
imap ::
  (KnownNats s) =>
  ([Int] -> a -> b) ->
  Array s a ->
  Array s b
imap :: forall (s :: [Nat]) a b.
KnownNats s =>
([Int] -> a -> b) -> Array s a -> Array s b
imap [Int] -> a -> b
f Array s a
a = ([Int] -> a -> b) -> Array s [Int] -> Array s a -> Array s b
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith [Int] -> a -> b
f Array s [Int]
forall (s :: [Nat]). KnownNats s => Array s [Int]
indices Array s a
a

-- | With a function that takes dimensions and (type-level) parameters, apply the parameters to the initial dimensions. ie
--
-- > rowWise f xs = f [0..rank xs - 1] xs
--
-- >>> toDynamic $ rowWise indexesT (S.SNats @[1,0]) a
-- UnsafeArray [4] [12,13,14,15]
rowWise ::
  forall a ds s s' xs proxy.
  ( KnownNats s,
    KnownNats ds,
    ds ~ Eval (DimsOf xs)
  ) =>
  (Dims ds -> proxy xs -> Array s a -> Array s' a) ->
  proxy xs ->
  Array s a ->
  Array s' a
rowWise :: forall a (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) (xs :: [Nat])
       (proxy :: [Nat] -> *).
(KnownNats s, KnownNats ds, ds ~ Eval (DimsOf xs)) =>
(Dims ds -> proxy xs -> Array s a -> Array s' a)
-> proxy xs -> Array s a -> Array s' a
rowWise Dims ds -> proxy xs -> Array s a -> Array s' a
f proxy xs
xs Array s a
a = Dims ds -> proxy xs -> Array s a -> Array s' a
f (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) proxy xs
xs Array s a
a

-- | With a function that takes dimensions and (type-level) parameters, apply the parameters to the the last dimensions. ie
--
-- > colWise f xs = f (List.reverse [0 .. (rank a - 1)]) xs
--
-- >>> toDynamic $ colWise indexesT (S.SNats @[1,0]) a
-- UnsafeArray [2] [1,13]
colWise ::
  forall a ds s s' xs proxy.
  ( KnownNats s,
    KnownNats ds,
    ds ~ Eval (EndDimsOf xs s)
  ) =>
  (Dims ds -> proxy xs -> Array s a -> Array s' a) ->
  proxy xs ->
  Array s a ->
  Array s' a
colWise :: forall a (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) (xs :: [Nat])
       (proxy :: [Nat] -> *).
(KnownNats s, KnownNats ds, ds ~ Eval (EndDimsOf xs s)) =>
(Dims ds -> proxy xs -> Array s a -> Array s' a)
-> proxy xs -> Array s a -> Array s' a
colWise Dims ds -> proxy xs -> Array s a -> Array s' a
f proxy xs
xs Array s a
a = Dims ds -> proxy xs -> Array s a -> Array s' a
f (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) proxy xs
xs Array s a
a

-- | Take the top-most elements across the specified dimension.
--
-- >>> pretty $ take (Dim @2) (SNat @1) a
-- [[[0],
--   [4],
--   [8]],
--  [[12],
--   [16],
--   [20]]]
take ::
  forall d t s s' a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (TakeDim d t s)
  ) =>
  Dim d ->
  SNat t ->
  Array s a ->
  Array s' a
take :: forall (d :: Nat) (t :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (TakeDim d t s)) =>
Dim d -> SNat t -> Array s a -> Array s' a
take Dim d
_ SNat t
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
forall a. a -> a
id Array s a
a

-- | Take the bottom-most elements across the specified dimension.
--
-- >>> pretty $ takeB (Dim @2) (SNat @1) a
-- [[[3],
--   [7],
--   [11]],
--  [[15],
--   [19],
--   [23]]]
takeB ::
  forall s s' a d t.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (TakeDim d t s)
  ) =>
  Dim d ->
  SNat t ->
  Array s a ->
  Array s' a
takeB :: forall (s :: [Nat]) (s' :: [Nat]) a (d :: Nat) (t :: Nat).
(KnownNats s, KnownNats s', s' ~ Eval (TakeDim d t s)) =>
Dim d -> SNat t -> Array s a -> Array s' a
takeB SNat d
Dim SNat t
SNat Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim (forall (n :: Nat). KnownNat n => Int
valueOf @d) (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> [Int] -> Int
getDim (forall (n :: Nat). KnownNat n => Int
valueOf @d) (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a) Int -> Int -> Int
forall a. Num a => a -> a -> a
- (forall (n :: Nat). KnownNat n => Int
valueOf @t))) Array s a
a

-- | Drop the top-most elements across the specified dimension.
--
-- >>> pretty $ drop (Dim @2) (SNat @1) a
-- [[[1,2,3],
--   [5,6,7],
--   [9,10,11]],
--  [[13,14,15],
--   [17,18,19],
--   [21,22,23]]]
drop ::
  forall s s' a d t.
  ( KnownNats s,
    KnownNats s',
    Eval (DropDim d t s) ~ s'
  ) =>
  Dim d ->
  SNat t ->
  Array s a ->
  Array s' a
drop :: forall (s :: [Nat]) (s' :: [Nat]) a (d :: Nat) (t :: Nat).
(KnownNats s, KnownNats s', Eval (DropDim d t s) ~ s') =>
Dim d -> SNat t -> Array s a -> Array s' a
drop SNat d
Dim SNat t
SNat Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> (Int -> Int) -> [Int] -> [Int]
S.modifyDim (forall (n :: Nat). KnownNat n => Int
valueOf @d) (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ forall (n :: Nat). KnownNat n => Int
valueOf @t)) Array s a
a

-- | Drop the bottom-most elements across the specified dimension.
--
-- >>> pretty $ dropB (Dim @2) (SNat @1) a
-- [[[0,1,2],
--   [4,5,6],
--   [8,9,10]],
--  [[12,13,14],
--   [16,17,18],
--   [20,21,22]]]
dropB ::
  forall s s' a d t.
  ( KnownNats s,
    KnownNats s',
    Eval (DropDim d t s) ~ s'
  ) =>
  Dim d ->
  SNat t ->
  Array s a ->
  Array s' a
dropB :: forall (s :: [Nat]) (s' :: [Nat]) a (d :: Nat) (t :: Nat).
(KnownNats s, KnownNats s', Eval (DropDim d t s) ~ s') =>
Dim d -> SNat t -> Array s a -> Array s' a
dropB Dim d
_ SNat t
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
forall a. a -> a
id Array s a
a

-- | Select an index along a dimension.
--
-- >>> let s = select (Dim @2) (S.fin @4 3) a
-- >>> pretty s
-- [[3,7,11],
--  [15,19,23]]
select ::
  forall d a p s s'.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (DeleteDim d s),
    p ~ Eval (GetDim d s)
  ) =>
  Dim d ->
  Fin p ->
  Array s a ->
  Array s' a
select :: forall (d :: Nat) a (p :: Nat) (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (DeleteDim d s),
 p ~ Eval (GetDim d s)) =>
Dim d -> Fin p -> Array s a -> Array s' a
select SNat d
Dim Fin p
p Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> Int -> [Int] -> [Int]
S.insertDim (forall (n :: Nat). KnownNat n => Int
valueOf @d) (Fin p -> Int
forall {k} (s :: k). Fin s -> Int
fromFin Fin p
p)) Array s a
a

-- | Insert along a dimension at a position.
--
-- >>> pretty $ insert (Dim @2) (UnsafeFin 0) a (konst @[2,3] 0)
-- [[[0,0,1,2,3],
--   [0,4,5,6,7],
--   [0,8,9,10,11]],
--  [[0,12,13,14,15],
--   [0,16,17,18,19],
--   [0,20,21,22,23]]]
-- >>> toDynamic $ insert (Dim @0) (UnsafeFin 0) (toScalar 1) (toScalar 2)
-- UnsafeArray [2] [2,1]
insert ::
  forall s' s si d p a.
  ( KnownNats s,
    KnownNats si,
    KnownNats s',
    s' ~ Eval (IncAt d s),
    p ~ Eval (GetDim d s),
    True ~ Eval (InsertOk d s si)
  ) =>
  Dim d ->
  Fin p ->
  Array s a ->
  Array si a ->
  Array s' a
insert :: forall (s' :: [Nat]) (s :: [Nat]) (si :: [Nat]) (d :: Nat)
       (p :: Nat) a.
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 p ~ Eval (GetDim d s), 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Fin p -> Array s a -> Array si a -> Array s' a
insert SNat d
Dim Fin p
i Array s a
a Array si a
b = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate Rep (Array s') -> a
Fins s' -> a
go
  where
    go :: Fins s' -> a
go Fins s'
s
      | Int -> [Int] -> Int
getDim Int
d [Int]
s' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Fin p -> Int
forall {k} (s :: k). Fin s -> Int
fromFin Fin p
i = Array si a -> Rep (Array si) -> a
forall a. Array si a -> Rep (Array si) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array si a
b ([Int] -> Fins si
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins (Int -> [Int] -> [Int]
deleteDim Int
d [Int]
s'))
      | Int -> [Int] -> Int
getDim Int
d [Int]
s' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Fin p -> Int
forall {k} (s :: k). Fin s -> Int
fromFin Fin p
i = Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
s')
      | Bool
otherwise = Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins (Int -> [Int] -> [Int]
decAt Int
d [Int]
s'))
      where
        s' :: [Int]
s' = Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Fins s'
s
    d :: Int
d = forall (n :: Nat). KnownNat n => Int
valueOf @d

-- | Delete along a dimension at a position.
--
-- >>> pretty $ delete (Dim @2) (UnsafeFin 3) a
-- [[[0,1,2],
--   [4,5,6],
--   [8,9,10]],
--  [[12,13,14],
--   [16,17,18],
--   [20,21,22]]]
delete ::
  forall d s s' p a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (DecAt d s),
    p ~ 1 + Eval (GetDim d s)
  ) =>
  Dim d ->
  Fin p ->
  Array s a ->
  Array s' a
delete :: forall (d :: Nat) (s :: [Nat]) (s' :: [Nat]) (p :: Nat) a.
(KnownNats s, KnownNats s', s' ~ Eval (DecAt d s),
 p ~ (1 + Eval (GetDim d s))) =>
Dim d -> Fin p -> Array s a -> Array s' a
delete SNat d
Dim Fin p
p Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (\[Int]
s -> [Int] -> [Int] -> Bool -> [Int]
forall a. a -> a -> Bool -> a
bool (Int -> [Int] -> [Int]
incAt Int
d [Int]
s) [Int]
s (Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Fin p -> Int
forall {k} (s :: k). Fin s -> Int
fromFin Fin p
p)) Array s a
a
  where
    d :: Int
d = forall (n :: Nat). KnownNat n => Int
valueOf @d

-- | Insert along a dimension at the end.
--
-- >>> pretty $ append (Dim @2) a (konst @[2,3] 0)
-- [[[0,1,2,3,0],
--   [4,5,6,7,0],
--   [8,9,10,11,0]],
--  [[12,13,14,15,0],
--   [16,17,18,19,0],
--   [20,21,22,23,0]]]
append ::
  forall a d s si s'.
  ( KnownNats s,
    KnownNats si,
    KnownNats s',
    s' ~ Eval (IncAt d s),
    True ~ Eval (InsertOk d s si)
  ) =>
  Dim d ->
  Array s a ->
  Array si a ->
  Array s' a
append :: forall a (d :: Nat) (s :: [Nat]) (si :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Array s a -> Array si a -> Array s' a
append (Dim d
Dim :: Dim d) = Dim d
-> Fin
     (If
        (Eval (TyEqImpl d 0 && TyEqImpl s '[]))
        1
        (Eval (FromMaybe (TypeError ...) (GetIndexImpl d s))))
-> Array s a
-> Array si a
-> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) (si :: [Nat]) (d :: Nat)
       (p :: Nat) a.
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 p ~ Eval (GetDim d s), 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Fin p -> Array s a -> Array si a -> Array s' a
insert (forall (n :: Nat). KnownNat n => SNat n
Dim @d) (Int
-> Fin
     (If
        (Eval (TyEqImpl d 0 && TyEqImpl s '[]))
        1
        (Eval (FromMaybe (TypeError ...) (GetIndexImpl d s))))
forall {k} (s :: k). Int -> Fin s
UnsafeFin (Int -> [Int] -> Int
getDim (forall (n :: Nat). KnownNat n => Int
valueOf @d) (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s)))

-- | Insert along a dimension at the beginning.
--
-- >>> pretty $ prepend (Dim @2) (konst @[2,3] 0) a
-- [[[0,0,1,2,3],
--   [0,4,5,6,7],
--   [0,8,9,10,11]],
--  [[0,12,13,14,15],
--   [0,16,17,18,19],
--   [0,20,21,22,23]]]
prepend ::
  forall a d s si s'.
  ( KnownNats s,
    KnownNats si,
    KnownNats s',
    s' ~ Eval (IncAt d s),
    True ~ Eval (InsertOk d s si)
  ) =>
  Dim d ->
  Array si a ->
  Array s a ->
  Array s' a
prepend :: forall a (d :: Nat) (s :: [Nat]) (si :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Array si a -> Array s a -> Array s' a
prepend Dim d
d Array si a
a Array s a
b = Dim d
-> Fin
     (If
        (Eval (TyEqImpl d 0 && TyEqImpl s '[]))
        1
        (Eval (FromMaybe (TypeError ...) (GetIndexImpl d s))))
-> Array s a
-> Array si a
-> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) (si :: [Nat]) (d :: Nat)
       (p :: Nat) a.
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 p ~ Eval (GetDim d s), 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Fin p -> Array s a -> Array si a -> Array s' a
insert Dim d
d (Int
-> Fin
     (If
        (Eval (TyEqImpl d 0 && TyEqImpl s '[]))
        1
        (Eval (FromMaybe (TypeError ...) (GetIndexImpl d s))))
forall {k} (s :: k). Int -> Fin s
UnsafeFin Int
0) Array s a
b Array si a
a

-- | Concatenate along a dimension.
--
-- >>> shape $ concatenate (Dim @1) a a
-- [2,6,4]
-- >>> toDynamic $ concatenate (Dim @0) (toScalar 1) (toScalar 2)
-- UnsafeArray [2] [1,2]
-- >>> toDynamic $ concatenate (Dim @0) (array @'[1] [0]) (array @'[3] [1..3])
-- UnsafeArray [4] [0,1,2,3]
concatenate ::
  forall a s0 s1 d s.
  ( KnownNats s0,
    KnownNats s1,
    KnownNats s,
    Eval (Concatenate d s0 s1) ~ s
  ) =>
  Dim d ->
  Array s0 a ->
  Array s1 a ->
  Array s a
concatenate :: forall a (s0 :: [Nat]) (s1 :: [Nat]) (d :: Nat) (s :: [Nat]).
(KnownNats s0, KnownNats s1, KnownNats s,
 Eval (Concatenate d s0 s1) ~ s) =>
Dim d -> Array s0 a -> Array s1 a -> Array s a
concatenate SNat d
Dim Array s0 a
a0 Array s1 a
a1 = (Rep (Array s) -> a) -> Array s a
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate ([Int] -> a
go ([Int] -> a) -> (Fins s -> [Int]) -> Fins s -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins)
  where
    go :: [Int] -> a
go [Int]
s =
      a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
        (Array s0 a -> Rep (Array s0) -> a
forall a. Array s0 a -> Rep (Array s0) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s0 a
a0 ([Int] -> Fins s0
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
s))
        ( Array s1 a -> Rep (Array s1) -> a
forall a. Array s1 a -> Rep (Array s1) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index
            Array s1 a
a1
            ( [Int] -> Fins s1
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins s1) -> [Int] -> Fins s1
forall a b. (a -> b) -> a -> b
$
                Int -> Int -> [Int] -> [Int]
insertDim
                  Int
d'
                  (Int -> [Int] -> Int
getDim Int
d' [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> [Int] -> Int
getDim Int
d' [Int]
ds0)
                  (Int -> [Int] -> [Int]
deleteDim Int
d' [Int]
s)
            )
        )
        (Int -> [Int] -> Int
getDim Int
d' [Int]
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> [Int] -> Int
getDim Int
d' [Int]
ds0)
    ds0 :: [Int]
ds0 = Array s0 a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s0 a
a0
    d' :: Int
d' = forall (n :: Nat). KnownNat n => Int
valueOf @d

-- | Combine two arrays as a new dimension of a new array.
--
-- >>> pretty $ couple (Dim @0) (array @'[3] [1,2,3]) (array @'[3] @Int [4,5,6])
-- [[1,2,3],
--  [4,5,6]]
-- >>> couple (Dim @0) (toScalar @Int 0) (toScalar 1)
-- [0,1]
couple ::
  forall d a s s' se.
  ( KnownNat d,
    KnownNats s,
    KnownNats s',
    KnownNats se,
    s' ~ Eval (Concatenate d se se),
    se ~ Eval (InsertDim d 1 s)
  ) =>
  Dim d ->
  Array s a ->
  Array s a ->
  Array s' a
couple :: forall (d :: Nat) a (s :: [Nat]) (s' :: [Nat]) (se :: [Nat]).
(KnownNat d, KnownNats s, KnownNats s', KnownNats se,
 s' ~ Eval (Concatenate d se se), se ~ Eval (InsertDim d 1 s)) =>
Dim d -> Array s a -> Array s a -> Array s' a
couple Dim d
d Array s a
a Array s a
a' = Dim d -> Array se a -> Array se a -> Array s' a
forall a (s0 :: [Nat]) (s1 :: [Nat]) (d :: Nat) (s :: [Nat]).
(KnownNats s0, KnownNats s1, KnownNats s,
 Eval (Concatenate d s0 s1) ~ s) =>
Dim d -> Array s0 a -> Array s1 a -> Array s a
concatenate Dim d
d (Dim d -> Array s a -> Array se a
forall (s :: [Nat]) (s' :: [Nat]) (d :: Nat) a.
(KnownNats s, KnownNats s', s' ~ Eval (InsertDim d 1 s)) =>
Dim d -> Array s a -> Array s' a
elongate Dim d
d Array s a
a) (Dim d -> Array s a -> Array se a
forall (s :: [Nat]) (s' :: [Nat]) (d :: Nat) a.
(KnownNats s, KnownNats s', s' ~ Eval (InsertDim d 1 s)) =>
Dim d -> Array s a -> Array s' a
elongate Dim d
d Array s a
a')

-- | Slice along a dimension with the supplied offset & length.
--
-- >>> pretty $ slice (Dim @2) (SNat @1) (SNat @2) a
-- [[[1,2],
--   [5,6],
--   [9,10]],
--  [[13,14],
--   [17,18],
--   [21,22]]]
slice ::
  forall a d off l s s'.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (SetDim d l s),
    Eval (SliceOk d off l s) ~ True
  ) =>
  Dim d ->
  SNat off ->
  SNat l ->
  Array s a ->
  Array s' a
slice :: forall a (d :: Nat) (off :: Nat) (l :: Nat) (s :: [Nat])
       (s' :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (SetDim d l s),
 Eval (SliceOk d off l s) ~ 'True) =>
Dim d -> SNat off -> SNat l -> Array s a -> Array s' a
slice SNat d
Dim SNat off
SNat SNat l
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> (Int -> Int) -> [Int] -> [Int]
S.modifyDim (forall (n :: Nat). KnownNat n => Int
valueOf @d) (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (forall (n :: Nat). KnownNat n => Int
valueOf @off))) Array s a
a

-- | Rotate an array along a dimension.
--
-- >>> pretty $ rotate (Dim @1) 2 a
-- [[[8,9,10,11],
--   [0,1,2,3],
--   [4,5,6,7]],
--  [[20,21,22,23],
--   [12,13,14,15],
--   [16,17,18,19]]]
rotate ::
  forall d s a.
  (KnownNats s) =>
  Dim d ->
  Int ->
  Array s a ->
  Array s a
rotate :: forall (d :: Nat) (s :: [Nat]) a.
KnownNats s =>
Dim d -> Int -> Array s a -> Array s a
rotate SNat d
Dim Int
r Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> Int -> [Int] -> [Int] -> [Int]
rotateIndex (forall (n :: Nat). KnownNat n => Int
valueOf @d) Int
r (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a)) Array s a
a

-- * multi-dimensional operators

-- | Across the specified dimensions, takes the top-most elements.
--
-- >>> pretty $ takes (Dims @[0,1]) (S.SNats @[1,2]) a
-- [[[0,1,2,3],
--   [4,5,6,7]]]
takes ::
  forall ds xs s' s a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (SetDims ds xs s)
  ) =>
  Dims ds ->
  SNats xs ->
  Array s a ->
  Array s' a
takes :: forall (ds :: [Nat]) (xs :: [Nat]) (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (SetDims ds xs s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
takes Dims ds
_ SNats xs
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
forall a. a -> a
id Array s a
a

-- | Across the specified dimensions, takes the bottom-most elements.
--
-- >>> pretty (takeBs (Dims @[0,1]) (S.SNats @[1,2]) a)
-- [[[16,17,18,19],
--   [20,21,22,23]]]
takeBs ::
  forall s' s a ds xs.
  ( KnownNats s,
    KnownNats s',
    KnownNats ds,
    KnownNats xs,
    s' ~ Eval (SetDims ds xs s)
  ) =>
  Dims ds ->
  SNats xs ->
  Array s a ->
  Array s' a
takeBs :: forall (s' :: [Nat]) (s :: [Nat]) a (ds :: [Nat]) (xs :: [Nat]).
(KnownNats s, KnownNats s', KnownNats ds, KnownNats xs,
 s' ~ Eval (SetDims ds xs s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
takeBs Dims ds
_ SNats xs
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
start) Array s a
a
  where
    start :: [Int]
start = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a) ([Int] -> [Int] -> [Int] -> [Int]
S.setDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @xs) (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a))

-- | Across the specified dimensions, drops the top-most elements.
--
-- >>> pretty $ drops (Dims @[0,2]) (S.SNats @[1,3]) a
-- [[[15],
--   [19],
--   [23]]]
drops ::
  forall ds xs s' s a.
  ( KnownNats s,
    KnownNats s',
    KnownNats ds,
    KnownNats xs,
    s' ~ Eval (DropDims ds xs s)
  ) =>
  Dims ds ->
  SNats xs ->
  Array s a ->
  Array s' a
drops :: forall (ds :: [Nat]) (xs :: [Nat]) (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s', KnownNats ds, KnownNats xs,
 s' ~ Eval (DropDims ds xs s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
drops Dims ds
_ SNats xs
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
start) Array s a
a
  where
    start :: [Int]
start = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s) (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s')

-- | Across the specified dimensions, drops the bottom-most elements.
--
-- >>> pretty $ dropBs (Dims @[0,2]) (S.SNats @[1,3]) a
-- [[[0],
--   [4],
--   [8]]]
dropBs ::
  forall s' s ds xs a.
  ( KnownNats s,
    KnownNats s',
    KnownNats ds,
    KnownNats xs,
    s' ~ Eval (DropDims ds xs s)
  ) =>
  Dims ds ->
  SNats xs ->
  Array s a ->
  Array s' a
dropBs :: forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', KnownNats ds, KnownNats xs,
 s' ~ Eval (DropDims ds xs s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
dropBs Dims ds
_ SNats xs
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
forall a. a -> a
id Array s a
a

-- | Select by dimensions and indexes.
--
-- >>> pretty $ indexes (Dims @[0,1]) (S.UnsafeFins [1,1]) a
-- [16,17,18,19]
-- >>> F.indexes (S.SNats @'[1]) (S.fins @'[3] [1]) (F.range @[2,3])
-- [1,4]
indexes ::
  forall s' s ds xs a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (DeleteDims ds s),
    xs ~ Eval (GetDims ds s)
  ) =>
  Dims ds ->
  Fins xs ->
  Array s a ->
  Array s' a
indexes :: forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (DeleteDims ds s),
 xs ~ Eval (GetDims ds s)) =>
Dims ds -> Fins xs -> Array s a -> Array s' a
indexes SNats ds
Dims Fins xs
xs Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ([Int] -> [Int] -> [Int] -> [Int]
S.insertDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) (Fins xs -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Fins xs
xs)) Array s a
a

-- | Select by dimensions and indexes, supplying indexes as a type.
--
-- >>> pretty $ indexesT (Dims @[0,1]) (S.SNats @[1,1]) a
-- [16,17,18,19]
indexesT ::
  forall ds xs s s' a.
  ( KnownNats s,
    KnownNats ds,
    KnownNats xs,
    KnownNats s',
    s' ~ Eval (DeleteDims ds s),
    True ~ Eval (IsFins xs =<< GetDims ds s)
  ) =>
  Dims ds ->
  SNats xs ->
  Array s a ->
  Array s' a
indexesT :: forall (ds :: [Nat]) (xs :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats ds, KnownNats xs, KnownNats s',
 s' ~ Eval (DeleteDims ds s),
 'True ~ Eval (IsFins xs =<< GetDims ds s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
indexesT Dims ds
ds SNats xs
_ Array s a
a = Dims ds
-> Fins (Eval (Map (Flip GetDim s) ds)) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (DeleteDims ds s),
 xs ~ Eval (GetDims ds s)) =>
Dims ds -> Fins xs -> Array s a -> Array s' a
indexes Dims ds
ds ([Int] -> Fins (Eval (Map (Flip GetDim s) ds))
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins (Eval (Map (Flip GetDim s) ds)))
-> [Int] -> Fins (Eval (Map (Flip GetDim s) ds))
forall a b. (a -> b) -> a -> b
$ forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @xs) Array s a
a

-- | Slice along dimensions with the supplied offsets and lengths.
--
-- >>> pretty $ slices (Dims @'[2]) (S.SNats @'[1]) (S.SNats @'[2]) a
-- [[[1,2],
--   [5,6],
--   [9,10]],
--  [[13,14],
--   [17,18],
--   [21,22]]]
slices ::
  forall a ds ls offs s s'.
  ( KnownNats s,
    KnownNats s',
    KnownNats ds,
    KnownNats ls,
    KnownNats offs,
    Eval (SlicesOk ds offs ls s) ~ True,
    Eval (SetDims ds ls s) ~ s'
  ) =>
  Dims ds ->
  SNats offs ->
  SNats ls ->
  Array s a ->
  Array s' a
slices :: forall a (ds :: [Nat]) (ls :: [Nat]) (offs :: [Nat]) (s :: [Nat])
       (s' :: [Nat]).
(KnownNats s, KnownNats s', KnownNats ds, KnownNats ls,
 KnownNats offs, Eval (SlicesOk ds offs ls s) ~ 'True,
 Eval (SetDims ds ls s) ~ s') =>
Dims ds -> SNats offs -> SNats ls -> Array s a -> Array s' a
slices Dims ds
_ SNats offs
_ SNats ls
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
o) Array s a
a
  where
    o :: [Int]
o = [Int] -> [Int] -> [Int] -> [Int]
S.setDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @offs) (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array s a
a) Int
0)

-- | Select the first element along the supplied dimensions.
--
-- >>> pretty $ heads (Dims @[0,2]) a
-- [0,4,8]
heads ::
  forall a ds s s'.
  ( KnownNats s,
    KnownNats s',
    KnownNats ds,
    s' ~ Eval (DeleteDims ds s)
  ) =>
  Dims ds ->
  Array s a ->
  Array s' a
heads :: forall a (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', KnownNats ds,
 s' ~ Eval (DeleteDims ds s)) =>
Dims ds -> Array s a -> Array s' a
heads Dims ds
ds Array s a
a = Dims ds
-> Fins (Eval (Map (Flip GetDim s) ds)) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (DeleteDims ds s),
 xs ~ Eval (GetDims ds s)) =>
Dims ds -> Fins xs -> Array s a -> Array s' a
indexes Dims ds
ds ([Int] -> Fins (Eval (Map (Flip GetDim s) ds))
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins (Eval (Map (Flip GetDim s) ds)))
-> [Int] -> Fins (Eval (Map (Flip GetDim s) ds))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (forall (s :: [Nat]). KnownNats s => Int
rankOf @ds) Int
0) Array s a
a

-- | Select the last element along the supplied dimensions.
--
-- >>> pretty $ lasts (Dims @[0,2]) a
-- [15,19,23]
lasts ::
  forall ds s s' a.
  ( KnownNats s,
    KnownNats ds,
    KnownNats s',
    s' ~ Eval (DeleteDims ds s)
  ) =>
  Dims ds ->
  Array s a ->
  Array s' a
lasts :: forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats ds, KnownNats s',
 s' ~ Eval (DeleteDims ds s)) =>
Dims ds -> Array s a -> Array s' a
lasts Dims ds
ds Array s a
a = Dims ds
-> Fins (Eval (Map (Flip GetDim s) ds)) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (DeleteDims ds s),
 xs ~ Eval (GetDims ds s)) =>
Dims ds -> Fins xs -> Array s a -> Array s' a
indexes Dims ds
ds ([Int] -> Fins (Eval (Map (Flip GetDim s) ds))
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
lastds) Array s a
a
  where
    lastds :: [Int]
lastds = (\Int
i -> Int -> [Int] -> Int
getDim Int
i (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int) -> [Int] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds)

-- | Select the tail elements along the supplied dimensions.
--
-- >>> pretty $ tails (Dims @[0,2]) a
-- [[[13,14,15],
--   [17,18,19],
--   [21,22,23]]]
tails ::
  forall ds os s s' a ls.
  ( KnownNats s,
    KnownNats ds,
    KnownNats s',
    KnownNats ls,
    KnownNats os,
    Eval (SlicesOk ds os ls s) ~ True,
    os ~ Eval (Replicate (Eval (Rank ds)) 1),
    ls ~ Eval (GetLastPositions ds s),
    s' ~ Eval (SetDims ds ls s)
  ) =>
  Dims ds ->
  Array s a ->
  Array s' a
tails :: forall (ds :: [Nat]) (os :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a
       (ls :: [Nat]).
(KnownNats s, KnownNats ds, KnownNats s', KnownNats ls,
 KnownNats os, Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 1),
 ls ~ Eval (GetLastPositions ds s), s' ~ Eval (SetDims ds ls s)) =>
Dims ds -> Array s a -> Array s' a
tails Dims ds
ds Array s a
a = Dims ds -> SNats os -> SNats ls -> Array s a -> Array s' a
forall a (ds :: [Nat]) (ls :: [Nat]) (offs :: [Nat]) (s :: [Nat])
       (s' :: [Nat]).
(KnownNats s, KnownNats s', KnownNats ds, KnownNats ls,
 KnownNats offs, Eval (SlicesOk ds offs ls s) ~ 'True,
 Eval (SetDims ds ls s) ~ s') =>
Dims ds -> SNats offs -> SNats ls -> Array s a -> Array s' a
slices Dims ds
ds (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @os) (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ls) Array s a
a

-- | Select the init elements along the supplied dimensions.
--
-- >>> pretty $ inits (Dims @[0,2]) a
-- [[[0,1,2],
--   [4,5,6],
--   [8,9,10]]]
inits ::
  forall ds os s s' a ls.
  ( KnownNats s,
    KnownNats ds,
    KnownNats s',
    KnownNats ls,
    KnownNats os,
    Eval (SlicesOk ds os ls s) ~ True,
    os ~ Eval (Replicate (Eval (Rank ds)) 0),
    ls ~ Eval (GetLastPositions ds s),
    s' ~ Eval (SetDims ds ls s)
  ) =>
  Dims ds ->
  Array s a ->
  Array s' a
inits :: forall (ds :: [Nat]) (os :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a
       (ls :: [Nat]).
(KnownNats s, KnownNats ds, KnownNats s', KnownNats ls,
 KnownNats os, Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 0),
 ls ~ Eval (GetLastPositions ds s), s' ~ Eval (SetDims ds ls s)) =>
Dims ds -> Array s a -> Array s' a
inits Dims ds
ds Array s a
a = Dims ds -> SNats os -> SNats ls -> Array s a -> Array s' a
forall a (ds :: [Nat]) (ls :: [Nat]) (offs :: [Nat]) (s :: [Nat])
       (s' :: [Nat]).
(KnownNats s, KnownNats s', KnownNats ds, KnownNats ls,
 KnownNats offs, Eval (SlicesOk ds offs ls s) ~ 'True,
 Eval (SetDims ds ls s) ~ s') =>
Dims ds -> SNats offs -> SNats ls -> Array s a -> Array s' a
slices Dims ds
ds (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @os) (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ls) Array s a
a

-- | Extracts specified dimensions to an outer layer.
--
-- >>> :t extracts (Dims @'[0]) (range @[2,3,4])
-- extracts (Dims @'[0]) (range @[2,3,4])
--   :: Array '[2] (Array [3, 4] Int)
extracts ::
  forall ds st si so a.
  ( KnownNats st,
    KnownNats ds,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds st),
    so ~ Eval (GetDims ds st)
  ) =>
  Dims ds ->
  Array st a ->
  Array so (Array si a)
extracts :: forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts Dims ds
ds Array st a
a = (Rep (Array so) -> Array si a) -> Array so (Array si a)
forall a. (Rep (Array so) -> a) -> Array so a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array so)
s -> Dims ds -> Fins so -> Array st a -> Array si a
forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (DeleteDims ds s),
 xs ~ Eval (GetDims ds s)) =>
Dims ds -> Fins xs -> Array s a -> Array s' a
indexes Dims ds
ds Rep (Array so)
Fins so
s Array st a
a)

-- | Reduce along specified dimensions, using the supplied fold.
--
-- >>> pretty $ reduces (Dims @'[0]) sum a
-- [66,210]
-- >>> pretty $ reduces (Dims @[0,2]) sum a
-- [[12,15,18,21],
--  [48,51,54,57]]
reduces ::
  forall ds st si so a b.
  ( KnownNats st,
    KnownNats ds,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds st),
    so ~ Eval (GetDims ds st)
  ) =>
  Dims ds ->
  (Array si a -> b) ->
  Array st a ->
  Array so b
reduces :: forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a b.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> (Array si a -> b) -> Array st a -> Array so b
reduces Dims ds
ds Array si a -> b
f Array st a
a = (Array si a -> b) -> Array so (Array si a) -> Array so b
forall a b. (a -> b) -> Array so a -> Array so b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array si a -> b
f (Dims ds -> Array st a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts Dims ds
ds Array st a
a)

-- | Join inner and outer dimension layers by supplied dimensions.
--
-- >>> let e = extracts (Dims @[1,0]) a
-- >>> let j = joins (Dims @[1,0]) e
-- >>> a == j
-- True
joins ::
  forall a ds si so st.
  ( KnownNats ds,
    KnownNats st,
    KnownNats si,
    KnownNats so,
    Eval (InsertDims ds so si) ~ st
  ) =>
  Dims ds ->
  Array so (Array si a) ->
  Array st a
joins :: forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins Dims ds
_ Array so (Array si a)
a = (Rep (Array st) -> a) -> Array st a
forall a. (Rep (Array st) -> a) -> Array st a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate Rep (Array st) -> a
Fins st -> a
go
  where
    go :: Fins st -> a
go Fins st
s = Array si a -> Rep (Array si) -> a
forall a. Array si a -> Rep (Array si) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index (Array so (Array si a) -> Rep (Array so) -> Array si a
forall a. Array so a -> Rep (Array so) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array so (Array si a)
a ([Int] -> Fins so
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins so) -> [Int] -> Fins so
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
S.getDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) (Fins st -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Fins st
s))) ([Int] -> Fins si
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins si) -> [Int] -> Fins si
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
S.deleteDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) (Fins st -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Fins st
s))

-- | Join inner and outer dimension layers in outer dimension order.
--
-- >>> a == join (extracts (Dims @[0,1]) a)
-- True
join ::
  forall a si so st ds.
  ( KnownNats st,
    KnownNats si,
    KnownNats so,
    KnownNats ds,
    ds ~ Eval (DimsOf so),
    st ~ Eval (InsertDims ds so si)
  ) =>
  Array so (Array si a) ->
  Array st a
join :: forall a (si :: [Nat]) (so :: [Nat]) (st :: [Nat]) (ds :: [Nat]).
(KnownNats st, KnownNats si, KnownNats so, KnownNats ds,
 ds ~ Eval (DimsOf so), st ~ Eval (InsertDims ds so si)) =>
Array so (Array si a) -> Array st a
join Array so (Array si a)
a = Dims ds -> Array so (Array si a) -> Array st a
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ds) Array so (Array si a)
a

-- | Traverse along specified dimensions.
--
-- >>> traverses (Dims @'[1]) print (range @[2,3])
-- 0
-- 3
-- 1
-- 4
-- 2
-- 5
-- [(),(),(),(),(),()]
traverses ::
  ( Applicative f,
    KnownNats s,
    KnownNats si,
    KnownNats so,
    si ~ Eval (GetDims ds s),
    so ~ Eval (DeleteDims ds s),
    s ~ Eval (InsertDims ds si so)
  ) =>
  Dims ds ->
  (a -> f b) ->
  Array s a ->
  f (Array s b)
traverses :: forall (f :: * -> *) (s :: [Nat]) (si :: [Nat]) (so :: [Nat])
       (ds :: [Nat]) a b.
(Applicative f, KnownNats s, KnownNats si, KnownNats so,
 si ~ Eval (GetDims ds s), so ~ Eval (DeleteDims ds s),
 s ~ Eval (InsertDims ds si so)) =>
Dims ds -> (a -> f b) -> Array s a -> f (Array s b)
traverses (Dims ds
Dims :: Dims ds) a -> f b
f Array s a
a = Dims ds -> Array si (Array so b) -> Array s b
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ds) (Array si (Array so b) -> Array s b)
-> f (Array si (Array so b)) -> f (Array s b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Array so a -> f (Array so b))
-> Array si (Array so a) -> f (Array si (Array so b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array si a -> f (Array si b)
traverse ((a -> f b) -> Array so a -> f (Array so b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array so a -> f (Array so b)
traverse a -> f b
f) (Dims ds -> Array s a -> Array si (Array so a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (Dims ds
forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims :: Dims ds) Array s a
a)

-- | Maps a function along specified dimensions.
--
-- >>> pretty $ maps (Dims @'[1]) transpose a
-- [[[0,12],
--   [4,16],
--   [8,20]],
--  [[1,13],
--   [5,17],
--   [9,21]],
--  [[2,14],
--   [6,18],
--   [10,22]],
--  [[3,15],
--   [7,19],
--   [11,23]]]
maps ::
  forall ds s s' si si' so a b.
  ( KnownNats s,
    KnownNats s',
    KnownNats si,
    KnownNats si',
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s' ~ Eval (InsertDims ds so si'),
    s ~ Eval (InsertDims ds so si)
  ) =>
  Dims ds ->
  (Array si a -> Array si' b) ->
  Array s a ->
  Array s' b
maps :: forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) (si :: [Nat])
       (si' :: [Nat]) (so :: [Nat]) a b.
(KnownNats s, KnownNats s', KnownNats si, KnownNats si',
 KnownNats so, si ~ Eval (DeleteDims ds s),
 so ~ Eval (GetDims ds s), s' ~ Eval (InsertDims ds so si'),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds -> (Array si a -> Array si' b) -> Array s a -> Array s' b
maps SNats ds
SNats Array si a -> Array si' b
f Array s a
a = SNats ds -> Array so (Array si' b) -> Array s' b
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ds) ((Array si a -> Array si' b)
-> Array so (Array si a) -> Array so (Array si' b)
forall a b. (a -> b) -> Array so a -> Array so b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array si a -> Array si' b
f (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ds) Array s a
a))

-- | Filters along specified dimensions (which are flattened as a dynamic array).
--
-- >>> pretty $ filters (Dims @[0,1]) (any ((==0) . (`mod` 7))) a
-- [[0,1,2,3],[4,5,6,7],[12,13,14,15],[20,21,22,23]]
filters ::
  forall ds si so a.
  ( KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds so),
    KnownNats (Eval (GetDims ds so))
  ) =>
  Dims ds ->
  (Array si a -> Bool) ->
  Array so a ->
  A.Array (Array si a)
filters :: forall (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats si, KnownNats so, si ~ Eval (DeleteDims ds so),
 KnownNats (Eval (GetDims ds so))) =>
Dims ds -> (Array si a -> Bool) -> Array so a -> Array (Array si a)
filters SNats ds
Dims Array si a -> Bool
p Array so a
a = Vector (Array si a) -> Array (Array si a)
forall t a. FromArray t a => t -> Array a
A.asArray (Vector (Array si a) -> Array (Array si a))
-> Vector (Array si a) -> Array (Array si a)
forall a b. (a -> b) -> a -> b
$ (Array si a -> Bool) -> Vector (Array si a) -> Vector (Array si a)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Array si a -> Bool
p (Vector (Array si a) -> Vector (Array si a))
-> Vector (Array si a) -> Vector (Array si a)
forall a b. (a -> b) -> a -> b
$ Array (Eval (Map (Flip GetDim so) ds)) (Array si a)
-> Vector (Array si a)
forall t a. FromVector t a => t -> Vector a
asVector (SNats ds
-> Array so a
-> Array (Eval (Map (Flip GetDim so) ds)) (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array so a
a)

-- | Zips two arrays with a function along specified dimensions.
--
-- >>> pretty $ zips (Dims @[0,1]) (zipWith (,)) a (reverses (Dims @'[0]) a)
-- [[[(0,12),(1,13),(2,14),(3,15)],
--   [(4,16),(5,17),(6,18),(7,19)],
--   [(8,20),(9,21),(10,22),(11,23)]],
--  [[(12,0),(13,1),(14,2),(15,3)],
--   [(16,4),(17,5),(18,6),(19,7)],
--   [(20,8),(21,9),(22,10),(23,11)]]]
zips ::
  forall ds s s' si si' so a b c.
  ( KnownNats s,
    KnownNats s',
    KnownNats si,
    KnownNats si',
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s' ~ Eval (InsertDims ds so si'),
    s ~ Eval (InsertDims ds so si)
  ) =>
  Dims ds ->
  (Array si a -> Array si b -> Array si' c) ->
  Array s a ->
  Array s b ->
  Array s' c
zips :: forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) (si :: [Nat])
       (si' :: [Nat]) (so :: [Nat]) a b c.
(KnownNats s, KnownNats s', KnownNats si, KnownNats si',
 KnownNats so, si ~ Eval (DeleteDims ds s),
 so ~ Eval (GetDims ds s), s' ~ Eval (InsertDims ds so si'),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds
-> (Array si a -> Array si b -> Array si' c)
-> Array s a
-> Array s b
-> Array s' c
zips SNats ds
SNats Array si a -> Array si b -> Array si' c
f Array s a
a Array s b
b = SNats ds -> Array so (Array si' c) -> Array s' c
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) ((Array si a -> Array si b -> Array si' c)
-> Array so (Array si a)
-> Array so (Array si b)
-> Array so (Array si' c)
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith Array si a -> Array si b -> Array si' c
f (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a) (SNats ds -> Array s b -> Array so (Array si b)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s b
b))

-- | Modify using the supplied function along dimensions and positions.
--
-- >>> pretty $ modifies (fmap (100+)) (Dims @'[2]) (S.UnsafeFins [0]) a
-- [[[100,1,2,3],
--   [104,5,6,7],
--   [108,9,10,11]],
--  [[112,13,14,15],
--   [116,17,18,19],
--   [120,21,22,23]]]
modifies ::
  forall a si s ds so.
  ( KnownNats s,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s ~ Eval (InsertDims ds so si)
  ) =>
  (Array si a -> Array si a) ->
  Dims ds ->
  Fins so ->
  Array s a ->
  Array s a
modifies :: forall a (si :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (so :: [Nat]).
(KnownNats s, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds s), so ~ Eval (GetDims ds s),
 s ~ Eval (InsertDims ds so si)) =>
(Array si a -> Array si a)
-> Dims ds -> Fins so -> Array s a -> Array s a
modifies Array si a -> Array si a
f SNats ds
SNats Fins so
ps Array s a
a = SNats ds -> Array so (Array si a) -> Array s a
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) (Array so (Array si a) -> Array s a)
-> Array so (Array si a) -> Array s a
forall a b. (a -> b) -> a -> b
$ Fins so
-> (Array si a -> Array si a)
-> Array so (Array si a)
-> Array so (Array si a)
forall (s :: [Nat]) a.
KnownNats s =>
Fins s -> (a -> a) -> Array s a -> Array s a
modify Fins so
ps Array si a -> Array si a
f (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | Apply a binary function between successive slices, across dimensions and lags.
--
-- >>> pretty $ diffs (Dims @'[1]) (S.SNats @'[1]) (zipWith (-)) a
-- [[[4,4,4,4],
--   [4,4,4,4]],
--  [[4,4,4,4],
--   [4,4,4,4]]]
diffs ::
  forall a b ds ls si si' st st' so postDrop.
  ( KnownNats ls,
    KnownNats si,
    KnownNats si',
    KnownNats st,
    KnownNats st',
    KnownNats so,
    KnownNats postDrop,
    si ~ Eval (DeleteDims ds postDrop),
    so ~ Eval (GetDims ds postDrop),
    st' ~ Eval (InsertDims ds so si'),
    postDrop ~ Eval (InsertDims ds so si),
    postDrop ~ Eval (DropDims ds ls st)
  ) =>
  Dims ds ->
  SNats ls ->
  (Array si a -> Array si a -> Array si' b) ->
  Array st a ->
  Array st' b
diffs :: forall a b (ds :: [Nat]) (ls :: [Nat]) (si :: [Nat]) (si' :: [Nat])
       (st :: [Nat]) (st' :: [Nat]) (so :: [Nat]) (postDrop :: [Nat]).
(KnownNats ls, KnownNats si, KnownNats si', KnownNats st,
 KnownNats st', KnownNats so, KnownNats postDrop,
 si ~ Eval (DeleteDims ds postDrop),
 so ~ Eval (GetDims ds postDrop), st' ~ Eval (InsertDims ds so si'),
 postDrop ~ Eval (InsertDims ds so si),
 postDrop ~ Eval (DropDims ds ls st)) =>
Dims ds
-> SNats ls
-> (Array si a -> Array si a -> Array si' b)
-> Array st a
-> Array st' b
diffs SNats ds
SNats SNats ls
xs Array si a -> Array si a -> Array si' b
f Array st a
a = SNats ds
-> (Array si a -> Array si a -> Array si' b)
-> Array postDrop a
-> Array postDrop a
-> Array st' b
forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) (si :: [Nat])
       (si' :: [Nat]) (so :: [Nat]) a b c.
(KnownNats s, KnownNats s', KnownNats si, KnownNats si',
 KnownNats so, si ~ Eval (DeleteDims ds s),
 so ~ Eval (GetDims ds s), s' ~ Eval (InsertDims ds so si'),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds
-> (Array si a -> Array si b -> Array si' c)
-> Array s a
-> Array s b
-> Array s' c
zips (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array si a -> Array si a -> Array si' b
f (SNats ds -> SNats ls -> Array st a -> Array postDrop a
forall (ds :: [Nat]) (xs :: [Nat]) (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s', KnownNats ds, KnownNats xs,
 s' ~ Eval (DropDims ds xs s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
drops (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) SNats ls
xs Array st a
a) (SNats ds -> SNats ls -> Array st a -> Array postDrop a
forall (s' :: [Nat]) (s :: [Nat]) (ds :: [Nat]) (xs :: [Nat]) a.
(KnownNats s, KnownNats s', KnownNats ds, KnownNats xs,
 s' ~ Eval (DropDims ds xs s)) =>
Dims ds -> SNats xs -> Array s a -> Array s' a
dropBs (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) SNats ls
xs Array st a
a)

-- | Product two arrays using the supplied binary function.
--
-- For context, if the function is multiply, and the arrays are tensors,
-- then this can be interpreted as a [tensor product](https://en.wikipedia.org/wiki/Tensor_product).
-- The concept of a tensor product is a dense crossroad, and a complete treatment is elsewhere.  To quote the wiki article:
--
-- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category.
--
-- >>> x = array [1,2,3] :: Array '[3] Int
-- >>> pretty $ expand (*) x x
-- [[1,2,3],
--  [2,4,6],
--  [3,6,9]]
--
-- Alternatively, expand can be understood as representing the permutation of element pairs of two arrays, so like the Applicative List instance.
--
-- >>> i2 = indices @[2,2]
-- >>> pretty $ expand (,) i2 i2
-- [[[[([0,0],[0,0]),([0,0],[0,1])],
--    [([0,0],[1,0]),([0,0],[1,1])]],
--   [[([0,1],[0,0]),([0,1],[0,1])],
--    [([0,1],[1,0]),([0,1],[1,1])]]],
--  [[[([1,0],[0,0]),([1,0],[0,1])],
--    [([1,0],[1,0]),([1,0],[1,1])]],
--   [[([1,1],[0,0]),([1,1],[0,1])],
--    [([1,1],[1,0]),([1,1],[1,1])]]]]
expand ::
  forall sc sa sb a b c.
  ( KnownNats sa,
    KnownNats sb,
    KnownNats sc,
    sc ~ Eval ((++) sa sb)
  ) =>
  (a -> b -> c) ->
  Array sa a ->
  Array sb b ->
  Array sc c
expand :: forall (sc :: [Nat]) (sa :: [Nat]) (sb :: [Nat]) a b c.
(KnownNats sa, KnownNats sb, KnownNats sc, sc ~ Eval (sa ++ sb)) =>
(a -> b -> c) -> Array sa a -> Array sb b -> Array sc c
expand a -> b -> c
f Array sa a
a Array sb b
b = (Rep (Array sc) -> c) -> Array sc c
forall a. (Rep (Array sc) -> a) -> Array sc a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array sc)
i -> a -> b -> c
f (Array sa a -> Rep (Array sa) -> a
forall a. Array sa a -> Rep (Array sa) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array sa a
a ([Int] -> Fins sa
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins sa) -> [Int] -> Fins sa
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
r (Fins sc -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array sc)
Fins sc
i))) (Array sb b -> Rep (Array sb) -> b
forall a. Array sb a -> Rep (Array sb) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array sb b
b ([Int] -> Fins sb
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins sb) -> [Int] -> Fins sb
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
r (Fins sc -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array sc)
Fins sc
i))))
  where
    r :: Int
r = Array sa a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array sa a
a

-- | Like expand, but permutes the first array first, rather than the second.
--
-- >>> pretty $ expand (,) v (fmap (+3) v)
-- [[(0,3),(0,4),(0,5)],
--  [(1,3),(1,4),(1,5)],
--  [(2,3),(2,4),(2,5)]]
--
-- >>> pretty $ coexpand (,) v (fmap (+3) v)
-- [[(0,3),(1,3),(2,3)],
--  [(0,4),(1,4),(2,4)],
--  [(0,5),(1,5),(2,5)]]
coexpand ::
  forall sc sa sb a b c.
  ( KnownNats sa,
    KnownNats sb,
    KnownNats sc,
    sc ~ Eval ((++) sa sb)
  ) =>
  (a -> b -> c) ->
  Array sa a ->
  Array sb b ->
  Array sc c
coexpand :: forall (sc :: [Nat]) (sa :: [Nat]) (sb :: [Nat]) a b c.
(KnownNats sa, KnownNats sb, KnownNats sc, sc ~ Eval (sa ++ sb)) =>
(a -> b -> c) -> Array sa a -> Array sb b -> Array sc c
coexpand a -> b -> c
f Array sa a
a Array sb b
b = (Rep (Array sc) -> c) -> Array sc c
forall a. (Rep (Array sc) -> a) -> Array sc a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array sc)
i -> a -> b -> c
f (Array sa a -> Rep (Array sa) -> a
forall a. Array sa a -> Rep (Array sa) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array sa a
a ([Int] -> Fins sa
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins sa) -> [Int] -> Fins sa
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
r (Fins sc -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array sc)
Fins sc
i))) (Array sb b -> Rep (Array sb) -> b
forall a. Array sb a -> Rep (Array sb) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array sb b
b ([Int] -> Fins sb
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int] -> Fins sb) -> [Int] -> Fins sb
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
r (Fins sc -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array sc)
Fins sc
i))))
  where
    r :: Int
r = Array sa a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array sa a
a

-- | Contract an array by applying the supplied (folding) function on diagonal elements of the dimensions.
--
-- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2.
--
--
-- >>> pretty $ contract (Dims @[1,2]) sum (expand (*) m (transpose m))
-- [[5,14],
--  [14,50]]
contract ::
  forall a b s ss se s' ds ds'.
  ( KnownNats se,
    se ~ Eval (DeleteDims ds' s),
    KnownNats ds',
    KnownNats s,
    KnownNats ss,
    KnownNats s',
    s' ~ Eval (GetDims ds' s),
    ss ~ Eval (MinDim se),
    ds' ~ Eval (ExceptDims ds s)
  ) =>
  Dims ds ->
  (Array ss a -> b) ->
  Array s a ->
  Array s' b
contract :: forall a b (s :: [Nat]) (ss :: [Nat]) (se :: [Nat]) (s' :: [Nat])
       (ds :: [Nat]) (ds' :: [Nat]).
(KnownNats se, se ~ Eval (DeleteDims ds' s), KnownNats ds',
 KnownNats s, KnownNats ss, KnownNats s', s' ~ Eval (GetDims ds' s),
 ss ~ Eval (MinDim se), ds' ~ Eval (ExceptDims ds s)) =>
Dims ds -> (Array ss a -> b) -> Array s a -> Array s' b
contract SNats ds
SNats Array ss a -> b
f Array s a
a = Array ss a -> b
f (Array ss a -> b) -> (Array se a -> Array ss a) -> Array se a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array se a -> Array ss a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag (Array se a -> b) -> Array s' (Array se a) -> Array s' b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Dims ds' -> Array s a -> Array s' (Array se a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds') Array s a
a

-- | Expand two arrays and then contract the result using the supplied matching dimensions.
--
-- >>> pretty $ prod (Dims @'[1]) (Dims @'[0]) sum (*) (range @[2,3]) (range @[3,2])
-- [[10,13],
--  [28,40]]
--
-- With full laziness, this computation would be equivalent to:
--
-- > f . diag <$> extracts (Dims @ds') (expand g a b)
prod ::
  forall a b c d s0 s1 so0 so1 si st ds0 ds1.
  ( KnownNats so0,
    KnownNats so1,
    KnownNats si,
    KnownNats s0,
    KnownNats s1,
    KnownNats st,
    KnownNats ds0,
    KnownNats ds1,
    so0 ~ Eval (DeleteDims ds0 s0),
    so1 ~ Eval (DeleteDims ds1 s1),
    si ~ Eval (GetDims ds0 s0),
    si ~ Eval (GetDims ds1 s1),
    st ~ Eval ((++) so0 so1)
  ) =>
  Dims ds0 ->
  Dims ds1 ->
  (Array si c -> d) ->
  (a -> b -> c) ->
  Array s0 a ->
  Array s1 b ->
  Array st d
prod :: forall a b c d (s0 :: [Nat]) (s1 :: [Nat]) (so0 :: [Nat])
       (so1 :: [Nat]) (si :: [Nat]) (st :: [Nat]) (ds0 :: [Nat])
       (ds1 :: [Nat]).
(KnownNats so0, KnownNats so1, KnownNats si, KnownNats s0,
 KnownNats s1, KnownNats st, KnownNats ds0, KnownNats ds1,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1)) =>
Dims ds0
-> Dims ds1
-> (Array si c -> d)
-> (a -> b -> c)
-> Array s0 a
-> Array s1 b
-> Array st d
prod SNats ds0
SNats SNats ds1
SNats Array si c -> d
g a -> b -> c
f Array s0 a
a Array s1 b
b = ([Int] -> d) -> Array st d
forall (s :: [Nat]) a. KnownNats s => ([Int] -> a) -> Array s a
unsafeTabulate (\[Int]
so -> Array si c -> d
g (Array si c -> d) -> Array si c -> d
forall a b. (a -> b) -> a -> b
$ ([Int] -> c) -> Array si c
forall (s :: [Nat]) a. KnownNats s => ([Int] -> a) -> Array s a
unsafeTabulate (\[Int]
si -> a -> b -> c
f (Array s0 a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Array s0 a
a ([Int] -> [Int] -> [Int] -> [Int]
S.insertDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds0) [Int]
si (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
sp [Int]
so))) (Array s1 b -> [Int] -> b
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Array s1 b
b ([Int] -> [Int] -> [Int] -> [Int]
S.insertDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds1) [Int]
si (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
sp [Int]
so)))))
  where
    sp :: Int
sp = Array s0 a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array s0 a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- forall (s :: [Nat]). KnownNats s => Int
rankOf @ds0

-- | A generalisation of a dot operation, which is a multiplicative expansion of two arrays and sum contraction along the middle two dimensions.
--
-- matrix multiplication
--
-- >>> pretty $ dot sum (*) m (transpose m)
-- [[5,14],
--  [14,50]]
--
-- inner product
--
-- >>> pretty $ dot sum (*) v v
-- 5
--
-- matrix-vector multiplication
-- Note that an Array with shape [3] is neither a row vector nor column vector.
--
-- >>> pretty $ dot sum (*) v (transpose m)
-- [5,14]
--
-- >>> pretty $ dot sum (*) m v
-- [5,14]
dot ::
  forall a b c d ds0 ds1 s0 s1 so0 so1 st si.
  ( KnownNats s0,
    KnownNats s1,
    KnownNats ds0,
    KnownNats ds1,
    KnownNats so0,
    KnownNats so1,
    KnownNats st,
    KnownNats si,
    so0 ~ Eval (DeleteDims ds0 s0),
    so1 ~ Eval (DeleteDims ds1 s1),
    si ~ Eval (GetDims ds0 s0),
    si ~ Eval (GetDims ds1 s1),
    st ~ Eval ((++) so0 so1),
    ds0 ~ '[Eval ((Fcf.-) (Eval (Rank s0)) 1)],
    ds1 ~ '[0]
  ) =>
  (Array si c -> d) ->
  (a -> b -> c) ->
  Array s0 a ->
  Array s1 b ->
  Array st d
dot :: forall a b c d (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat])
       (s1 :: [Nat]) (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat])
       (si :: [Nat]).
(KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
(Array si c -> d)
-> (a -> b -> c) -> Array s0 a -> Array s1 b -> Array st d
dot Array si c -> d
f a -> b -> c
g Array s0 a
a Array s1 b
b = Dims ds0
-> Dims ds1
-> (Array si c -> d)
-> (a -> b -> c)
-> Array s0 a
-> Array s1 b
-> Array st d
forall a b c d (s0 :: [Nat]) (s1 :: [Nat]) (so0 :: [Nat])
       (so1 :: [Nat]) (si :: [Nat]) (st :: [Nat]) (ds0 :: [Nat])
       (ds1 :: [Nat]).
(KnownNats so0, KnownNats so1, KnownNats si, KnownNats s0,
 KnownNats s1, KnownNats st, KnownNats ds0, KnownNats ds1,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1)) =>
Dims ds0
-> Dims ds1
-> (Array si c -> d)
-> (a -> b -> c)
-> Array s0 a
-> Array s1 b
-> Array st d
prod (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds0) (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds1) Array si c -> d
f a -> b -> c
g Array s0 a
a Array s1 b
b

-- | Array multiplication.
--
-- matrix multiplication
--
-- >>> pretty $ mult m (transpose m)
-- [[5,14],
--  [14,50]]
--
-- inner product
--
-- >>> pretty $ mult v v
-- 5
--
-- matrix-vector multiplication
--
-- >>> pretty $ mult v (transpose m)
-- [5,14]
--
-- >>> pretty $ mult m v
-- [5,14]
mult ::
  forall a ds0 ds1 s0 s1 so0 so1 st si.
  ( Num a,
    KnownNats s0,
    KnownNats s1,
    KnownNats ds0,
    KnownNats ds1,
    KnownNats so0,
    KnownNats so1,
    KnownNats st,
    KnownNats si,
    so0 ~ Eval (DeleteDims ds0 s0),
    so1 ~ Eval (DeleteDims ds1 s1),
    si ~ Eval (GetDims ds0 s0),
    si ~ Eval (GetDims ds1 s1),
    st ~ Eval ((++) so0 so1),
    ds0 ~ '[Eval ((Fcf.-) (Eval (Rank s0)) 1)],
    ds1 ~ '[0]
  ) =>
  Array s0 a ->
  Array s1 a ->
  Array st a
mult :: forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
       (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Num a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult = (Array (Eval (Map (Flip GetDim s0) '[Eval (Length s0) - 1])) a
 -> a)
-> (a -> a -> a) -> Array s0 a -> Array s1 a -> Array st a
forall a b c d (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat])
       (s1 :: [Nat]) (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat])
       (si :: [Nat]).
(KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
(Array si c -> d)
-> (a -> b -> c) -> Array s0 a -> Array s1 b -> Array st d
dot Array (Eval (Map (Flip GetDim s0) '[Eval (Length s0) - 1])) a -> a
forall a.
Num a =>
Array (Eval (Map (Flip GetDim s0) '[Eval (Length s0) - 1])) a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum a -> a -> a
forall a. Num a => a -> a -> a
(*)

-- | @windows xs@ are xs-sized windows of an array
--
-- >>> shape $ windows (Dims @[2,2]) (range @[4,3,2])
-- [3,2,2,2,2]
windows ::
  forall w s ws a.
  ( KnownNats s,
    KnownNats ws,
    ws ~ Eval (ExpandWindows w s)
  ) =>
  SNats w -> Array s a -> Array ws a
windows :: forall (w :: [Nat]) (s :: [Nat]) (ws :: [Nat]) a.
(KnownNats s, KnownNats ws, ws ~ Eval (ExpandWindows w s)) =>
SNats w -> Array s a -> Array ws a
windows SNats w
SNats Array s a
a = ([Int] -> [Int]) -> Array s a -> Array ws a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> [Int] -> [Int]
S.indexWindows (forall (s :: [Nat]). KnownNats s => Int
rankOf @w)) Array s a
a

-- | Find the starting positions of occurences of one array in another.
--
-- >>> a = cycle @[4,4] (range @'[3])
-- >>> i = array @[2,2] [1,2,2,0]
-- >>> pretty $ find i a
-- [[False,True,False],
--  [True,False,False],
--  [False,False,True]]
find ::
  forall s' si s a r i' re ws.
  ( Eq a,
    KnownNats si,
    KnownNats s,
    KnownNats s',
    KnownNats re,
    KnownNats i',
    KnownNat r,
    KnownNats ws,
    ws ~ Eval (ExpandWindows i' s),
    r ~ Eval (Rank s),
    i' ~ Eval (Rerank r si),
    re ~ Eval (DimWindows ws s),
    i' ~ Eval (DeleteDims re ws),
    s' ~ Eval (GetDims re ws)
  ) =>
  Array si a -> Array s a -> Array s' Bool
find :: forall (s' :: [Nat]) (si :: [Nat]) (s :: [Nat]) a (r :: Nat)
       (i' :: [Nat]) (re :: [Nat]) (ws :: [Nat]).
(Eq a, KnownNats si, KnownNats s, KnownNats s', KnownNats re,
 KnownNats i', KnownNat r, KnownNats ws,
 ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s),
 i' ~ Eval (Rerank r si), re ~ Eval (DimWindows ws s),
 i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws)) =>
Array si a -> Array s a -> Array s' Bool
find Array si a
i Array s a
a = Array s' Bool
xs
  where
    i' :: Array i' a
i' = SNat r -> Array si a -> Array i' a
forall (r :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Rerank r s)) =>
SNat r -> Array s a -> Array s' a
rerank (forall (n :: Nat). KnownNat n => SNat n
SNat @r) Array si a
i
    ws :: Array ws a
ws = SNats i' -> Array s a -> Array ws a
forall (w :: [Nat]) (s :: [Nat]) (ws :: [Nat]) a.
(KnownNats s, KnownNats ws, ws ~ Eval (ExpandWindows w s)) =>
SNats w -> Array s a -> Array ws a
windows (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @i') Array s a
a
    xs :: Array s' Bool
xs = (Array i' a -> Bool) -> Array s' (Array i' a) -> Array s' Bool
forall a b. (a -> b) -> Array s' a -> Array s' b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Array i' a -> Array i' a -> Bool
forall a. Eq a => a -> a -> Bool
== Array i' a
i') (Dims re -> Array ws a -> Array s' (Array i' a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @re) Array ws a
ws)

-- | Find the ending positions of one array in another except where the array overlaps with another copy.
--
-- >>> a = konst @[5,5] @Int 1
-- >>> i = konst @[2,2] @Int 1
-- >>> pretty $ findNoOverlap i a
-- [[True,False,True,False],
--  [False,False,False,False],
--  [True,False,True,False],
--  [False,False,False,False]]
findNoOverlap ::
  forall s' si s a r i' re ws.
  ( Eq a,
    KnownNats si,
    KnownNats s,
    KnownNats s',
    KnownNats re,
    KnownNats i',
    KnownNat r,
    KnownNats ws,
    ws ~ Eval (ExpandWindows i' s),
    r ~ Eval (Rank s),
    i' ~ Eval (Rerank r si),
    re ~ Eval (DimWindows ws s),
    i' ~ Eval (DeleteDims re ws),
    s' ~ Eval (GetDims re ws)
  ) =>
  Array si a -> Array s a -> Array s' Bool
findNoOverlap :: forall (s' :: [Nat]) (si :: [Nat]) (s :: [Nat]) a (r :: Nat)
       (i' :: [Nat]) (re :: [Nat]) (ws :: [Nat]).
(Eq a, KnownNats si, KnownNats s, KnownNats s', KnownNats re,
 KnownNats i', KnownNat r, KnownNats ws,
 ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s),
 i' ~ Eval (Rerank r si), re ~ Eval (DimWindows ws s),
 i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws)) =>
Array si a -> Array s a -> Array s' Bool
findNoOverlap Array si a
i Array s a
a = Array s' Bool
r
  where
    f :: Array s' Bool
f = Array si a -> Array s a -> Array s' Bool
forall (s' :: [Nat]) (si :: [Nat]) (s :: [Nat]) a (r :: Nat)
       (i' :: [Nat]) (re :: [Nat]) (ws :: [Nat]).
(Eq a, KnownNats si, KnownNats s, KnownNats s', KnownNats re,
 KnownNats i', KnownNat r, KnownNats ws,
 ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s),
 i' ~ Eval (Rerank r si), re ~ Eval (DimWindows ws s),
 i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws)) =>
Array si a -> Array s a -> Array s' Bool
find Array si a
i Array s a
a

    cl :: [Int] -> [[Int]]
    cl :: [Int] -> [[Int]]
cl [Int]
sh = ([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
List.filter (Bool -> Bool
P.not (Bool -> Bool) -> ([Int] -> Bool) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) ([Int] -> Bool) -> ([Int] -> [Int]) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
List.init) ([[Int]] -> [[Int]]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ ([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
List.filter (Bool -> Bool
P.not (Bool -> Bool) -> ([Int] -> Bool) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)) ([[Int]] -> [[Int]]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ Array [Int] -> [[Int]]
forall t a. FromArray t a => Array a -> t
A.arrayAs (Array [Int] -> [[Int]]) -> Array [Int] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ [Int] -> ([Int] -> [Int]) -> Array [Int]
forall a. [Int] -> ([Int] -> a) -> Array a
A.tabulate ((\Int
x -> Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int) -> [Int] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
sh) (\[Int]
s -> (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (\Int
x Int
x0 -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
s [Int]
sh)
    go :: Array s' Bool -> [Int] -> Bool
go Array s' Bool
r' [Int]
s = Array s' Bool -> Rep (Array s') -> Bool
forall a. Array s' a -> Rep (Array s') -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s' Bool
f ([Int] -> Fins s'
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins [Int]
s) Bool -> Bool -> Bool
&& Bool -> Bool
not (([Int] -> Bool) -> [[Int]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Array s' Bool -> Rep (Array s') -> Bool
forall a. Array s' a -> Rep (Array s') -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s' Bool
r' (Fins s' -> Bool) -> ([Int] -> Fins s') -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Fins s'
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins) (([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
List.filter (\[Int]
x -> [Int] -> [Int] -> Bool
isFins [Int]
x (Array s' Bool -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s' Bool
f)) ([[Int]] -> [[Int]]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
s) ([Int] -> [[Int]]
cl (Array si a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array si a
i))))
    r :: Array s' Bool
r = ([Int] -> Bool) -> Array s' Bool
forall (s :: [Nat]) a. KnownNats s => ([Int] -> a) -> Array s a
unsafeTabulate (Array s' Bool -> [Int] -> Bool
go Array s' Bool
r)

-- | Check if the first array is a prefix of the second.
--
-- >>> isPrefixOf (array @[2,2] [0,1,4,5]) a
-- True
isPrefixOf ::
  forall s' s r a.
  ( Eq a,
    KnownNats s,
    KnownNats s',
    KnownNat r,
    KnownNats (Eval (Rerank r s)),
    True ~ Eval (IsSubset s' s),
    r ~ Eval (Rank s')
  ) =>
  Array s' a -> Array s a -> Bool
isPrefixOf :: forall (s' :: [Nat]) (s :: [Nat]) (r :: Nat) a.
(Eq a, KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), 'True ~ Eval (IsSubset s' s),
 r ~ Eval (Rank s')) =>
Array s' a -> Array s a -> Bool
isPrefixOf Array s' a
p Array s a
a = Array s' a
p Array s' a -> Array s' a -> Bool
forall a. Eq a => a -> a -> Bool
== Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) (r :: Nat) a.
(KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), 'True ~ Eval (IsSubset s' s),
 r ~ Eval (Rank s')) =>
Array s a -> Array s' a
cut Array s a
a

-- | Check if the first array is a suffix of the second.
--
-- >>> isSuffixOf (array @[2,2] [18,19,22,23]) a
-- True
isSuffixOf ::
  forall s' s r a.
  ( Eq a,
    KnownNats s,
    KnownNats s',
    KnownNat r,
    KnownNats (Eval (Rerank r s)),
    r ~ Eval (Rank s'),
    True ~ Eval (IsSubset s' s)
  ) =>
  Array s' a -> Array s a -> Bool
isSuffixOf :: forall (s' :: [Nat]) (s :: [Nat]) (r :: Nat) a.
(Eq a, KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), r ~ Eval (Rank s'),
 'True ~ Eval (IsSubset s' s)) =>
Array s' a -> Array s a -> Bool
isSuffixOf Array s' a
p Array s a
a = Array s' a
p Array s' a -> Array s' a -> Bool
forall a. Eq a => a -> a -> Bool
== Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a (r :: Nat).
(KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), r ~ Eval (Rank s'),
 'True ~ Eval (IsSubset s' s)) =>
Array s a -> Array s' a
cutSuffix Array s a
a

-- | Check if the first array is an infix of the second.
--
-- >>> isInfixOf (array @[2,2] [18,19,22,23]) a
-- True
isInfixOf ::
  forall s' si s a r i' re ws.
  ( Eq a,
    KnownNats si,
    KnownNats s,
    KnownNats s',
    KnownNats re,
    KnownNats i',
    KnownNat r,
    KnownNats ws,
    ws ~ Eval (ExpandWindows i' s),
    r ~ Eval (Rank s),
    i' ~ Eval (Rerank r si),
    re ~ Eval (DimWindows ws s),
    i' ~ Eval (DeleteDims re ws),
    s' ~ Eval (GetDims re ws)
  ) =>
  Array si a -> Array s a -> Bool
isInfixOf :: forall (s' :: [Nat]) (si :: [Nat]) (s :: [Nat]) a (r :: Nat)
       (i' :: [Nat]) (re :: [Nat]) (ws :: [Nat]).
(Eq a, KnownNats si, KnownNats s, KnownNats s', KnownNats re,
 KnownNats i', KnownNat r, KnownNats ws,
 ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s),
 i' ~ Eval (Rerank r si), re ~ Eval (DimWindows ws s),
 i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws)) =>
Array si a -> Array s a -> Bool
isInfixOf Array si a
p Array s a
a = Array s' Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (Array s' Bool -> Bool) -> Array s' Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Array si a -> Array s a -> Array s' Bool
forall (s' :: [Nat]) (si :: [Nat]) (s :: [Nat]) a (r :: Nat)
       (i' :: [Nat]) (re :: [Nat]) (ws :: [Nat]).
(Eq a, KnownNats si, KnownNats s, KnownNats s', KnownNats re,
 KnownNats i', KnownNat r, KnownNats ws,
 ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s),
 i' ~ Eval (Rerank r si), re ~ Eval (DimWindows ws s),
 i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws)) =>
Array si a -> Array s a -> Array s' Bool
find Array si a
p Array s a
a

-- | Fill an array with the supplied value without regard to the original shape or cut the array values to match array size.
--
-- > validate (def x a) == True
--
-- >>> pretty $ fill @'[3] 0 (array @'[0] [])
-- [0,0,0]
-- >>> pretty $ fill @'[3] 0 (array @'[4] [1..4])
-- [1,2,3]
fill ::
  forall s' a s.
  ( KnownNats s,
    KnownNats s'
  ) =>
  a -> Array s a -> Array s' a
fill :: forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s') =>
a -> Array s a -> Array s' a
fill a
x (Array Vector a
v) = Vector a -> Array s' a
forall a (s :: [Nat]). Vector a -> Array s a
Array (Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.take ([Int] -> Int
S.size (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s')) (Vector a
v Vector a -> Vector a -> Vector a
forall a. Semigroup a => a -> a -> a
<> Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate ([Int] -> Int
S.size (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s') Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v) a
x))

-- | Cut an array to form a new (smaller) shape. Errors if the new shape is larger. The old array is reranked to the rank of the new shape first.
--
-- >>> toDynamic $ cut @'[2] (array @'[4] @Int [0..3])
-- UnsafeArray [2] [0,1]
cut ::
  forall s' s r a.
  ( KnownNats s,
    KnownNats s',
    KnownNat r,
    KnownNats (Eval (Rerank r s)),
    True ~ Eval (IsSubset s' s),
    r ~ Eval (Rank s')
  ) =>
  Array s a ->
  Array s' a
cut :: forall (s' :: [Nat]) (s :: [Nat]) (r :: Nat) a.
(KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), 'True ~ Eval (IsSubset s' s),
 r ~ Eval (Rank s')) =>
Array s a -> Array s' a
cut Array s a
a = ([Int] -> [Int])
-> Array
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
     a
-> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
forall a. a -> a
id (SNat r
-> Array s a
-> Array
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
     a
forall (r :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Rerank r s)) =>
SNat r -> Array s a -> Array s' a
rerank (forall (n :: Nat). KnownNat n => SNat n
SNat @r) Array s a
a)

-- | Cut an array to form a new (smaller) shape, using suffix elements. Errors if the new shape is larger. The old array is reranked to the rank of the new shape first.
--
-- >>> toDynamic $ cutSuffix @[2,2] a
-- UnsafeArray [2,2] [18,19,22,23]
cutSuffix ::
  forall s' s a r.
  ( KnownNats s,
    KnownNats s',
    KnownNat r,
    KnownNats (Eval (Rerank r s)),
    r ~ Eval (Rank s'),
    True ~ Eval (IsSubset s' s)
  ) =>
  Array s a ->
  Array s' a
cutSuffix :: forall (s' :: [Nat]) (s :: [Nat]) a (r :: Nat).
(KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), r ~ Eval (Rank s'),
 'True ~ Eval (IsSubset s' s)) =>
Array s a -> Array s' a
cutSuffix Array s a
a = ([Int] -> [Int])
-> Array
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
     a
-> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
diffDim) Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a'
  where
    a' :: Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a' = SNat r
-> Array s a
-> Array
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
     a
forall (r :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Rerank r s)) =>
SNat r -> Array s a -> Array s' a
rerank (forall (n :: Nat). KnownNat n => SNat n
SNat @r) Array s a
a
    diffDim :: [Int]
diffDim = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) (Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a') (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s')

-- | Pad an array to form a new shape, supplying a default value for elements outside the shape of the old array. The old array is reranked to the rank of the new shape first.
--
-- >>> toDynamic $ pad @'[5] 0 (array @'[4] @Int [0..3])
-- UnsafeArray [5] [0,1,2,3,0]
pad ::
  forall s' a s r.
  ( KnownNats s,
    KnownNats s',
    KnownNat r,
    KnownNats (Eval (Rerank r s)),
    r ~ Eval (Rank s')
  ) =>
  a ->
  Array s a ->
  Array s' a
pad :: forall (s' :: [Nat]) a (s :: [Nat]) (r :: Nat).
(KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), r ~ Eval (Rank s')) =>
a -> Array s a -> Array s' a
pad a
d Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array s')
s -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
d (Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> Rep
     (Array
        (If
           (Eval
              (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
           (Eval
              (UnfoldrCase
                 (NumIter 1)
                 (If
                    (Eval
                       (Not
                          (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                    ('Just '(1, (r - Eval (Length s)) - 1))
                    'Nothing))
            <> s)
           (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
              : Drop_ ((Eval (Length s) - r) + 1) s)))
-> a
forall a.
Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> Rep
     (Array
        (If
           (Eval
              (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
           (Eval
              (UnfoldrCase
                 (NumIter 1)
                 (If
                    (Eval
                       (Not
                          (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                    ('Just '(1, (r - Eval (Length s)) - 1))
                    'Nothing))
            <> s)
           (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
              : Drop_ ((Eval (Length s) - r) + 1) s)))
-> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a' (Fins s'
-> Fins
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
forall a b. a -> b
unsafeCoerce Rep (Array s')
Fins s'
s)) (Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Rep (Array s')
Fins s'
s [Int] -> [Int] -> Bool
`S.isFins` Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a'))
  where
    a' :: Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a' = SNat r
-> Array s a
-> Array
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
     a
forall (r :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Rerank r s)) =>
SNat r -> Array s a -> Array s' a
rerank (forall (n :: Nat). KnownNat n => SNat n
SNat @r) Array s a
a

-- | Left pad an array to form a new shape, supplying a default value for elements outside the shape of the old array.
--
-- >>> toDynamic $ lpad @'[5] 0 (array @'[4] [0..3])
-- UnsafeArray [5] [0,0,1,2,3]
-- >>> pretty $ lpad @[3,3] 0 (range @[2,2])
-- [[0,0,0],
--  [0,0,1],
--  [0,2,3]]
lpad ::
  forall s' a s r.
  ( KnownNats s,
    KnownNats s',
    KnownNat r,
    KnownNats (Eval (Rerank r s)),
    r ~ Eval (Rank s')
  ) =>
  a ->
  Array s a ->
  Array s' a
lpad :: forall (s' :: [Nat]) a (s :: [Nat]) (r :: Nat).
(KnownNats s, KnownNats s', KnownNat r,
 KnownNats (Eval (Rerank r s)), r ~ Eval (Rank s')) =>
a -> Array s a -> Array s' a
lpad a
d Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array s')
s -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
d (Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> Rep
     (Array
        (If
           (Eval
              (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
           (Eval
              (UnfoldrCase
                 (NumIter 1)
                 (If
                    (Eval
                       (Not
                          (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                    ('Just '(1, (r - Eval (Length s)) - 1))
                    'Nothing))
            <> s)
           (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
              : Drop_ ((Eval (Length s) - r) + 1) s)))
-> a
forall a.
Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> Rep
     (Array
        (If
           (Eval
              (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
           (Eval
              (UnfoldrCase
                 (NumIter 1)
                 (If
                    (Eval
                       (Not
                          (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                    ('Just '(1, (r - Eval (Length s)) - 1))
                    'Nothing))
            <> s)
           (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
              : Drop_ ((Eval (Length s) - r) + 1) s)))
-> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a' ([Int]
-> Fins
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
forall {k} (s :: k). [Int] -> Fins s
UnsafeFins ([Int]
 -> Fins
      (If
         (Eval
            (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
         (Eval
            (UnfoldrCase
               (NumIter 1)
               (If
                  (Eval
                     (Not
                        (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                  ('Just '(1, (r - Eval (Length s)) - 1))
                  'Nothing))
          <> s)
         (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
            : Drop_ ((Eval (Length s) - r) + 1) s)))
-> [Int]
-> Fins
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
forall a b. (a -> b) -> a -> b
$ Fins s' -> [Int]
olds Rep (Array s')
Fins s'
s)) (Fins s' -> [Int]
olds Rep (Array s')
Fins s'
s [Int] -> [Int] -> Bool
`S.isFins` Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a'))
  where
    a' :: Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a' = SNat r
-> Array s a
-> Array
     (If
        (Eval
           (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
        (Eval
           (UnfoldrCase
              (NumIter 1)
              (If
                 (Eval
                    (Not
                       (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
                 ('Just '(1, (r - Eval (Length s)) - 1))
                 'Nothing))
         <> s)
        (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
           : Drop_ ((Eval (Length s) - r) + 1) s))
     a
forall (r :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Rerank r s)) =>
SNat r -> Array s a -> Array s' a
rerank (forall (n :: Nat). KnownNat n => SNat n
SNat @r) Array s a
a
    gap :: [Int]
gap = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s') (Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
-> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array
  (If
     (Eval
        (Not (OrdCond (CmpNat r (Eval (Length s))) 'True 'True 'False)))
     (Eval
        (UnfoldrCase
           (NumIter 1)
           (If
              (Eval
                 (Not
                    (OrdCond (CmpNat (r - Eval (Length s)) 0) 'True 'True 'False)))
              ('Just '(1, (r - Eval (Length s)) - 1))
              'Nothing))
      <> s)
     (Eval (Foldr (*) 1 (Take_ ((Eval (Length s) - r) + 1) s))
        : Drop_ ((Eval (Length s) - r) + 1) s))
  a
a')
    olds :: Fins s' -> [Int]
olds Fins s'
s = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) (Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Fins s'
s) [Int]
gap

-- | Reshape an array (with the same number of elements).
--
-- >>> pretty $ reshape @[4,3,2] a
-- [[[0,1],
--   [2,3],
--   [4,5]],
--  [[6,7],
--   [8,9],
--   [10,11]],
--  [[12,13],
--   [14,15],
--   [16,17]],
--  [[18,19],
--   [20,21],
--   [22,23]]]
reshape ::
  forall s' s a.
  ( Eval (Size s) ~ Eval (Size s'),
    KnownNats s,
    KnownNats s'
  ) =>
  Array s a ->
  Array s' a
reshape :: forall (s' :: [Nat]) (s :: [Nat]) a.
(Eval (Size s) ~ Eval (Size s'), KnownNats s, KnownNats s') =>
Array s a -> Array s' a
reshape = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ([Int] -> Int -> [Int]
shapen [Int]
s (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
flatten [Int]
s')
  where
    s :: [Int]
s = forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s
    s' :: [Int]
s' = forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s'

-- | Make an Array single dimensional.
--
-- >>> pretty $ flat (range @[2,2])
-- [0,1,2,3]
-- >>> pretty (flat $ toScalar 0)
-- [0]
flat ::
  forall s' s a.
  ( KnownNats s,
    KnownNats s',
    s' ~ '[Eval (Size s)]
  ) =>
  Array s a ->
  Array s' a
flat :: forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ '[Eval (Size s)]) =>
Array s a -> Array s' a
flat Array s a
a = Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape Array s a
a

-- | Reshape an array, repeating the original array. The shape of the array should be a suffix of the new shape.
--
-- >>> pretty $ repeat @[2,2,2] (array @'[2] [1,2])
-- [[[1,2],
--   [1,2]],
--  [[1,2],
--   [1,2]]]
--
-- > repeat ds (toScalar x) == konst ds x
repeat ::
  forall s' s a.
  ( KnownNats s,
    KnownNats s',
    Eval (IsPrefixOf s s') ~ True
  ) =>
  Array s a ->
  Array s' a
repeat :: forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s', Eval (IsPrefixOf s s') ~ 'True) =>
Array s a -> Array s' a
repeat Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop ([Int] -> Int
forall a. [a] -> Int
S.rank (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s') Int -> Int -> Int
forall a. Num a => a -> a -> a
- Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
rank Array s a
a)) Array s a
a

-- | Reshape an array, cycling through the elements without regard to the original shape.
--
-- >>> pretty $ cycle @[2,2,2] (array @'[3] [1,2,3])
-- [[[1,2],
--   [3,1]],
--  [[2,3],
--   [1,2]]]
cycle ::
  forall s' s a.
  ( KnownNats s,
    KnownNats s'
  ) =>
  Array s a ->
  Array s' a
cycle :: forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
cycle Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ([Int] -> Int -> [Int]
S.shapen (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Array s a -> Int
forall a (s :: [Nat]). KnownNats s => Array s a -> Int
size Array s a
a) (Int -> Int) -> ([Int] -> Int) -> [Int] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
S.flatten (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s')) Array s a
a

-- | Change rank by adding new dimensions at the front, if the new rank is greater, or combining dimensions (from left to right) into rows, if the new rank is lower.
--
-- >>> shape (rerank (SNat @4) a)
-- [1,2,3,4]
-- >>> shape (rerank (SNat @2) a)
-- [6,4]
--
-- > flat == rerank 1
rerank ::
  forall r s s' a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (Rerank r s)
  ) =>
  SNat r -> Array s a -> Array s' a
rerank :: forall (r :: Nat) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Rerank r s)) =>
SNat r -> Array s a -> Array s' a
rerank SNat r
_ Array s a
a = Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape Array s a
a

-- | Change the order of dimensions.
--
-- >>> pretty $ reorder (Dims @[2,0,1]) a
-- [[[0,4,8],
--   [12,16,20]],
--  [[1,5,9],
--   [13,17,21]],
--  [[2,6,10],
--   [14,18,22]],
--  [[3,7,11],
--   [15,19,23]]]
reorder ::
  forall ds s s' a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (Reorder s ds)
  ) =>
  SNats ds ->
  Array s a ->
  Array s' a
reorder :: forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (Reorder s ds)) =>
SNats ds -> Array s a -> Array s' a
reorder SNats ds
SNats Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (\[Int]
s -> [Int] -> [Int] -> [Int] -> [Int]
S.insertDims (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) [Int]
s []) Array s a
a

-- | Remove single dimensions.
--
-- >>> let sq = array [1..24] :: Array '[2,1,3,4,1] Int
-- >>> shape $ squeeze sq
-- [2,3,4]
--
-- >>> shape $ squeeze (singleton 0)
-- []
squeeze ::
  forall s t a.
  ( KnownNats s,
    KnownNats t,
    t ~ Eval (Squeeze s)
  ) =>
  Array s a ->
  Array t a
squeeze :: forall (s :: [Nat]) (t :: [Nat]) a.
(KnownNats s, KnownNats t, t ~ Eval (Squeeze s)) =>
Array s a -> Array t a
squeeze = Array s a -> Array t a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape

-- | Insert a single dimension at the supplied position.
--
-- >>> shape $ elongate (SNat @1) a
-- [2,1,3,4]
-- >>> toDynamic $ elongate (SNat @0) (toScalar 1)
-- UnsafeArray [1] [1]
elongate ::
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (InsertDim d 1 s)
  ) =>
  Dim d ->
  Array s a ->
  Array s' a
elongate :: forall (s :: [Nat]) (s' :: [Nat]) (d :: Nat) a.
(KnownNats s, KnownNats s', s' ~ Eval (InsertDim d 1 s)) =>
Dim d -> Array s a -> Array s' a
elongate Dim d
_ Array s a
a = Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
Array s a -> Array s' a
unsafeModifyShape Array s a
a

-- | Reverse indices eg transposes the element A/ijk/ to A/kji/.
--
-- >>> (transpose a) ! [1,0,0] == a ! [0,0,1]
-- True
-- >>> pretty $ transpose (array @[2,2,2] [1..8])
-- [[[1,5],
--   [3,7]],
--  [[2,6],
--   [4,8]]]
transpose ::
  forall a s s'. (KnownNats s, KnownNats s', s' ~ Eval (Reverse s)) => Array s a -> Array s' a
transpose :: forall a (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (Reverse s)) =>
Array s a -> Array s' a
transpose Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute [Int] -> [Int]
forall a. [a] -> [a]
List.reverse Array s a
a

-- | Inflate (or replicate) an array by inserting a new dimension given a supplied dimension and size.
--
-- >>> pretty $ inflate (SNat @0) (SNat @2) (array @'[3] [0,1,2])
-- [[0,1,2],
--  [0,1,2]]
inflate ::
  forall s' s d x a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (InsertDim d x s)
  ) =>
  Dim d ->
  SNat x ->
  Array s a ->
  Array s' a
inflate :: forall (s' :: [Nat]) (s :: [Nat]) (d :: Nat) (x :: Nat) a.
(KnownNats s, KnownNats s', s' ~ Eval (InsertDim d x s)) =>
Dim d -> SNat x -> Array s a -> Array s' a
inflate SNat d
SNat SNat x
_ Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute (Int -> [Int] -> [Int]
S.deleteDim (forall (n :: Nat). KnownNat n => Int
valueOf @d)) Array s a
a

-- | Intercalate an array along dimensions.
--
-- >>> pretty $ intercalate (SNat @2) (konst @[2,3] 0) a
-- [[[0,0,1,0,2,0,3],
--   [4,0,5,0,6,0,7],
--   [8,0,9,0,10,0,11]],
--  [[12,0,13,0,14,0,15],
--   [16,0,17,0,18,0,19],
--   [20,0,21,0,22,0,23]]]
intercalate ::
  forall d ds n n' s si st a.
  ( KnownNats s,
    KnownNats si,
    KnownNats st,
    KnownNats ds,
    KnownNat n,
    KnownNat n',
    ds ~ '[d],
    si ~ Eval (DeleteDim d s),
    n ~ Eval (GetDim d s),
    n' ~ Eval ((Fcf.-) (Eval ((Fcf.+) n n)) 1),
    st ~ Eval (InsertDim d n' si)
  ) =>
  Dim d -> Array si a -> Array s a -> Array st a
intercalate :: forall (d :: Nat) (ds :: [Nat]) (n :: Nat) (n' :: Nat) (s :: [Nat])
       (si :: [Nat]) (st :: [Nat]) a.
(KnownNats s, KnownNats si, KnownNats st, KnownNats ds, KnownNat n,
 KnownNat n', ds ~ '[d], si ~ Eval (DeleteDim d s),
 n ~ Eval (GetDim d s), n' ~ Eval (Eval (n + n) - 1),
 st ~ Eval (InsertDim d n' si)) =>
Dim d -> Array si a -> Array s a -> Array st a
intercalate SNat d
SNat Array si a
i Array s a
a =
  Dims ds -> Array '[n'] (Array si a) -> Array st a
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins
    (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds)
    ( forall (n :: Nat) a t.
(FromVector t a, KnownNat n) =>
t -> Array '[n] a
vector @n'
        ( Array si a -> [Array si a] -> [Array si a]
forall a. a -> [a] -> [a]
List.intersperse
            Array si a
i
            (Array '[n] (Array si a) -> [Array si a]
forall a. Array '[n] a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Dims ds -> Array s a -> Array '[n] (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a))
        )
    )

-- | Intersperse an element along dimensions.
--
-- >>> pretty $ intersperse (SNat @2) 0 a
-- [[[0,0,1,0,2,0,3],
--   [4,0,5,0,6,0,7],
--   [8,0,9,0,10,0,11]],
--  [[12,0,13,0,14,0,15],
--   [16,0,17,0,18,0,19],
--   [20,0,21,0,22,0,23]]]
intersperse ::
  forall d ds n n' s si st a.
  ( KnownNats s,
    KnownNats si,
    KnownNats st,
    KnownNats ds,
    KnownNat n,
    KnownNat n',
    ds ~ '[d],
    si ~ Eval (DeleteDim d s),
    n ~ Eval (GetDim d s),
    n' ~ n + n - 1,
    st ~ Eval (InsertDim d n' si)
  ) =>
  Dim d -> a -> Array s a -> Array st a
intersperse :: forall (d :: Nat) (ds :: [Nat]) (n :: Nat) (n' :: Nat) (s :: [Nat])
       (si :: [Nat]) (st :: [Nat]) a.
(KnownNats s, KnownNats si, KnownNats st, KnownNats ds, KnownNat n,
 KnownNat n', ds ~ '[d], si ~ Eval (DeleteDim d s),
 n ~ Eval (GetDim d s), n' ~ ((n + n) - 1),
 st ~ Eval (InsertDim d n' si)) =>
Dim d -> a -> Array s a -> Array st a
intersperse (SNat d
SNat :: SNat d) a
x Array s a
a = SNat d -> Array si a -> Array s a -> Array st a
forall (d :: Nat) (ds :: [Nat]) (n :: Nat) (n' :: Nat) (s :: [Nat])
       (si :: [Nat]) (st :: [Nat]) a.
(KnownNats s, KnownNats si, KnownNats st, KnownNats ds, KnownNat n,
 KnownNat n', ds ~ '[d], si ~ Eval (DeleteDim d s),
 n ~ Eval (GetDim d s), n' ~ Eval (Eval (n + n) - 1),
 st ~ Eval (InsertDim d n' si)) =>
Dim d -> Array si a -> Array s a -> Array st a
intercalate (forall (n :: Nat). KnownNat n => SNat n
SNat @d) (forall (s :: [Nat]) a. KnownNats s => a -> Array s a
konst @si a
x) Array s a
a

-- | Concatenate dimensions, creating a new dimension at the supplied postion.
--
-- >>> pretty $ concats (Dims @[0,1]) (SNat @1) a
-- [[0,4,8,12,16,20],
--  [1,5,9,13,17,21],
--  [2,6,10,14,18,22],
--  [3,7,11,15,19,23]]
concats ::
  forall s s' newd ds a.
  ( KnownNats s,
    KnownNats s',
    s' ~ Eval (ConcatDims ds newd s)
  ) =>
  Dims ds ->
  SNat newd ->
  Array s a ->
  Array s' a
concats :: forall (s :: [Nat]) (s' :: [Nat]) (newd :: Nat) (ds :: [Nat]) a.
(KnownNats s, KnownNats s', s' ~ Eval (ConcatDims ds newd s)) =>
Dims ds -> SNat newd -> Array s a -> Array s' a
concats SNats ds
SNats SNat newd
SNat Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s' a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ([Int] -> Int -> [Int] -> [Int] -> [Int]
unconcatDimsIndex [Int]
ds Int
n (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a)) Array s a
a
  where
    n :: Int
n = forall (n :: Nat). KnownNat n => Int
valueOf @newd
    ds :: [Int]
ds = forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds

-- | Reverses element order along specified dimensions.
--
-- >>> pretty $ reverses (Dims @[0,1]) a
-- [[[20,21,22,23],
--   [16,17,18,19],
--   [12,13,14,15]],
--  [[8,9,10,11],
--   [4,5,6,7],
--   [0,1,2,3]]]
reverses ::
  forall ds s a.
  (KnownNats s) =>
  Dims ds ->
  Array s a ->
  Array s a
reverses :: forall (ds :: [Nat]) (s :: [Nat]) a.
KnownNats s =>
Dims ds -> Array s a -> Array s a
reverses SNats ds
SNats Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ([Int] -> [Int] -> [Int] -> [Int]
reverseIndex (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) (Array s a -> [Int]
forall a (s :: [Nat]). KnownNats s => Array s a -> [Int]
shape Array s a
a)) Array s a
a

-- | Rotate an array by/along dimensions & offsets.
--
-- >>> pretty $ rotates (Dims @'[1]) [2] a
-- [[[8,9,10,11],
--   [0,1,2,3],
--   [4,5,6,7]],
--  [[20,21,22,23],
--   [12,13,14,15],
--   [16,17,18,19]]]
rotates ::
  forall a ds s.
  ( KnownNats s,
    True ~ Eval (IsDims ds s)
  ) =>
  Dims ds ->
  [Int] ->
  Array s a ->
  Array s a
rotates :: forall a (ds :: [Nat]) (s :: [Nat]).
(KnownNats s, 'True ~ Eval (IsDims ds s)) =>
Dims ds -> [Int] -> Array s a -> Array s a
rotates SNats ds
SNats [Int]
rs Array s a
a = ([Int] -> [Int]) -> Array s a -> Array s a
forall (s' :: [Nat]) (s :: [Nat]) a.
(KnownNats s, KnownNats s') =>
([Int] -> [Int]) -> Array s a -> Array s' a
unsafeBackpermute ([Int] -> [Int] -> [Int] -> [Int] -> [Int]
rotatesIndex (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @ds) [Int]
rs (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s)) Array s a
a

-- | Sort an array along the supplied dimensions.
--
-- >>> pretty $ sorts (Dims @'[0]) (array @[2,2] [2,3,1,4])
-- [[1,4],
--  [2,3]]
-- >>> pretty $ sorts (Dims @'[1]) (array @[2,2] [2,3,1,4])
-- [[2,3],
--  [1,4]]
-- >>> pretty $ sorts (Dims @[0,1]) (array @[2,2] [2,3,1,4])
-- [[1,2],
--  [3,4]]
sorts ::
  forall ds s a si so.
  ( Ord a,
    KnownNats s,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s ~ Eval (InsertDims ds so si)
  ) =>
  Dims ds -> Array s a -> Array s a
sorts :: forall (ds :: [Nat]) (s :: [Nat]) a (si :: [Nat]) (so :: [Nat]).
(Ord a, KnownNats s, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds s), so ~ Eval (GetDims ds s),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds -> Array s a -> Array s a
sorts SNats ds
SNats Array s a
a = SNats ds -> Array so (Array si a) -> Array s a
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) (Array so (Array si a) -> Array s a)
-> Array so (Array si a) -> Array s a
forall a b. (a -> b) -> a -> b
$ (Vector (Array si a) -> Vector (Array si a))
-> Array so (Array si a) -> Array so (Array si a)
forall (s :: [Nat]) u a v b.
(KnownNats s, FromVector u a, FromVector v b) =>
(u -> v) -> Array s a -> Array s b
unsafeModifyVector Vector (Array si a) -> Vector (Array si a)
forall a. Ord a => Vector a -> Vector a
sortV (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | The indices into the array if it were sorted by a comparison function along the dimensions supplied.
--
-- >>> import Data.Ord (Down (..))
-- >>> toDynamic $ sortsBy (Dims @'[0]) (fmap Down) (array @[2,2] [2,3,1,4])
-- UnsafeArray [2,2] [2,3,1,4]
sortsBy ::
  forall ds s a b si so.
  ( Ord b,
    KnownNats s,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s ~ Eval (InsertDims ds so si)
  ) =>
  Dims ds -> (Array si a -> Array si b) -> Array s a -> Array s a
sortsBy :: forall (ds :: [Nat]) (s :: [Nat]) a b (si :: [Nat]) (so :: [Nat]).
(Ord b, KnownNats s, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds s), so ~ Eval (GetDims ds s),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds -> (Array si a -> Array si b) -> Array s a -> Array s a
sortsBy SNats ds
SNats Array si a -> Array si b
c Array s a
a = SNats ds -> Array so (Array si a) -> Array s a
forall a (ds :: [Nat]) (si :: [Nat]) (so :: [Nat]) (st :: [Nat]).
(KnownNats ds, KnownNats st, KnownNats si, KnownNats so,
 Eval (InsertDims ds so si) ~ st) =>
Dims ds -> Array so (Array si a) -> Array st a
joins (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) (Array so (Array si a) -> Array s a)
-> Array so (Array si a) -> Array s a
forall a b. (a -> b) -> a -> b
$ (Vector (Array si a) -> Vector (Array si a))
-> Array so (Array si a) -> Array so (Array si a)
forall (s :: [Nat]) u a v b.
(KnownNats s, FromVector u a, FromVector v b) =>
(u -> v) -> Array s a -> Array s b
unsafeModifyVector ((Array si a -> Array si b)
-> Vector (Array si a) -> Vector (Array si a)
forall b a. Ord b => (a -> b) -> Vector a -> Vector a
sortByV Array si a -> Array si b
c) (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | The indices into the array if it were sorted along the dimensions supplied.
--
-- >>> orders (Dims @'[0]) (array @[2,2] [2,3,1,4])
-- [1,0]
orders ::
  forall ds s a si so.
  ( Ord a,
    KnownNats s,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s ~ Eval (InsertDims ds so si)
  ) =>
  Dims ds -> Array s a -> Array so Int
orders :: forall (ds :: [Nat]) (s :: [Nat]) a (si :: [Nat]) (so :: [Nat]).
(Ord a, KnownNats s, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds s), so ~ Eval (GetDims ds s),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds -> Array s a -> Array so Int
orders SNats ds
SNats Array s a
a = (Vector (Array si a) -> Vector Int)
-> Array so (Array si a) -> Array so Int
forall (s :: [Nat]) u a v b.
(KnownNats s, FromVector u a, FromVector v b) =>
(u -> v) -> Array s a -> Array s b
unsafeModifyVector Vector (Array si a) -> Vector Int
forall a. Ord a => Vector a -> Vector Int
orderV (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | The indices into the array if it were sorted by a comparison function along the dimensions supplied.
--
-- >>> import Data.Ord (Down (..))
-- >>> ordersBy (Dims @'[0]) (fmap Down) (array @[2,2] [2,3,1,4])
-- [0,1]
ordersBy ::
  forall ds s a b si so.
  ( Ord b,
    KnownNats s,
    KnownNats si,
    KnownNats so,
    si ~ Eval (DeleteDims ds s),
    so ~ Eval (GetDims ds s),
    s ~ Eval (InsertDims ds so si)
  ) =>
  Dims ds -> (Array si a -> Array si b) -> Array s a -> Array so Int
ordersBy :: forall (ds :: [Nat]) (s :: [Nat]) a b (si :: [Nat]) (so :: [Nat]).
(Ord b, KnownNats s, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds s), so ~ Eval (GetDims ds s),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds -> (Array si a -> Array si b) -> Array s a -> Array so Int
ordersBy SNats ds
SNats Array si a -> Array si b
c Array s a
a = (Vector (Array si a) -> Vector Int)
-> Array so (Array si a) -> Array so Int
forall (s :: [Nat]) u a v b.
(KnownNats s, FromVector u a, FromVector v b) =>
(u -> v) -> Array s a -> Array s b
unsafeModifyVector ((Array si a -> Array si b) -> Vector (Array si a) -> Vector Int
forall b a. Ord b => (a -> b) -> Vector a -> Vector Int
orderByV Array si a -> Array si b
c) (SNats ds -> Array s a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | Apply a binary array function to two arrays with matching shapes across the supplied (matching) dimensions.
--
-- >>> a = array @[2,3] [0..5]
-- >>> b = array @'[3] [6..8]
-- >>> pretty $ telecasts (Dims @'[1]) (Dims @'[0]) (concatenate (SNat @0)) a b
-- [[0,3,6],
--  [1,4,7],
--  [2,5,8]]
telecasts ::
  forall sa sb sc sia sib sic ma mb a b c soa sob ds.
  ( KnownNats sa,
    KnownNats sb,
    KnownNats sc,
    KnownNats sia,
    KnownNats sib,
    KnownNats sic,
    KnownNats soa,
    KnownNats sob,
    KnownNats ds,
    ds ~ Eval (DimsOf soa),
    sia ~ Eval (DeleteDims ma sa),
    sib ~ Eval (DeleteDims mb sb),
    soa ~ Eval (GetDims ma sa),
    sob ~ Eval (GetDims mb sb),
    soa ~ sob,
    sc ~ Eval (InsertDims ds soa sic)
  ) =>
  SNats ma -> SNats mb -> (Array sia a -> Array sib b -> Array sic c) -> Array sa a -> Array sb b -> Array sc c
telecasts :: forall (sa :: [Nat]) (sb :: [Nat]) (sc :: [Nat]) (sia :: [Nat])
       (sib :: [Nat]) (sic :: [Nat]) (ma :: [Nat]) (mb :: [Nat]) a b c
       (soa :: [Nat]) (sob :: [Nat]) (ds :: [Nat]).
(KnownNats sa, KnownNats sb, KnownNats sc, KnownNats sia,
 KnownNats sib, KnownNats sic, KnownNats soa, KnownNats sob,
 KnownNats ds, ds ~ Eval (DimsOf soa),
 sia ~ Eval (DeleteDims ma sa), sib ~ Eval (DeleteDims mb sb),
 soa ~ Eval (GetDims ma sa), sob ~ Eval (GetDims mb sb), soa ~ sob,
 sc ~ Eval (InsertDims ds soa sic)) =>
SNats ma
-> SNats mb
-> (Array sia a -> Array sib b -> Array sic c)
-> Array sa a
-> Array sb b
-> Array sc c
telecasts SNats ma
SNats SNats mb
SNats Array sia a -> Array sib b -> Array sic c
f Array sa a
a Array sb b
b = Array sob (Array sic c) -> Array sc c
forall a (si :: [Nat]) (so :: [Nat]) (st :: [Nat]) (ds :: [Nat]).
(KnownNats st, KnownNats si, KnownNats so, KnownNats ds,
 ds ~ Eval (DimsOf so), st ~ Eval (InsertDims ds so si)) =>
Array so (Array si a) -> Array st a
join ((Array sia a -> Array sib b -> Array sic c)
-> Array sob (Array sia a)
-> Array sob (Array sib b)
-> Array sob (Array sic c)
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith Array sia a -> Array sib b -> Array sic c
f (SNats ma -> Array sa a -> Array sob (Array sia a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @ma) Array sa a
a) (SNats mb -> Array sb b -> Array sob (Array sib b)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(KnownNats st, KnownNats ds, KnownNats si, KnownNats so,
 si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st)) =>
Dims ds -> Array st a -> Array so (Array si a)
extracts (forall (ns :: [Nat]). KnownNats ns => SNats ns
SNats @mb) Array sb b
b))

-- | Apply a binary array function to two arrays where the shape of the first array is a prefix of the second array.
--
-- >>> a = array @[2,3] [0..5]
-- >>> pretty $ transmit (zipWith (+)) (toScalar 1) a
-- [[1,2,3],
--  [4,5,6]]
transmit ::
  forall sa sb sc a b c ds sib sic sob.
  ( KnownNats sa,
    KnownNats sb,
    KnownNats sc,
    KnownNats ds,
    KnownNats sib,
    KnownNats sic,
    KnownNats sob,
    ds ~ Eval (EnumFromTo (Eval (Rank sa)) (Eval (Rank sb) - 1)),
    sib ~ Eval (DeleteDims ds sb),
    sob ~ Eval (GetDims ds sb),
    sb ~ Eval (InsertDims ds sob sib),
    sc ~ Eval (InsertDims ds sob sic),
    True ~ Eval (IsPrefixOf sa sb)
  ) =>
  (Array sa a -> Array sib b -> Array sic c) -> Array sa a -> Array sb b -> Array sc c
transmit :: forall (sa :: [Nat]) (sb :: [Nat]) (sc :: [Nat]) a b c
       (ds :: [Nat]) (sib :: [Nat]) (sic :: [Nat]) (sob :: [Nat]).
(KnownNats sa, KnownNats sb, KnownNats sc, KnownNats ds,
 KnownNats sib, KnownNats sic, KnownNats sob,
 ds ~ Eval (EnumFromTo (Eval (Rank sa)) (Eval (Rank sb) - 1)),
 sib ~ Eval (DeleteDims ds sb), sob ~ Eval (GetDims ds sb),
 sb ~ Eval (InsertDims ds sob sib),
 sc ~ Eval (InsertDims ds sob sic),
 'True ~ Eval (IsPrefixOf sa sb)) =>
(Array sa a -> Array sib b -> Array sic c)
-> Array sa a -> Array sb b -> Array sc c
transmit Array sa a -> Array sib b -> Array sic c
f Array sa a
a Array sb b
b = Dims ds -> (Array sib b -> Array sic c) -> Array sb b -> Array sc c
forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) (si :: [Nat])
       (si' :: [Nat]) (so :: [Nat]) a b.
(KnownNats s, KnownNats s', KnownNats si, KnownNats si',
 KnownNats so, si ~ Eval (DeleteDims ds s),
 so ~ Eval (GetDims ds s), s' ~ Eval (InsertDims ds so si'),
 s ~ Eval (InsertDims ds so si)) =>
Dims ds -> (Array si a -> Array si' b) -> Array s a -> Array s' b
maps (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) (Array sa a -> Array sib b -> Array sic c
f Array sa a
a) Array sb b
b

-- | A one-dimensional array.
type Vector s a = Array '[s] a

-- | Create a one-dimensional array.
--
-- >>> pretty $ vector @3 @Int [2,3,4]
-- [2,3,4]
vector ::
  forall n a t.
  ( FromVector t a,
    KnownNat n
  ) =>
  t ->
  Array '[n] a
vector :: forall (n :: Nat) a t.
(FromVector t a, KnownNat n) =>
t -> Array '[n] a
vector t
xs = t -> Array '[n] a
forall (s :: [Nat]) a t.
(KnownNats s, FromVector t a) =>
t -> Array s a
array t
xs

-- | vector with an explicit SNat rather than a KnownNat constraint.
--
-- >>> pretty $ vector' @Int (SNat @3) [2,3,4]
-- [2,3,4]
vector' ::
  forall a n t.
  (FromVector t a) =>
  SNat n ->
  t ->
  Array '[n] a
vector' :: forall a (n :: Nat) t.
FromVector t a =>
SNat n -> t -> Array '[n] a
vector' SNat n
n t
xs = SNat n -> (KnownNat n => Array '[n] a) -> Array '[n] a
forall (n :: Nat) r. SNat n -> (KnownNat n => r) -> r
withKnownNat SNat n
n (t -> Array '[n] a
forall (n :: Nat) a t.
(FromVector t a, KnownNat n) =>
t -> Array '[n] a
vector t
xs)

-- | Vector specialisation of 'range'
--
-- >>> toDynamic $ iota @5
-- UnsafeArray [5] [0,1,2,3,4]
iota :: forall n. (KnownNat n) => Vector n Int
iota :: forall (n :: Nat). KnownNat n => Vector n Int
iota = Array '[n] Int
forall (s :: [Nat]). KnownNats s => Array s Int
range

-- | A two-dimensional array.
type Matrix m n a = Array '[m, n] a

-- * row (first dimension) specializations

-- | Add a new row
--
-- >>> pretty $ cons (array @'[2] [0,1]) (array @[2,2] [2,3,4,5])
-- [[0,1],
--  [2,3],
--  [4,5]]
cons ::
  forall st s sh a.
  ( KnownNats st,
    KnownNats s,
    KnownNats sh,
    True ~ Eval (InsertOk 0 st sh),
    s ~ Eval (IncAt 0 st),
    sh ~ Eval (DeleteDim 0 st)
  ) =>
  Array sh a -> Array st a -> Array s a
cons :: forall (st :: [Nat]) (s :: [Nat]) (sh :: [Nat]) a.
(KnownNats st, KnownNats s, KnownNats sh,
 'True ~ Eval (InsertOk 0 st sh), s ~ Eval (IncAt 0 st),
 sh ~ Eval (DeleteDim 0 st)) =>
Array sh a -> Array st a -> Array s a
cons =
  Dim 0 -> Array sh a -> Array st a -> Array s a
forall a (d :: Nat) (s :: [Nat]) (si :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Array si a -> Array s a -> Array s' a
prepend (forall (n :: Nat). KnownNat n => SNat n
SNat @0)

-- | Add a new row at the end
--
-- >>> pretty $ snoc (array @[2,2] [0,1,2,3]) (array @'[2] [4,5])
-- [[0,1],
--  [2,3],
--  [4,5]]
snoc ::
  forall si s sl a.
  ( KnownNats si,
    KnownNats s,
    KnownNats sl,
    True ~ Eval (InsertOk 0 si sl),
    s ~ Eval (IncAt 0 si),
    sl ~ Eval (DeleteDim 0 si)
  ) =>
  Array si a -> Array sl a -> Array s a
snoc :: forall (si :: [Nat]) (s :: [Nat]) (sl :: [Nat]) a.
(KnownNats si, KnownNats s, KnownNats sl,
 'True ~ Eval (InsertOk 0 si sl), s ~ Eval (IncAt 0 si),
 sl ~ Eval (DeleteDim 0 si)) =>
Array si a -> Array sl a -> Array s a
snoc = Dim 0 -> Array si a -> Array sl a -> Array s a
forall a (d :: Nat) (s :: [Nat]) (si :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats si, KnownNats s', s' ~ Eval (IncAt d s),
 'True ~ Eval (InsertOk d s si)) =>
Dim d -> Array s a -> Array si a -> Array s' a
append (forall (n :: Nat). KnownNat n => SNat n
SNat @0)

-- | split an array into the first row and the remaining rows.
--
-- >>> import Data.Bifunctor (bimap)
-- >>> bimap toDynamic toDynamic $ uncons (array @[3,2] [0..5])
-- (UnsafeArray [2] [0,1],UnsafeArray [2,2] [2,3,4,5])
uncons ::
  forall a s sh st ls os ds.
  ( KnownNats s,
    KnownNats sh,
    KnownNats st,
    ds ~ '[0],
    sh ~ Eval (DeleteDims ds s),
    KnownNats ls,
    KnownNats os,
    os ~ Eval (Replicate (Eval (Rank ds)) 1),
    ls ~ Eval (GetLastPositions ds s),
    Eval (SlicesOk ds os ls s) ~ True,
    st ~ Eval (SetDims ds ls s)
  ) =>
  Array s a -> (Array sh a, Array st a)
uncons :: forall a (s :: [Nat]) (sh :: [Nat]) (st :: [Nat]) (ls :: [Nat])
       (os :: [Nat]) (ds :: [Nat]).
(KnownNats s, KnownNats sh, KnownNats st, ds ~ '[0],
 sh ~ Eval (DeleteDims ds s), KnownNats ls, KnownNats os,
 os ~ Eval (Replicate (Eval (Rank ds)) 1),
 ls ~ Eval (GetLastPositions ds s),
 Eval (SlicesOk ds os ls s) ~ 'True, st ~ Eval (SetDims ds ls s)) =>
Array s a -> (Array sh a, Array st a)
uncons Array s a
a = (Dims ds -> Array s a -> Array sh a
forall a (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', KnownNats ds,
 s' ~ Eval (DeleteDims ds s)) =>
Dims ds -> Array s a -> Array s' a
heads (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a, Dims ds -> Array s a -> Array st a
forall (ds :: [Nat]) (os :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a
       (ls :: [Nat]).
(KnownNats s, KnownNats ds, KnownNats s', KnownNats ls,
 KnownNats os, Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 1),
 ls ~ Eval (GetLastPositions ds s), s' ~ Eval (SetDims ds ls s)) =>
Dims ds -> Array s a -> Array s' a
tails (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | split an array into the initial rows and the last row.
--
-- >>> import Data.Bifunctor (bimap)
-- >>> bimap toDynamic toDynamic $ unsnoc (array @[3,2] [0..5])
-- (UnsafeArray [2,2] [0,1,2,3],UnsafeArray [2] [4,5])
unsnoc ::
  forall ds os s a ls si sl.
  ( KnownNats s,
    KnownNats ds,
    KnownNats si,
    KnownNats ls,
    KnownNats os,
    KnownNats sl,
    ds ~ '[0],
    Eval (SlicesOk ds os ls s) ~ True,
    os ~ Eval (Replicate (Eval (Rank ds)) 0),
    ls ~ Eval (GetLastPositions ds s),
    si ~ Eval (SetDims ds ls s),
    sl ~ Eval (DeleteDims ds s)
  ) =>
  Array s a -> (Array si a, Array sl a)
unsnoc :: forall (ds :: [Nat]) (os :: [Nat]) (s :: [Nat]) a (ls :: [Nat])
       (si :: [Nat]) (sl :: [Nat]).
(KnownNats s, KnownNats ds, KnownNats si, KnownNats ls,
 KnownNats os, KnownNats sl, ds ~ '[0],
 Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 0),
 ls ~ Eval (GetLastPositions ds s), si ~ Eval (SetDims ds ls s),
 sl ~ Eval (DeleteDims ds s)) =>
Array s a -> (Array si a, Array sl a)
unsnoc Array s a
a = (Dims ds -> Array s a -> Array si a
forall (ds :: [Nat]) (os :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a
       (ls :: [Nat]).
(KnownNats s, KnownNats ds, KnownNats s', KnownNats ls,
 KnownNats os, Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 0),
 ls ~ Eval (GetLastPositions ds s), s' ~ Eval (SetDims ds ls s)) =>
Dims ds -> Array s a -> Array s' a
inits (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a, Dims ds -> Array s a -> Array sl a
forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(KnownNats s, KnownNats ds, KnownNats s',
 s' ~ Eval (DeleteDims ds s)) =>
Dims ds -> Array s a -> Array s' a
lasts (forall (ns :: [Nat]). KnownNats ns => SNats ns
Dims @ds) Array s a
a)

-- | Convenience pattern for row extraction and consolidation at the beginning of an Array.
--
-- >>> (x:<xs) = array @'[4] [0..3]
-- >>> toDynamic x
-- UnsafeArray [] [0]
-- >>> toDynamic xs
-- UnsafeArray [3] [1,2,3]
-- >>> toDynamic (x:<xs)
-- UnsafeArray [4] [0,1,2,3]
pattern (:<) ::
  forall s sh st a os ls ds.
  ( KnownNats s,
    KnownNats sh,
    KnownNats st,
    True ~ Eval (InsertOk 0 st sh),
    s ~ Eval (IncAt 0 st),
    ds ~ '[0],
    sh ~ Eval (DeleteDims ds s),
    KnownNats ls,
    KnownNats os,
    Eval (SlicesOk ds os ls s) ~ True,
    os ~ Eval (Replicate (Eval (Rank ds)) 1),
    ls ~ Eval (GetLastPositions ds s),
    st ~ Eval (SetDims ds ls s)
  ) =>
  Array sh a -> Array st a -> Array s a
pattern x $m:< :: forall {r} {s :: [Nat]} {sh :: [Nat]} {st :: [Nat]} {a}
       {os :: [Nat]} {ls :: [Nat]} {ds :: [Nat]}.
(KnownNats s, KnownNats sh, KnownNats st,
 'True ~ Eval (InsertOk 0 st sh), s ~ Eval (IncAt 0 st), ds ~ '[0],
 sh ~ Eval (DeleteDims ds s), KnownNats ls, KnownNats os,
 Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 1),
 ls ~ Eval (GetLastPositions ds s), st ~ Eval (SetDims ds ls s)) =>
Array s a -> (Array sh a -> Array st a -> r) -> ((# #) -> r) -> r
$b:< :: forall (s :: [Nat]) (sh :: [Nat]) (st :: [Nat]) a (os :: [Nat])
       (ls :: [Nat]) (ds :: [Nat]).
(KnownNats s, KnownNats sh, KnownNats st,
 'True ~ Eval (InsertOk 0 st sh), s ~ Eval (IncAt 0 st), ds ~ '[0],
 sh ~ Eval (DeleteDims ds s), KnownNats ls, KnownNats os,
 Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 1),
 ls ~ Eval (GetLastPositions ds s), st ~ Eval (SetDims ds ls s)) =>
Array sh a -> Array st a -> Array s a
:< xs <- (uncons -> (x, xs))
  where
    Array sh a
x :< Array st a
xs = Array sh a -> Array st a -> Array s a
forall (st :: [Nat]) (s :: [Nat]) (sh :: [Nat]) a.
(KnownNats st, KnownNats s, KnownNats sh,
 'True ~ Eval (InsertOk 0 st sh), s ~ Eval (IncAt 0 st),
 sh ~ Eval (DeleteDim 0 st)) =>
Array sh a -> Array st a -> Array s a
cons Array sh a
x Array st a
xs

infix 5 :<

{-# COMPLETE (:<) :: Array #-}

-- | Convenience pattern for row extraction and consolidation at the end of an Array.
--
-- >>> (xs:>x) = array @'[4] [0..3]
-- >>> toDynamic x
-- UnsafeArray [] [3]
-- >>> toDynamic xs
-- UnsafeArray [3] [0,1,2]
-- >>> toDynamic (xs:>x)
-- UnsafeArray [4] [0,1,2,3]
pattern (:>) ::
  forall si sl s a ds ls os.
  ( KnownNats si,
    KnownNats sl,
    KnownNats s,
    True ~ Eval (InsertOk 0 si sl),
    s ~ Eval (IncAt 0 si),
    KnownNats ds,
    KnownNats ls,
    KnownNats os,
    sl ~ Eval (DeleteDim 0 si),
    ds ~ '[0],
    Eval (SlicesOk ds os ls s) ~ True,
    os ~ Eval (Replicate (Eval (Rank ds)) 0),
    ls ~ Eval (GetLastPositions ds s),
    si ~ Eval (SetDims ds ls s),
    sl ~ Eval (DeleteDims ds s)
  ) =>
  Array si a -> Array sl a -> Array s a
pattern xs $m:> :: forall {r} {si :: [Nat]} {sl :: [Nat]} {s :: [Nat]} {a}
       {ds :: [Nat]} {ls :: [Nat]} {os :: [Nat]}.
(KnownNats si, KnownNats sl, KnownNats s,
 'True ~ Eval (InsertOk 0 si sl), s ~ Eval (IncAt 0 si),
 KnownNats ds, KnownNats ls, KnownNats os,
 sl ~ Eval (DeleteDim 0 si), ds ~ '[0],
 Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 0),
 ls ~ Eval (GetLastPositions ds s), si ~ Eval (SetDims ds ls s),
 sl ~ Eval (DeleteDims ds s)) =>
Array s a -> (Array si a -> Array sl a -> r) -> ((# #) -> r) -> r
$b:> :: forall (si :: [Nat]) (sl :: [Nat]) (s :: [Nat]) a (ds :: [Nat])
       (ls :: [Nat]) (os :: [Nat]).
(KnownNats si, KnownNats sl, KnownNats s,
 'True ~ Eval (InsertOk 0 si sl), s ~ Eval (IncAt 0 si),
 KnownNats ds, KnownNats ls, KnownNats os,
 sl ~ Eval (DeleteDim 0 si), ds ~ '[0],
 Eval (SlicesOk ds os ls s) ~ 'True,
 os ~ Eval (Replicate (Eval (Rank ds)) 0),
 ls ~ Eval (GetLastPositions ds s), si ~ Eval (SetDims ds ls s),
 sl ~ Eval (DeleteDims ds s)) =>
Array si a -> Array sl a -> Array s a
:> x <- (unsnoc -> (xs, x))
  where
    Array si a
xs :> Array sl a
x = Array si a -> Array sl a -> Array s a
forall (si :: [Nat]) (s :: [Nat]) (sl :: [Nat]) a.
(KnownNats si, KnownNats s, KnownNats sl,
 'True ~ Eval (InsertOk 0 si sl), s ~ Eval (IncAt 0 si),
 sl ~ Eval (DeleteDim 0 si)) =>
Array si a -> Array sl a -> Array s a
snoc Array si a
xs Array sl a
x

infix 5 :>

{-# COMPLETE (:>) :: Array #-}

-- | Generate an array of uniform random variates between a range.
--
-- >>> import System.Random.Stateful hiding (uniform)
-- >>> g <- newIOGenM (mkStdGen 42)
-- >>> u <- uniform @[2,3,4] @Int g (0,9)
-- >>> pretty u
-- [[[0,7,0,2],
--   [1,7,4,2],
--   [5,9,8,2]],
--  [[9,8,1,0],
--   [2,2,8,2],
--   [2,8,0,6]]]
uniform ::
  forall s a g m.
  ( StatefulGen g m,
    UniformRange a,
    KnownNats s
  ) =>
  g -> (a, a) -> m (Array s a)
uniform :: forall (s :: [Nat]) a g (m :: * -> *).
(StatefulGen g m, UniformRange a, KnownNats s) =>
g -> (a, a) -> m (Array s a)
uniform g
g (a, a)
r = do
  Vector a
v <- Int -> m a -> m (Vector a)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM ([Int] -> Int
S.size (forall (s :: [Nat]). KnownNats s => [Int]
valuesOf @s)) ((a, a) -> g -> m a
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (a, a) -> g -> m a
uniformRM (a, a)
r g
g)
  Array s a -> m (Array s a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array s a -> m (Array s a)) -> Array s a -> m (Array s a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Array s a
forall (s :: [Nat]) a t.
(KnownNats s, FromVector t a) =>
t -> Array s a
array Vector a
v

-- | Inverse of a square matrix.
--
-- > A.mult (D.inverse a) a == a
--
-- >>> e = array @[3,3] @Double [4,12,-16,12,37,-43,-16,-43,98]
-- >>> pretty (inverse e)
-- [[49.36111111111111,-13.555555555555554,2.1111111111111107],
--  [-13.555555555555554,3.7777777777777772,-0.5555555555555555],
--  [2.1111111111111107,-0.5555555555555555,0.1111111111111111]]
inverse :: (Eq a, Floating a, KnownNat m) => Matrix m m a -> Matrix m m a
inverse :: forall a (m :: Nat).
(Eq a, Floating a, KnownNat m) =>
Matrix m m a -> Matrix m m a
inverse Matrix m m a
a = Matrix m m a -> Matrix m m a -> Matrix m m a
forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
       (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Num a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult (Matrix m m a -> Matrix m m a
forall a (n :: Nat).
(KnownNat n, Floating a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri (Matrix m m a -> Matrix m m a
forall a (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (Reverse s)) =>
Array s a -> Array s' a
transpose (Matrix m m a -> Matrix m m a
forall (m :: Nat) a.
(KnownNat m, Floating a, KnownNats '[m, m]) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a))) (Matrix m m a -> Matrix m m a
forall a (n :: Nat).
(KnownNat n, Floating a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri (Matrix m m a -> Matrix m m a
forall (m :: Nat) a.
(KnownNat m, Floating a, KnownNats '[m, m]) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a))

-- | [Inversion of a Triangular Matrix](https://math.stackexchange.com/questions/1003801/inverse-of-an-invertible-upper-triangular-matrix-of-order-3)
--
-- >>> t = array @[3,3] @Double [1,0,1,0,1,2,0,0,1]
-- >>> pretty (invtri t)
-- [[1.0,0.0,-1.0],
--  [0.0,1.0,-2.0],
--  [0.0,0.0,1.0]]
--
-- >>> ident == mult t (invtri t)
-- True
invtri :: forall a n. (KnownNat n, Floating a, Eq a) => Matrix n n a -> Matrix n n a
invtri :: forall a (n :: Nat).
(KnownNat n, Floating a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri Matrix n n a
a = Matrix n n a
i
  where
    ti :: Array ('[n] <> '[n]) a
ti = Array '[n] a -> Array ('[n] <> '[n]) a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (s ++ s), Num a) =>
Array s a -> Array s' a
undiag ((a -> a) -> Array '[n] a -> Array '[n] a
forall a b. (a -> b) -> Array '[n] a -> Array '[n] b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Fractional a => a -> a
recip (Matrix n n a -> Array '[n] a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag Matrix n n a
a))
    tl :: Matrix n n a
tl = (a -> a -> a) -> Matrix n n a -> Matrix n n a -> Matrix n n a
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith (-) Matrix n n a
a (Array '[n] a -> Matrix n n a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (s ++ s), Num a) =>
Array s a -> Array s' a
undiag (Matrix n n a -> Array '[n] a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag Matrix n n a
a))
    l :: Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
l = (a -> a)
-> Array
     (Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
      <> Eval
           (Foldl'
              (Flip DeleteDim)
              '[n, n]
              (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
     a
-> Array
     (Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
      <> Eval
           (Foldl'
              (Flip DeleteDim)
              '[n, n]
              (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
     a
forall a b.
(a -> b)
-> Array
     (Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
      <> Eval
           (Foldl'
              (Flip DeleteDim)
              '[n, n]
              (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
     a
-> Array
     (Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
      <> Eval
           (Foldl'
              (Flip DeleteDim)
              '[n, n]
              (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
     b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate ((Array (Eval (Map (Flip GetDim '[n, n]) '[1])) a -> a)
-> (a -> a -> a)
-> Array ('[n] <> '[n]) a
-> Matrix n n a
-> Array
     (Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
      <> Eval
           (Foldl'
              (Flip DeleteDim)
              '[n, n]
              (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
     a
forall a b c d (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat])
       (s1 :: [Nat]) (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat])
       (si :: [Nat]).
(KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
(Array si c -> d)
-> (a -> b -> c) -> Array s0 a -> Array s1 b -> Array st d
dot Array (Eval (Map (Flip GetDim '[n, n]) '[1])) a -> a
forall a.
Num a =>
Array (Eval (Map (Flip GetDim '[n, n]) '[1])) a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum a -> a -> a
forall a. Num a => a -> a -> a
(*) Array ('[n] <> '[n]) a
ti Matrix n n a
tl)
    pow :: Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
-> Int -> Matrix n n a
pow Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
xs Int
x = ((Matrix n n a -> Matrix n n a) -> Matrix n n a -> Matrix n n a)
-> Matrix n n a -> [Matrix n n a -> Matrix n n a] -> Matrix n n a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Matrix n n a -> Matrix n n a) -> Matrix n n a -> Matrix n n a
forall a b. (a -> b) -> a -> b
($) (forall (s :: [Nat]) a. (KnownNats s, Num a) => Array s a
ident @[n, n]) (Int
-> (Matrix n n a -> Matrix n n a) -> [Matrix n n a -> Matrix n n a]
forall a. Int -> a -> [a]
replicate Int
x (Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
-> Matrix n n a -> Matrix n n a
forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
       (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Num a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
xs))
    zero' :: Matrix n n a
zero' = forall (s :: [Nat]) a. KnownNats s => a -> Array s a
konst @[n, n] a
0
    add :: Matrix n n a -> Matrix n n a -> Matrix n n a
add = (a -> a -> a) -> Matrix n n a -> Matrix n n a -> Matrix n n a
forall (s :: [Nat]) a b c.
KnownNats s =>
(a -> b -> c) -> Array s a -> Array s b -> Array s c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+)
    sum' :: Array '[n] (Matrix n n a) -> Matrix n n a
sum' = (Matrix n n a -> Matrix n n a -> Matrix n n a)
-> Matrix n n a -> Array '[n] (Matrix n n a) -> Matrix n n a
forall b a. (b -> a -> b) -> b -> Array '[n] a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Matrix n n a -> Matrix n n a -> Matrix n n a
add Matrix n n a
zero'
    i :: Matrix n n a
i = Matrix n n a -> Array ('[n] <> '[n]) a -> Matrix n n a
forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
       (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Num a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
 KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
 so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
 si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
 st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
 ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult (Array '[n] (Matrix n n a) -> Matrix n n a
sum' ((Int -> Matrix n n a)
-> Array '[n] Int -> Array '[n] (Matrix n n a)
forall a b. (a -> b) -> Array '[n] a -> Array '[n] b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
-> Int -> Matrix n n a
pow Array
  (Eval
     (Foldl'
        (Flip DeleteDim)
        '[n, n]
        (Eval (Rev (Eval (PreDeletePositionsGo '[1] '[])) '[])))
   <> Eval
        (Foldl'
           (Flip DeleteDim)
           '[n, n]
           (Eval (Rev (Eval (PreDeletePositionsGo '[0] '[])) '[]))))
  a
l) (forall (s :: [Nat]). KnownNats s => Array s Int
range @'[n]))) Array ('[n] <> '[n]) a
ti

-- | Cholesky decomposition using the <https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky_algorithm Cholesky-Crout> algorithm.
--
-- >>> e = array @[3,3] @Double [4,12,-16,12,37,-43,-16,-43,98]
-- >>> pretty (chol e)
-- [[2.0,0.0,0.0],
--  [6.0,1.0,0.0],
--  [-8.0,5.0,3.0]]
-- >>> mult (chol e) (transpose (chol e)) == e
-- True
chol :: (KnownNat m, Floating a, KnownNats '[m, m]) => Matrix m m a -> Matrix m m a
chol :: forall (m :: Nat) a.
(KnownNat m, Floating a, KnownNats '[m, m]) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a = Matrix m m a
l
  where
    l :: Matrix m m a
l = (Rep (Array '[m, m]) -> a) -> Matrix m m a
forall a. (Rep (Array '[m, m]) -> a) -> Array '[m, m] a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array '[m, m])
s -> Int -> Matrix m m a -> Fins '[m, m] -> a -> a
forall a (m :: Nat).
(Floating a, KnownNat m) =>
Int -> Matrix m m a -> Fins '[m, m] -> a -> a
norm_ Int
1 Matrix m m a
l Rep (Array '[m, m])
Fins '[m, m]
s (Matrix m m a -> Rep (Array '[m, m]) -> a
forall a. Array '[m, m] a -> Rep (Array '[m, m]) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Matrix m m a
a Rep (Array '[m, m])
s a -> a -> a
forall a. Num a => a -> a -> a
- Matrix m m a -> Fins '[m, m] -> a
forall a (m :: Nat).
(Num a, KnownNat m) =>
Matrix m m a -> Fins '[m, m] -> a
cross_ Matrix m m a
l Rep (Array '[m, m])
Fins '[m, m]
s))

norm_ :: (Floating a, KnownNat m) => Int -> Matrix m m a -> Fins '[m, m] -> a -> a
norm_ :: forall a (m :: Nat).
(Floating a, KnownNat m) =>
Int -> Matrix m m a -> Fins '[m, m] -> a -> a
norm_ Int
d Matrix m m a
l (UnsafeFins [Int]
s) = (a -> a) -> (a -> a) -> Bool -> a -> a
forall a. a -> a -> Bool -> a
bool (a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Matrix m m a -> Array '[m] a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag Matrix m m a
l Array '[m] a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
! [Int -> [Int] -> Int
S.getDim Int
d [Int]
s] *) a -> a
forall a. Floating a => a -> a
sqrt ([Int] -> Bool
forall a. Eq a => [a] -> Bool
S.isDiag [Int]
s)

cross_ :: (Num a, KnownNat m) => Matrix m m a -> Fins '[m, m] -> a
cross_ :: forall a (m :: Nat).
(Num a, KnownNat m) =>
Matrix m m a -> Fins '[m, m] -> a
cross_ Matrix m m a
l Fins '[m, m]
s = Array a -> a
forall a. Num a => Array a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> a) -> Array Int -> Array a
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
k -> Matrix m m a
l Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
! [Int
i, Int
k] a -> a -> a
forall a. Num a => a -> a -> a
* Matrix m m a
l Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
! [Int
j, Int
k]) ([Int] -> Array Int
A.range [Int
j]))
  where
    [Int
i, Int
j] = Fins '[m, m] -> [Int]
forall {k} (s :: k). Fins s -> [Int]
fromFins Fins '[m, m]
s