{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}

-- | Arrays with shape information and computations at a value-level.
module Harpie.Array
  ( -- * Usage
    -- $usage

    -- * Harpie Arrays
    Array (..),
    array,
    (><),
    validate,
    safeArray,
    unsafeModifyShape,
    unsafeModifyVector,

    -- * Dimensions
    Dim,
    Dims,

    -- * Conversion
    FromVector (..),
    FromArray (..),

    -- * Shape Access
    shape,
    rank,
    size,
    length,
    isNull,

    -- * Indexing
    index,
    (!),
    (!?),
    tabulate,
    backpermute,

    -- * Scalars
    fromScalar,
    toScalar,
    isScalar,
    asSingleton,
    asScalar,

    -- * Array Creation
    empty,
    range,
    corange,
    indices,
    ident,
    konst,
    singleton,
    diag,
    undiag,

    -- * Element-level functions
    zipWith,
    zipWithSafe,
    modify,
    imap,

    -- * Function generalisers
    rowWise,
    colWise,
    dimsWise,

    -- * Single-dimension functions
    take,
    drop,
    select,
    insert,
    delete,
    append,
    prepend,
    concatenate,
    couple,
    slice,
    rotate,

    -- * Multi-dimension functions
    takes,
    drops,
    indexes,
    slices,
    heads,
    lasts,
    tails,
    inits,

    -- * Function application
    extracts,
    reduces,
    joins,
    joinsSafe,
    join,
    joinSafe,
    traverses,
    maps,
    filters,
    zips,
    zipsSafe,
    modifies,
    diffs,

    -- * Array expansion & contraction
    expand,
    coexpand,
    contract,
    prod,
    dot,
    mult,
    windows,

    -- * Search
    find,
    findNoOverlap,
    findIndices,
    isPrefixOf,
    isSuffixOf,
    isInfixOf,

    -- * Shape manipulations
    fill,
    cut,
    cutSuffix,
    pad,
    lpad,
    reshape,
    flat,
    repeat,
    cycle,
    rerank,
    reorder,
    squeeze,
    elongate,
    transpose,
    inflate,
    intercalate,
    intersperse,
    concats,
    reverses,
    rotates,

    -- * Sorting
    sorts,
    sortsBy,
    orders,
    ordersBy,

    -- * Transmission
    transmit,
    transmitSafe,
    transmitOp,
    telecasts,
    telecastsSafe,

    -- * Row specializations
    pattern (:<),
    cons,
    uncons,
    pattern (:>),
    snoc,
    unsnoc,

    -- * Shape specializations
    iota,

    -- * Math
    uniform,
    invtri,
    inverse,
    chol,
  )
where

import Control.Monad hiding (join)
import Data.Bool
import Data.Foldable hiding (find, length, minimum)
import Data.Function
import Data.List qualified as List
import Data.Vector qualified as V
import GHC.Generics
import Harpie.Shape hiding (asScalar, asSingleton, concatenate, range, rank, reorder, rerank, rotate, size, squeeze)
import Harpie.Shape qualified as S
import Harpie.Sort
import Prettyprinter hiding (dot, fill)
import System.Random hiding (uniform)
import System.Random.Stateful hiding (uniform)
import Prelude as P hiding (cycle, drop, length, repeat, take, zip, zipWith)

-- $setup
-- >>> :m -Prelude
-- >>> import Prelude hiding (take, drop, zipWith, length, cycle, repeat)
-- >>> import Harpie.Array as A
-- >>> import Harpie.Shape qualified as S
-- >>> import Data.Vector qualified as V
-- >>> import Prettyprinter hiding (dot, fill)
-- >>> import Data.List qualified as List
-- >>> let s = 1 :: Array Int
-- >>> s
-- UnsafeArray [] [1]
-- >>> pretty s
-- 1
-- >>> let v = range [3]
-- >>> v
-- UnsafeArray [3] [0,1,2]
-- >>> let m = range [2,3]
-- >>> pretty m
-- [[0,1,2],
--  [3,4,5]]
-- >>> let a = range [2,3,4]
-- >>> a
-- UnsafeArray [2,3,4] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
-- >>> pretty a
-- [[[0,1,2,3],
--   [4,5,6,7],
--   [8,9,10,11]],
--  [[12,13,14,15],
--   [16,17,18,19],
--   [20,21,22,23]]]

-- $usage
--
-- Several names used in @harpie@ conflict with [Prelude](https://hackage.haskell.org/package/base/docs/Prelude.html):
--
-- >>> import Prelude hiding (cycle, repeat, take, drop, zipWith, length)
--
-- In general, 'Array' functionality is contained in @Harpie.Array@ and shape  functionality is contained in @Harpie.Shape@. These two modules also have name clashes and at least one needs to be qualified:
--
-- >>> import Harpie.Array as A
-- >>> import Harpie.Shape qualified as S
--
-- [@prettyprinter@](https://hackage.haskell.org/package/prettyprinter) is used to prettily render arrays to better visualise shape.
--
-- >>> import Prettyprinter hiding (dot,fill)
--
-- Examples of arrays:
--
-- An array with no dimensions (a scalar).
--
-- >>> s = 1 :: Array Int
-- >>> s
-- UnsafeArray [] [1]
-- >>> shape s
-- []
-- >>> pretty s
-- 1
--
-- A single-dimension array (a vector).
--
-- >>> let v = range [3]
-- >>> pretty v
-- [0,1,2]
--
-- A two-dimensional array (a matrix).
--
-- >>> let m = range [2,3]
-- >>> pretty m
-- [[0,1,2],
--  [3,4,5]]
--
-- An n-dimensional array (n should be finite).
--
-- >>> a = range [2,3,4]
-- >>> a
-- UnsafeArray [2,3,4] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
-- >>> pretty a
-- [[[0,1,2,3],
--   [4,5,6,7],
--   [8,9,10,11]],
--  [[12,13,14,15],
--   [16,17,18,19],
--   [20,21,22,23]]]

-- | A hyperrectangular (or multidimensional) array with a value-level shape.
--
-- >>> let a = array [2,3,4] [1..24] :: Array Int
-- >>> a
-- UnsafeArray [2,3,4] [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
--
-- >>> pretty a
-- [[[1,2,3,4],
--   [5,6,7,8],
--   [9,10,11,12]],
--  [[13,14,15,16],
--   [17,18,19,20],
--   [21,22,23,24]]]
data Array a = UnsafeArray [Int] (V.Vector a)
  deriving stock ((forall x. Array a -> Rep (Array a) x)
-> (forall x. Rep (Array a) x -> Array a) -> Generic (Array a)
forall x. Rep (Array a) x -> Array a
forall x. Array a -> Rep (Array a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Array a) x -> Array a
forall a x. Array a -> Rep (Array a) x
$cfrom :: forall a x. Array a -> Rep (Array a) x
from :: forall x. Array a -> Rep (Array a) x
$cto :: forall a x. Rep (Array a) x -> Array a
to :: forall x. Rep (Array a) x -> Array a
Generic)
  deriving stock (Array a -> Array a -> Bool
(Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool) -> Eq (Array a)
forall a. Eq a => Array a -> Array a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Array a -> Array a -> Bool
== :: Array a -> Array a -> Bool
$c/= :: forall a. Eq a => Array a -> Array a -> Bool
/= :: Array a -> Array a -> Bool
Eq, Eq (Array a)
Eq (Array a) =>
(Array a -> Array a -> Ordering)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Array a)
-> (Array a -> Array a -> Array a)
-> Ord (Array a)
Array a -> Array a -> Bool
Array a -> Array a -> Ordering
Array a -> Array a -> Array a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Array a)
forall a. Ord a => Array a -> Array a -> Bool
forall a. Ord a => Array a -> Array a -> Ordering
forall a. Ord a => Array a -> Array a -> Array a
$ccompare :: forall a. Ord a => Array a -> Array a -> Ordering
compare :: Array a -> Array a -> Ordering
$c< :: forall a. Ord a => Array a -> Array a -> Bool
< :: Array a -> Array a -> Bool
$c<= :: forall a. Ord a => Array a -> Array a -> Bool
<= :: Array a -> Array a -> Bool
$c> :: forall a. Ord a => Array a -> Array a -> Bool
> :: Array a -> Array a -> Bool
$c>= :: forall a. Ord a => Array a -> Array a -> Bool
>= :: Array a -> Array a -> Bool
$cmax :: forall a. Ord a => Array a -> Array a -> Array a
max :: Array a -> Array a -> Array a
$cmin :: forall a. Ord a => Array a -> Array a -> Array a
min :: Array a -> Array a -> Array a
Ord, Int -> Array a -> ShowS
[Array a] -> ShowS
Array a -> String
(Int -> Array a -> ShowS)
-> (Array a -> String) -> ([Array a] -> ShowS) -> Show (Array a)
forall a. Show a => Int -> Array a -> ShowS
forall a. Show a => [Array a] -> ShowS
forall a. Show a => Array a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Array a -> ShowS
showsPrec :: Int -> Array a -> ShowS
$cshow :: forall a. Show a => Array a -> String
show :: Array a -> String
$cshowList :: forall a. Show a => [Array a] -> ShowS
showList :: [Array a] -> ShowS
Show)

type role Array representational

instance Functor Array where
  fmap :: forall a b. (a -> b) -> Array a -> Array b
fmap a -> b
f = (Vector a -> Vector b) -> Array a -> Array b
forall u a v b.
(FromVector u a, FromVector v b) =>
(u -> v) -> Array a -> Array b
unsafeModifyVector ((a -> b) -> Vector a -> Vector b
forall a b. (a -> b) -> Vector a -> Vector b
V.map a -> b
f)

instance Foldable Array where
  foldr :: forall a b. (a -> b -> b) -> b -> Array a -> b
foldr a -> b -> b
f b
x0 Array a
a = (a -> b -> b) -> b -> Vector a -> b
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr a -> b -> b
f b
x0 (Array a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector Array a
a)

instance Traversable Array where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array a -> f (Array b)
traverse a -> f b
f (UnsafeArray [Int]
s Vector a
v) =
    [Int] -> Vector b -> Array b
forall t a. FromVector t a => [Int] -> t -> Array a
array [Int]
s (Vector b -> Array b) -> f (Vector b) -> f (Array b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> Vector a -> f (Vector b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Vector a -> f (Vector b)
traverse a -> f b
f Vector a
v

instance (Show a) => Pretty (Array a) where
  pretty :: forall ann. Array a -> Doc ann
pretty a :: Array a
a@(UnsafeArray [Int]
_ Vector a
v) = case Array a -> Int
forall a. Array a -> Int
rank Array a
a of
    Int
0 -> a -> Doc ann
forall a ann. Show a => a -> Doc ann
viaShow (Vector a -> a
forall a. Vector a -> a
V.head Vector a
v)
    Int
1 -> Vector a -> Doc ann
forall a ann. Show a => a -> Doc ann
viaShow Vector a
v
    Int
_ ->
      String -> Doc ann
forall ann. String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
"["
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
indent
          Int
0
          ( [Doc ann] -> Doc ann
forall ann. [Doc ann] -> Doc ann
vsep
              ( Doc ann -> [Doc ann] -> [Doc ann]
forall ann. Doc ann -> [Doc ann] -> [Doc ann]
punctuate Doc ann
forall ann. Doc ann
comma ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$
                  Array a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Array a -> Doc ann
pretty
                    (Array a -> Doc ann) -> [Array a] -> [Doc ann]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Array (Array a) -> [Array a]
forall a. Array a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int
0] Array a
a)
              )
          )
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> String -> Doc ann
forall ann. String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
"]"

-- * conversions

instance (Num a) => Num (Array a) where
  + :: Array a -> Array a -> Array a
(+) = (a -> a -> a) -> Array a -> Array a -> Array a
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+)
  (-) = (a -> a -> a) -> Array a -> Array a -> Array a
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith (-)
  * :: Array a -> Array a -> Array a
(*) = String -> Array a -> Array a -> Array a
forall a. HasCallStack => String -> a
error String
"multiplication not defined"
  abs :: Array a -> Array a
abs = (a -> a) -> Array a -> Array a
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
  signum :: Array a -> Array a
signum = (a -> a) -> Array a -> Array a
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum
  fromInteger :: Integer -> Array a
fromInteger Integer
x = a -> Array a
forall a. a -> Array a
toScalar (Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
x)

-- | Conversion to and from a `V.Vector`
--
-- Note that conversion of an 'Array' to a vector drops shape information, so that:
--
-- > vectorAs . asVector == id
-- > asVector . vectorAs == flat
--
-- >>> asVector (range [2,3])
-- [0,1,2,3,4,5]
--
-- >>> vectorAs (V.fromList [0..5]) :: Array Int
-- UnsafeArray [6] [0,1,2,3,4,5]
class FromVector t a | t -> a where
  asVector :: t -> V.Vector a
  vectorAs :: V.Vector a -> t

instance FromVector (V.Vector a) a where
  asVector :: Vector a -> Vector a
asVector = Vector a -> Vector a
forall a. a -> a
id
  vectorAs :: Vector a -> Vector a
vectorAs = Vector a -> Vector a
forall a. a -> a
id

instance FromVector [a] a where
  asVector :: [a] -> Vector a
asVector = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList
  vectorAs :: Vector a -> [a]
vectorAs = Vector a -> [a]
forall a. Vector a -> [a]
V.toList

instance FromVector (Array a) a where
  asVector :: Array a -> Vector a
asVector (UnsafeArray [Int]
_ Vector a
v) = Vector a
v
  vectorAs :: Vector a -> Array a
vectorAs Vector a
v = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v] Vector a
v

-- | Conversion to and from an `Array`
--
-- Note that conversion of an 'Array' to a `FromArray` likely drops shape information, so that:
--
-- > arrayAs . asArray == flat
-- > asArray . arrayAs == id
--
-- >>> asArray ([0..5::Int])
-- UnsafeArray [6] [0,1,2,3,4,5]
--
-- >>> arrayAs (range [2,3]) :: [Int]
-- [0,1,2,3,4,5]
class FromArray t a | t -> a where
  asArray :: t -> Array a
  arrayAs :: Array a -> t

instance FromArray (Array a) a where
  asArray :: Array a -> Array a
asArray = Array a -> Array a
forall a. a -> a
id
  arrayAs :: Array a -> Array a
arrayAs = Array a -> Array a
forall a. a -> a
id

instance FromArray [a] a where
  asArray :: [a] -> Array a
asArray [a]
l = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [[a] -> Int
forall a. [a] -> Int
S.rank [a]
l] ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
l)
  arrayAs :: Array a -> [a]
arrayAs (UnsafeArray [Int]
_ Vector a
v) = Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
v

instance FromArray (V.Vector a) a where
  asArray :: Vector a -> Array a
asArray Vector a
v = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v] Vector a
v
  arrayAs :: Array a -> Vector a
arrayAs (UnsafeArray [Int]
_ Vector a
v) = Vector a
v

-- | Construct an array from a shape and a value without any shape validation.
--
-- >>> array [2,3] [0..5]
-- UnsafeArray [2,3] [0,1,2,3,4,5]
array :: (FromVector t a) => [Int] -> t -> Array a
array :: forall t a. FromVector t a => [Int] -> t -> Array a
array [Int]
s (t -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector -> Vector a
v) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
s Vector a
v

infixl 4 ><

-- | Construct an Array.
--
-- >>> pretty $ [2,3] >< [0..5]
-- [[0,1,2],
--  [3,4,5]]
(><) :: (FromVector t a) => [Int] -> t -> Array a
>< :: forall t a. FromVector t a => [Int] -> t -> Array a
(><) = [Int] -> t -> Array a
forall t a. FromVector t a => [Int] -> t -> Array a
array

-- | Validate the size and shape of an array.
--
-- >>> validate (array [2,3,4] [1..23] :: Array Int)
-- False
validate :: Array a -> Bool
validate :: forall a. Array a -> Bool
validate Array a
a = Array a -> Int
forall a. Array a -> Int
size Array a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector a -> Int
forall a. Vector a -> Int
V.length (Array a -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector Array a
a)

-- | Construct an Array, checking shape.
--
-- >>> safeArray [2,3,4] [0..23] == Just a
-- True
safeArray :: (FromVector t a) => [Int] -> t -> Maybe (Array a)
safeArray :: forall t a. FromVector t a => [Int] -> t -> Maybe (Array a)
safeArray [Int]
s t
v =
  Maybe (Array a) -> Maybe (Array a) -> Bool -> Maybe (Array a)
forall a. a -> a -> Bool -> a
bool Maybe (Array a)
forall a. Maybe a
Nothing (Array a -> Maybe (Array a)
forall a. a -> Maybe a
Just Array a
a) (Array a -> Bool
forall a. Array a -> Bool
validate Array a
a)
  where
    a :: Array a
a = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
s (t -> Vector a
forall t a. FromVector t a => t -> Vector a
asVector t
v)

-- | Unsafely modify an array shape.
--
-- >>> unsafeModifyShape (fmap (+1) :: [Int] -> [Int]) (array [2,3] [0..5])
-- UnsafeArray [3,4] [0,1,2,3,4,5]
unsafeModifyShape :: ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape :: forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape [Int] -> [Int]
f (UnsafeArray [Int]
s Vector a
v) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray ([Int] -> [Int]
f [Int]
s) Vector a
v

-- | Unsafely modify an array vector.
--
-- >>> unsafeModifyVector (V.map (+1)) (array [2,3] [0..5])
-- UnsafeArray [2,3] [1,2,3,4,5,6]
unsafeModifyVector :: (FromVector u a) => (FromVector v b) => (u -> v) -> Array a -> Array b
unsafeModifyVector :: forall u a v b.
(FromVector u a, FromVector v b) =>
(u -> v) -> Array a -> Array b
unsafeModifyVector u -> v
f (UnsafeArray [Int]
s Vector a
v) = [Int] -> Vector b -> Array b
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
s (v -> Vector b
forall t a. FromVector t a => t -> Vector a
asVector (u -> v
f (Vector a -> u
forall t a. FromVector t a => Vector a -> t
vectorAs Vector a
v)))

-- | Representation of an index into a shape (an [Int]). The index is a dimension of the shape.
type Dim = Int

-- | Representation of indexes into a shape (an [Int]). The indexes are dimensions of the shape.
type Dims = [Int]

-- | shape of an Array
--
-- >>> shape a
-- [2,3,4]
shape :: Array a -> [Int]
shape :: forall a. Array a -> [Int]
shape (UnsafeArray [Int]
s Vector a
_) = [Int]
s

-- | rank of an Array
--
-- >>> rank a
-- 3
rank :: Array a -> Int
rank :: forall a. Array a -> Int
rank = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
List.length ([Int] -> Int) -> (Array a -> [Int]) -> Array a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array a -> [Int]
forall a. Array a -> [Int]
shape

-- | size of an Array, which is the total number of elements, if the Array is valid.
--
-- >>> size a
-- 24
size :: Array a -> Int
size :: forall a. Array a -> Int
size = [Int] -> Int
S.size ([Int] -> Int) -> (Array a -> [Int]) -> Array a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array a -> [Int]
forall a. Array a -> [Int]
shape

-- | Number of rows (first dimension size) in an Array. As a convention, a scalar value is still a single row.
--
-- >>> length a
-- 2
-- >>> length (toScalar 0)
-- 1
length :: Array a -> Int
length :: forall a. Array a -> Int
length Array a
a = case Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a of
  [] -> Int
1
  (Int
x : [Int]
_) -> Int
x

-- | Is the Array empty (has zero number of elements).
--
-- >>> isNull ([2,0] >< [] :: Array ())
-- True
-- >>> isNull ([] >< [4] :: Array Int)
-- False
isNull :: Array a -> Bool
isNull :: forall a. Array a -> Bool
isNull = (Int
0 ==) (Int -> Bool) -> (Array a -> Int) -> Array a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array a -> Int
forall a. Array a -> Int
size

-- | Extract an element at an index, unsafely.
--
-- >>> index a [1,2,3]
-- 23
index :: Array a -> [Int] -> a
index :: forall a. Array a -> [Int] -> a
index (UnsafeArray [Int]
s Vector a
v) [Int]
i = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s [Int]
i)

infixl 9 !

-- | Extract an element at an index, unsafely.
--
-- >>> a ! [1,2,3]
-- 23
(!) :: Array a -> [Int] -> a
! :: forall a. Array a -> [Int] -> a
(!) = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index

-- | Extract an element at an index, safely.
--
-- >>> a !? [1,2,3]
-- Just 23
-- >>> a !? [2,3,1]
-- Nothing
(!?) :: Array a -> [Int] -> Maybe a
!? :: forall a. Array a -> [Int] -> Maybe a
(!?) Array a
a [Int]
xs = Maybe a -> Maybe a -> Bool -> Maybe a
forall a. a -> a -> Bool -> a
bool Maybe a
forall a. Maybe a
Nothing (a -> Maybe a
forall a. a -> Maybe a
Just (Array a
a Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
! [Int]
xs)) ([Int]
xs [Int] -> [Int] -> Bool
`isFins` Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)

-- | Tabulate an array supplying a shape and a tabulation function.
--
-- >>> tabulate [2,3,4] (S.flatten [2,3,4]) == a
-- True
tabulate :: [Int] -> ([Int] -> a) -> Array a
tabulate :: forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds [Int] -> a
f =
  [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
ds (Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate (Vector Int -> Int
forall a. Num a => Vector a -> a
V.product ([Int] -> Vector Int
forall t a. FromVector t a => t -> Vector a
asVector [Int]
ds)) ([Int] -> a
f ([Int] -> a) -> (Int -> [Int]) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen [Int]
ds))

-- | @backpermute@ is a tabulation where the contents of an array do not need to be accessed, and is thus a fulcrum for leveraging laziness and fusion via the rule:
--
-- > backpermute f g (backpermute f' g' a) == backpermute (f . f') (g . g') a
--
-- Many functions in this module are examples of backpermute usage.
--
-- >>> pretty $ backpermute List.reverse List.reverse a
-- [[[0,12],
--   [4,16],
--   [8,20]],
--  [[1,13],
--   [5,17],
--   [9,21]],
--  [[2,14],
--   [6,18],
--   [10,22]],
--  [[3,15],
--   [7,19],
--   [11,23]]]
backpermute :: ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute :: forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
f [Int] -> [Int]
g Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int]
f (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
g)
{-# INLINEABLE backpermute #-}

{- RULES
   "backpermute/backpermute" forall f f' g g' (a :: forall a. Array a)). backpermute f g (backpermute f' g' a) == backpermute (f . f') (g . g') a

-}

-- | Unwrap a scalar.
--
-- >>> let s = array [] [3] :: Array Int
-- >>> fromScalar s
-- 3
fromScalar :: Array a -> a
fromScalar :: forall a. Array a -> a
fromScalar Array a
a = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([] :: [Int])

-- | Wrap a scalar.
--
-- >>> :t toScalar 2
-- toScalar 2 :: Num a => Array a
toScalar :: a -> Array a
toScalar :: forall a. a -> Array a
toScalar a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [] (a -> [Int] -> a
forall a b. a -> b -> a
const a
a)

-- | Is an array a Scalar?
--
-- >>> isScalar (toScalar (2::Int))
-- True
isScalar :: Array a -> Bool
isScalar :: forall a. Array a -> Bool
isScalar Array a
a = Array a -> Int
forall a. Array a -> Int
rank Array a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0

-- | Convert a scalar to being a dimensioned array. Do nothing if not a scalar.
--
-- >>> asSingleton (toScalar 4)
-- UnsafeArray [1] [4]
asSingleton :: Array a -> Array a
asSingleton :: forall a. Array a -> Array a
asSingleton = ([Int] -> [Int]) -> Array a -> Array a
forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape [Int] -> [Int]
S.asSingleton

-- | Convert an array with shape [1] to being a scalar (Do nothing if not a shape [1] array).
--
-- >>> asScalar (singleton 3)
-- UnsafeArray [] [3]
asScalar :: Array a -> Array a
asScalar :: forall a. Array a -> Array a
asScalar = ([Int] -> [Int]) -> Array a -> Array a
forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape [Int] -> [Int]
S.asScalar

-- * Creation

-- | An array with no elements.
--
-- >>> empty
-- UnsafeArray [0] []
empty :: Array a
empty :: forall a. Array a
empty = [Int] -> [a] -> Array a
forall t a. FromVector t a => [Int] -> t -> Array a
array [Int
0] []

-- | An enumeration of row-major or [lexicographic](https://en.wikipedia.org/wiki/Lexicographic_order) order.
--
-- >>> pretty $ range [2,3]
-- [[0,1,2],
--  [3,4,5]]
range :: [Int] -> Array Int
range :: [Int] -> Array Int
range [Int]
xs = [Int] -> ([Int] -> Int) -> Array Int
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
xs ([Int] -> [Int] -> Int
flatten [Int]
xs)

-- | An enumeration of col-major or [colexicographic](https://en.wikipedia.org/wiki/Lexicographic_order) order.
--
-- >>> pretty (corange [2,3,4])
-- [[[0,6,12,18],
--   [2,8,14,20],
--   [4,10,16,22]],
--  [[1,7,13,19],
--   [3,9,15,21],
--   [5,11,17,23]]]
corange :: [Int] -> Array Int
corange :: [Int] -> Array Int
corange [Int]
xs = [Int] -> ([Int] -> Int) -> Array Int
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
xs ([Int] -> [Int] -> Int
flatten ([Int] -> [Int]
forall a. [a] -> [a]
List.reverse [Int]
xs) ([Int] -> Int) -> ([Int] -> [Int]) -> [Int] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. [a] -> [a]
List.reverse)

-- | Indices of an array shape.
--
-- >>> pretty $ indices [3,3]
-- [[[0,0],[0,1],[0,2]],
--  [[1,0],[1,1],[1,2]],
--  [[2,0],[2,1],[2,2]]]
indices :: [Int] -> Array [Int]
indices :: [Int] -> Array [Int]
indices [Int]
ds = [Int] -> ([Int] -> [Int]) -> Array [Int]
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds [Int] -> [Int]
forall a. a -> a
id

-- | The identity array.
--
-- >>> pretty $ ident [3,3]
-- [[1,0,0],
--  [0,1,0],
--  [0,0,1]]
ident :: (Num a) => [Int] -> Array a
ident :: forall a. Num a => [Int] -> Array a
ident [Int]
ds = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
0 a
1 (Bool -> a) -> ([Int] -> Bool) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Bool
forall a. Eq a => [a] -> Bool
isDiag)

-- | Create an array composed of a single value.
--
-- >>> pretty $ konst [3,2] 1
-- [[1,1],
--  [1,1],
--  [1,1]]
konst :: [Int] -> a -> Array a
konst :: forall a. [Int] -> a -> Array a
konst [Int]
ds a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (a -> [Int] -> a
forall a b. a -> b -> a
const a
a)

-- | Create an array of shape [1].
--
-- >>> pretty $ singleton 1
-- [1]
-- >>> singleton 3 == toScalar 3
-- False
--
-- >>> asVector (singleton 3) == asVector (toScalar 3)
-- True
singleton :: a -> Array a
singleton :: forall a. a -> Array a
singleton a
a = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int
1] (a -> Vector a
forall a. a -> Vector a
V.singleton a
a)

-- | Extract the diagonal of an array.
--
-- >>> pretty $ diag (ident [3,3])
-- [1,1,1]
diag ::
  Array a ->
  Array a
diag :: forall a. Array a -> Array a
diag Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
minDim (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Array a -> Int
forall a. Array a -> Int
rank Array a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Int] -> Int
getDim Int
0) Array a
a

-- | Expand the array to form a diagonal array.
--
-- >>> pretty $ undiag (range [3])
-- [[0,0,0],
--  [0,1,0],
--  [0,0,2]]
undiag ::
  (Num a) =>
  Array a ->
  Array a
undiag :: forall a. Num a => Array a -> Array a
undiag Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (\[Int]
xs -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
0 (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int]
xs) ([Int] -> Bool
forall a. Eq a => [a] -> Bool
isDiag [Int]
xs))

-- | Zip two arrays at an element level.
--
-- >>> zipWith (-) v v
-- UnsafeArray [3] [0,0,0]
zipWith :: (a -> b -> c) -> Array a -> Array b -> Array c
zipWith :: forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith a -> b -> c
f (UnsafeArray [Int]
s Vector a
v) (UnsafeArray [Int]
_ Vector b
v') = [Int] -> Vector c -> Array c
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
s ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> b -> c
f Vector a
v Vector b
v')

-- | Zip two arrays at an element level, checking for shape consistency.
--
-- >>> zipWithSafe (-) (range [3]) (range [4])
-- Nothing
zipWithSafe :: (a -> b -> c) -> Array a -> Array b -> Maybe (Array c)
zipWithSafe :: forall a b c.
(a -> b -> c) -> Array a -> Array b -> Maybe (Array c)
zipWithSafe a -> b -> c
f (UnsafeArray [Int]
s Vector a
v) (UnsafeArray [Int]
s' Vector b
v') = Maybe (Array c) -> Maybe (Array c) -> Bool -> Maybe (Array c)
forall a. a -> a -> Bool -> a
bool Maybe (Array c)
forall a. Maybe a
Nothing (Array c -> Maybe (Array c)
forall a. a -> Maybe a
Just (Array c -> Maybe (Array c)) -> Array c -> Maybe (Array c)
forall a b. (a -> b) -> a -> b
$ [Int] -> Vector c -> Array c
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
s ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> b -> c
f Vector a
v Vector b
v')) ([Int]
s [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
s')

-- | Modify a single value at an index.
--
-- >>> pretty $ modify [0,0] (const 100) (range [3,2])
-- [[100,1],
--  [2,3],
--  [4,5]]
modify :: [Int] -> (a -> a) -> Array a -> Array a
modify :: forall a. [Int] -> (a -> a) -> Array a -> Array a
modify [Int]
ds a -> a
f Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (\[Int]
s -> (a -> a) -> (a -> a) -> Bool -> a -> a
forall a. a -> a -> Bool -> a
bool a -> a
forall a. a -> a
id a -> a
f ([Int]
s [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
ds) (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int]
s))

-- | Maps an index function at element-level.
--
-- >>> pretty $ imap (\xs x -> x - sum xs) a
-- [[[0,0,0,0],
--   [3,3,3,3],
--   [6,6,6,6]],
--  [[11,11,11,11],
--   [14,14,14,14],
--   [17,17,17,17]]]
imap ::
  ([Int] -> a -> b) ->
  Array a ->
  Array b
imap :: forall a b. ([Int] -> a -> b) -> Array a -> Array b
imap [Int] -> a -> b
f Array a
a = ([Int] -> a -> b) -> Array [Int] -> Array a -> Array b
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith [Int] -> a -> b
f ([Int] -> Array [Int]
indices (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a

-- | With a function that takes dimensions and (type-level) parameters, apply the parameters to the initial dimensions. ie
--
-- > rowWise f xs = f [0..] xs
--
-- >>> rowWise indexes [1,0] a
-- UnsafeArray [4] [12,13,14,15]
rowWise :: (Dims -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a
rowWise :: forall x a.
([Int] -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a
rowWise [Int] -> [x] -> Array a -> Array a
f [x]
xs Array a
a = [Int] -> [x] -> Array a -> Array a
f [Int
0 .. ([x] -> Int
forall a. [a] -> Int
S.rank [x]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [x]
xs Array a
a

-- | With a function that takes dimensions and (type-level) parameters, apply the parameters to the the last dimensions. ie
--
-- > colWise f xs = f (List.reverse [0 .. (rank a - 1)]) xs
--
-- >>> colWise indexes [1,0] a
-- UnsafeArray [2] [1,13]
colWise :: (Dims -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a
colWise :: forall x a.
([Int] -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a
colWise [Int] -> [x] -> Array a -> Array a
f [x]
xs Array a
a = [Int] -> [x] -> Array a -> Array a
f ([Int] -> [Int]
forall a. [a] -> [a]
List.reverse [(Array a -> Int
forall a. Array a -> Int
rank Array a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- [x] -> Int
forall a. [a] -> Int
S.rank [x]
xs) .. (Array a -> Int
forall a. Array a -> Int
rank Array a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]) [x]
xs Array a
a

-- | With a function that takes a dimension and a parameter, fold dimensions and parameters using the function.
--
-- >>> dimsWise take [0,2] [1,2] a
-- UnsafeArray [1,3,2] [0,1,4,5,8,9]
dimsWise :: (Dim -> x -> Array a -> Array a) -> Dims -> [x] -> Array a -> Array a
dimsWise :: forall x a.
(Int -> x -> Array a -> Array a)
-> [Int] -> [x] -> Array a -> Array a
dimsWise Int -> x -> Array a -> Array a
f [Int]
ds [x]
xs Array a
a = (Array a -> (Int, x) -> Array a)
-> Array a -> [(Int, x)] -> Array a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Array a
a' (Int
d, x
x) -> Int -> x -> Array a -> Array a
f Int
d x
x Array a
a') Array a
a ([Int] -> [x] -> [(Int, x)]
forall a b. [a] -> [b] -> [(a, b)]
List.zip [Int]
ds [x]
xs)

-- | Take the top-most elements across the specified dimension. Negative values take the bottom-most. No index check is performed.
--
-- > take d x == takes [(d,x)]
--
-- >>> pretty $ take 2 1 a
-- [[[0],
--   [4],
--   [8]],
--  [[12],
--   [16],
--   [20]]]
-- >>> pretty $ take 2 (-1) a
-- [[[3],
--   [7],
--   [11]],
--  [[15],
--   [19],
--   [23]]]
take ::
  Dim ->
  Int ->
  Array a ->
  Array a
take :: forall a. Int -> Int -> Array a -> Array a
take Int
d Int
t Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
dsNew (Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Bool -> Int
forall a. a -> a -> Bool -> a
bool Int
0 (Int -> [Int] -> Int
getDim Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) (Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0))) Array a
a
  where
    dsNew :: [Int] -> [Int]
dsNew = Int -> Int -> [Int] -> [Int]
takeDim Int
d (Int -> Int
forall a. Num a => a -> a
abs Int
t)

-- | Drop the top-most elements across the specified dimension. Negative values take the bottom-most.
--
-- >>> pretty $ drop 2 1 a
-- [[[1,2,3],
--   [5,6,7],
--   [9,10,11]],
--  [[13,14,15],
--   [17,18,19],
--   [21,22,23]]]
-- >>> pretty $ drop 2 (-1) a
-- [[[0,1,2],
--   [4,5,6],
--   [8,9,10]],
--  [[12,13,14],
--   [16,17,18],
--   [20,21,22]]]
drop ::
  Dim ->
  Int ->
  Array a ->
  Array a
drop :: forall a. Int -> Int -> Array a -> Array a
drop Int
d Int
t Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
dsNew (Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Bool -> Int
forall a. a -> a -> Bool -> a
bool Int
t Int
0 (Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0))) Array a
a
  where
    dsNew :: [Int] -> [Int]
dsNew = Int -> Int -> [Int] -> [Int]
dropDim Int
d (Int -> Int
forall a. Num a => a -> a
abs Int
t)

-- | Select an index along a dimension.
--
-- >>> let s = select 2 3 a
-- >>> pretty s
-- [[3,7,11],
--  [15,19,23]]
select ::
  Dim ->
  Int ->
  Array a ->
  Array a
select :: forall a. Int -> Int -> Array a -> Array a
select Int
d Int
x Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute (Int -> [Int] -> [Int]
deleteDim Int
d) (Int -> Int -> [Int] -> [Int]
insertDim Int
d Int
x) Array a
a

-- | Insert along a dimension at a position.
--
-- >>> pretty $ insert 2 0 a (konst [2,3] 0)
-- [[[0,0,1,2,3],
--   [0,4,5,6,7],
--   [0,8,9,10,11]],
--  [[0,12,13,14,15],
--   [0,16,17,18,19],
--   [0,20,21,22,23]]]
-- >>> insert 0 0 (toScalar 1) (toScalar 2)
-- UnsafeArray [2] [2,1]
insert ::
  Dim ->
  Int ->
  Array a ->
  Array a ->
  Array a
insert :: forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d Int
i Array a
a Array a
b = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int]
incAt Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s
      | Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
b (Int -> [Int] -> [Int]
deleteDim Int
d [Int]
s)
      | Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int]
s
      | Bool
otherwise = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
decAt Int
d [Int]
s)

-- | Delete along a dimension at a position.
--
-- >>> pretty $ delete 2 0 a
-- [[[1,2,3],
--   [5,6,7],
--   [9,10,11]],
--  [[13,14,15],
--   [17,18,19],
--   [21,22,23]]]
delete ::
  Dim ->
  Int ->
  Array a ->
  Array a
delete :: forall a. Int -> Int -> Array a -> Array a
delete Int
d Int
i Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute (Int -> [Int] -> [Int]
decAt Int
d) (\[Int]
s -> [Int] -> [Int] -> Bool -> [Int]
forall a. a -> a -> Bool -> a
bool (Int -> [Int] -> [Int]
incAt Int
d [Int]
s) [Int]
s (Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i)) Array a
a

-- | Insert along a dimension at the end.
--
-- >>> pretty $ append 2 a (konst [2,3] 0)
-- [[[0,1,2,3,0],
--   [4,5,6,7,0],
--   [8,9,10,11,0]],
--  [[12,13,14,15,0],
--   [16,17,18,19,0],
--   [20,21,22,23,0]]]
append ::
  Dim ->
  Array a ->
  Array a ->
  Array a
append :: forall a. Int -> Array a -> Array a -> Array a
append Int
d Array a
a Array a
b = Int -> Int -> Array a -> Array a -> Array a
forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d (Int -> [Int] -> Int
getDim Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a Array a
b

-- | Insert along a dimension at the beginning.
--
-- >>> pretty $ prepend 2 (konst [2,3] 0) a
-- [[[0,0,1,2,3],
--   [0,4,5,6,7],
--   [0,8,9,10,11]],
--  [[0,12,13,14,15],
--   [0,16,17,18,19],
--   [0,20,21,22,23]]]
prepend ::
  Dim ->
  Array a ->
  Array a ->
  Array a
prepend :: forall a. Int -> Array a -> Array a -> Array a
prepend Int
d Array a
a Array a
b = Int -> Int -> Array a -> Array a -> Array a
forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d Int
0 Array a
b Array a
a

-- | Concatenate along a dimension.
--
-- >>> shape $ concatenate 1 a a
-- [2,6,4]
-- >>> concatenate 0 (toScalar 1) (toScalar 2)
-- UnsafeArray [2] [1,2]
-- >>> concatenate 0 (toScalar 0) (asArray [1..3])
-- UnsafeArray [4] [0,1,2,3]
concatenate ::
  Dim ->
  Array a ->
  Array a ->
  Array a
concatenate :: forall a. Int -> Array a -> Array a -> Array a
concatenate Int
d Array a
a0 Array a
a1 = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int] -> [Int]
S.concatenate Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a0) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a1)) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s =
      a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
        (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a0 [Int]
s)
        ( Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index
            Array a
a1
            ( Int -> Int -> [Int] -> [Int]
insertDim
                Int
d
                (Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> [Int] -> Int
getDim Int
d [Int]
ds0)
                (Int -> [Int] -> [Int]
deleteDim Int
d [Int]
s)
            )
        )
        (Int -> [Int] -> Int
getDim Int
d [Int]
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> [Int] -> Int
getDim Int
d [Int]
ds0)
    ds0 :: [Int]
ds0 = Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a0

-- | Combine two arrays as a new dimension of a new array.
--
-- >>> pretty $ couple 0 (asArray [1,2,3]) (asArray [4,5,6::Int])
-- [[1,2,3],
--  [4,5,6]]
couple :: Int -> Array a -> Array a -> Array a
couple :: forall a. Int -> Array a -> Array a -> Array a
couple Int
d Array a
a Array a
a' = Int -> Array a -> Array a -> Array a
forall a. Int -> Array a -> Array a -> Array a
concatenate Int
d (Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
elongate Int
d Array a
a) (Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
elongate Int
d Array a
a')

-- | Slice along a dimension with the supplied offset & length.
--
-- >>> let s = slice 2 1 2 a
-- >>> pretty s
-- [[[1,2],
--   [5,6],
--   [9,10]],
--  [[13,14],
--   [17,18],
--   [21,22]]]
slice ::
  Dim ->
  Int ->
  Int ->
  Array a ->
  Array a
slice :: forall a. Int -> Int -> Int -> Array a -> Array a
slice Int
d Int
o Int
l Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute (Int -> Int -> [Int] -> [Int]
setDim Int
d Int
l) (Int -> (Int -> Int) -> [Int] -> [Int]
modifyDim Int
d (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
o)) Array a
a

-- | Rotate an array along a dimension.
--
-- >>> pretty $ rotate 1 2 a
-- [[[8,9,10,11],
--   [0,1,2,3],
--   [4,5,6,7]],
--  [[20,21,22,23],
--   [12,13,14,15],
--   [16,17,18,19]]]
rotate ::
  Dim ->
  Int ->
  Array a ->
  Array a
rotate :: forall a. Int -> Int -> Array a -> Array a
rotate Int
d Int
r Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
forall a. a -> a
id (Int -> Int -> [Int] -> [Int] -> [Int]
rotateIndex Int
d Int
r (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a

-- * multi-dimension operators

-- | Takes the top-most elements across the supplied dimension,n tuples. Negative values take the bottom-most.
--
-- > takes == dimsWise take
--
-- >>> pretty $ takes [0,2] [1,-3] a
-- [[[1,2,3],
--   [5,6,7],
--   [9,10,11]]]
takes ::
  Dims ->
  [Int] ->
  Array a ->
  Array a
takes :: forall a. [Int] -> [Int] -> Array a -> Array a
takes [Int]
ds [Int]
xs Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
dsNew ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
start) Array a
a
  where
    dsNew :: [Int] -> [Int]
dsNew = [Int] -> [Int] -> [Int] -> [Int]
setDims [Int]
ds [Int]
xsAbs
    start :: [Int]
start = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (\Int
x Int
s -> Int -> Int -> Bool -> Int
forall a. a -> a -> Bool -> a
bool Int
0 (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x) (Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0)) ([Int] -> [Int] -> [Int] -> [Int]
setDims [Int]
ds [Int]
xs (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Array a -> Int
forall a. Array a -> Int
rank Array a
a) Int
0)) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)
    xsAbs :: [Int]
xsAbs = (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Int
forall a. Num a => a -> a
abs [Int]
xs

-- | Drops the top-most elements. Negative values drop the bottom-most.
--
-- >>> pretty $ drops [0,1,2] [1,2,-3] a
-- [[[20]]]
drops ::
  Dims ->
  [Int] ->
  Array a ->
  Array a
drops :: forall a. [Int] -> [Int] -> Array a -> Array a
drops [Int]
ds [Int]
xs Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
dsNew ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (\Int
d' Int
s' -> Int -> Int -> Bool -> Int
forall a. a -> a -> Bool -> a
bool (Int
d' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s') Int
s' (Int
d' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0)) [Int]
xsNew) Array a
a
  where
    dsNew :: [Int] -> [Int]
dsNew = [Int] -> [Int] -> [Int] -> [Int]
dropDims [Int]
ds [Int]
xsAbs
    xsNew :: [Int]
xsNew = [Int] -> [Int] -> [Int] -> [Int]
setDims [Int]
ds [Int]
xs (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Array a -> Int
forall a. Array a -> Int
rank Array a
a) Int
0)
    xsAbs :: [Int]
xsAbs = (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Int
forall a. Num a => a -> a
abs [Int]
xs

-- | Select by dimensions and indexes.
--
-- >>> let s = indexes [0,1] [1,1] a
-- >>> pretty s
-- [16,17,18,19]
indexes :: Dims -> [Int] -> Array a -> Array a
indexes :: forall a. [Int] -> [Int] -> Array a -> Array a
indexes [Int]
ds [Int]
xs Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> [Int] -> [Int]
deleteDims [Int]
ds) ([Int] -> [Int] -> [Int] -> [Int]
insertDims [Int]
ds [Int]
xs) Array a
a

-- | Slice along dimensions with the supplied offsets and lengths.
--
-- >>> let s = slices [2,0] [1,1] [2,1] a
-- >>> pretty s
-- [[[13,14],
--   [17,18],
--   [21,22]]]
slices :: Dims -> [Int] -> [Int] -> Array a -> Array a
slices :: forall a. [Int] -> [Int] -> [Int] -> Array a -> Array a
slices [Int]
ds [Int]
os [Int]
ls Array a
a = (Int -> (Int, Int) -> Array a -> Array a)
-> [Int] -> [(Int, Int)] -> Array a -> Array a
forall x a.
(Int -> x -> Array a -> Array a)
-> [Int] -> [x] -> Array a -> Array a
dimsWise (\Int
d (Int
o, Int
l) -> Int -> Int -> Int -> Array a -> Array a
forall a. Int -> Int -> Int -> Array a -> Array a
slice Int
d Int
o Int
l) [Int]
ds ([Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
List.zip [Int]
os [Int]
ls) Array a
a

-- | Select the first element along the supplied dimensions.
--
-- >>> pretty $ heads [0,2] a
-- [0,4,8]
heads :: Dims -> Array a -> Array a
heads :: forall a. [Int] -> Array a -> Array a
heads [Int]
ds Array a
a = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
indexes [Int]
ds (Int -> Int -> [Int]
forall a. Int -> a -> [a]
List.replicate ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
ds) Int
0) Array a
a

-- | Select the last element along the supplied dimensions.
--
-- >>> pretty $ lasts [0,2] a
-- [15,19,23]
lasts :: Dims -> Array a -> Array a
lasts :: forall a. [Int] -> Array a -> Array a
lasts [Int]
ds Array a
a = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
indexes [Int]
ds [Int]
lastds Array a
a
  where
    lastds :: [Int]
lastds = (\Int
i -> Int -> [Int] -> Int
getDim Int
i (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int) -> [Int] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
ds

-- | Select the tail elements along the supplied dimensions.
--
-- >>> pretty $ tails [0,2] a
-- [[[13,14,15],
--   [17,18,19],
--   [21,22,23]]]
tails :: Dims -> Array a -> Array a
tails :: forall a. [Int] -> Array a -> Array a
tails [Int]
ds Array a
a = [Int] -> [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> [Int] -> Array a -> Array a
slices [Int]
ds [Int]
os [Int]
ls Array a
a
  where
    os :: [Int]
os = Int -> Int -> [Int]
forall a. Int -> a -> [a]
List.replicate ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
ds) Int
1
    ls :: [Int]
ls = [Int] -> [Int] -> [Int]
getLastPositions [Int]
ds (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)

-- | Select the init elements along the supplied dimensions.
--
-- >>> pretty $ inits [0,2] a
-- [[[0,1,2],
--   [4,5,6],
--   [8,9,10]]]
inits :: Dims -> Array a -> Array a
inits :: forall a. [Int] -> Array a -> Array a
inits [Int]
ds Array a
a = [Int] -> [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> [Int] -> Array a -> Array a
slices [Int]
ds [Int]
os [Int]
ls Array a
a
  where
    os :: [Int]
os = Int -> Int -> [Int]
forall a. Int -> a -> [a]
List.replicate ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
ds) Int
0
    ls :: [Int]
ls = [Int] -> [Int] -> [Int]
getLastPositions [Int]
ds (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)

-- | Extracts dimensions to an outer layer.
--
-- >>> pretty $ shape <$> extracts [0] a
-- [[3,4],[3,4]]
extracts ::
  Dims ->
  Array a ->
  Array (Array a)
extracts :: forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a = [Int] -> ([Int] -> Array a) -> Array (Array a)
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
getDims [Int]
ds (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int] -> Array a
go
  where
    go :: [Int] -> Array a
go [Int]
s = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
indexes [Int]
ds [Int]
s Array a
a

-- | Reduce along specified dimensions, using the supplied fold.
--
-- >>> pretty $ reduces [0] sum a
-- [66,210]
-- >>> pretty $ reduces [0,2] sum a
-- [[12,15,18,21],
--  [48,51,54,57]]
reduces ::
  Dims ->
  (Array a -> b) ->
  Array a ->
  Array b
reduces :: forall a b. [Int] -> (Array a -> b) -> Array a -> Array b
reduces [Int]
ds Array a -> b
f Array a
a = (Array a -> b) -> Array (Array a) -> Array b
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> b
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- | Join inner and outer dimension layers by supplied dimensions. No checks on shape.
--
-- >>> let e = extracts [1,0] a
-- >>> let j = joins [1,0] e
-- >>> a == j
-- True
joins ::
  Dims ->
  Array (Array a) ->
  Array a
joins :: forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds Array (Array a)
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int] -> [Int]
insertDims [Int]
ds [Int]
so [Int]
si) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index (Array (Array a) -> [Int] -> Array a
forall a. Array a -> [Int] -> a
index Array (Array a)
a ([Int] -> [Int] -> [Int]
getDims [Int]
ds [Int]
s)) ([Int] -> [Int] -> [Int]
deleteDims [Int]
ds [Int]
s)
    so :: [Int]
so = Array (Array a) -> [Int]
forall a. Array a -> [Int]
shape Array (Array a)
a
    si :: [Int]
si = Array a -> [Int]
forall a. Array a -> [Int]
shape (Array (Array a) -> [Int] -> Array a
forall a. Array a -> [Int] -> a
index Array (Array a)
a (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Array (Array a) -> Int
forall a. Array a -> Int
rank Array (Array a)
a) Int
0))

-- | Join inner and outer dimension layers by supplied dimensions. Check inner layer shape.
--
-- >>> let e = extracts [1,0] a
-- >>> (Just j) = joinsSafe [1,0] e
-- >>> a == j
-- True
joinsSafe ::
  Dims ->
  Array (Array a) ->
  Maybe (Array a)
joinsSafe :: forall a. [Int] -> Array (Array a) -> Maybe (Array a)
joinsSafe [Int]
ds Array (Array a)
a =
  Maybe (Array a) -> Maybe (Array a) -> Bool -> Maybe (Array a)
forall a. a -> a -> Bool -> a
bool
    Maybe (Array a)
forall a. Maybe a
Nothing
    (Array a -> Maybe (Array a)
forall a. a -> Maybe a
Just (Array a -> Maybe (Array a)) -> Array a -> Maybe (Array a)
forall a b. (a -> b) -> a -> b
$ [Int] -> Array (Array a) -> Array a
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds Array (Array a)
a)
    (Array [Int] -> Bool
forall a. Eq a => Array a -> Bool
allEqual ((Array a -> [Int]) -> Array (Array a) -> Array [Int]
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> [Int]
forall a. Array a -> [Int]
shape Array (Array a)
a))

-- | Join inner and outer dimension layers in outer dimension order.
--
-- >>> a == join (extracts [0,1] a)
-- True
join ::
  Array (Array a) ->
  Array a
join :: forall a. Array (Array a) -> Array a
join Array (Array a)
a = [Int] -> Array (Array a) -> Array a
forall a. [Int] -> Array (Array a) -> Array a
joins ([Int] -> [Int]
S.dimsOf (Array (Array a) -> [Int]
forall a. Array a -> [Int]
shape Array (Array a)
a)) Array (Array a)
a

-- | Join inner and outer dimension layers in outer dimension order, checking for consistent inner dimension shape.
--
-- >>> joinSafe (extracts [0,1] a)
-- Just (UnsafeArray [2,3,4] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23])
joinSafe ::
  Array (Array a) ->
  Maybe (Array a)
joinSafe :: forall a. Array (Array a) -> Maybe (Array a)
joinSafe Array (Array a)
a =
  Maybe (Array a) -> Maybe (Array a) -> Bool -> Maybe (Array a)
forall a. a -> a -> Bool -> a
bool
    Maybe (Array a)
forall a. Maybe a
Nothing
    (Array a -> Maybe (Array a)
forall a. a -> Maybe a
Just (Array a -> Maybe (Array a)) -> Array a -> Maybe (Array a)
forall a b. (a -> b) -> a -> b
$ Array (Array a) -> Array a
forall a. Array (Array a) -> Array a
join Array (Array a)
a)
    (Array [Int] -> Bool
forall a. Eq a => Array a -> Bool
allEqual ((Array a -> [Int]) -> Array (Array a) -> Array [Int]
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> [Int]
forall a. Array a -> [Int]
shape Array (Array a)
a))

-- | Satisfy a predicate across all elements
allEqual :: (Eq a) => Array a -> Bool
allEqual :: forall a. Eq a => Array a -> Bool
allEqual Array a
a = case Array a -> [a]
forall t a. FromArray t a => Array a -> t
arrayAs Array a
a of
  [] -> Bool
True
  (a
x : [a]
xs) -> (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x) [a]
xs

-- | Traverse along specified dimensions.
--
-- traverses [1] print (range [2,3])
-- 0
-- 3
-- 1
-- 4
-- 2
-- 5
-- UnsafeArray [2,3] [(),(),(),(),(),()]
traverses ::
  (Applicative f) =>
  Dims ->
  (a -> f b) ->
  Array a ->
  f (Array b)
traverses :: forall (f :: * -> *) a b.
Applicative f =>
[Int] -> (a -> f b) -> Array a -> f (Array b)
traverses [Int]
ds a -> f b
f Array a
a = [Int] -> Array (Array b) -> Array b
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds (Array (Array b) -> Array b) -> f (Array (Array b)) -> f (Array b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Array a -> f (Array b)) -> Array (Array a) -> f (Array (Array b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array a -> f (Array b)
traverse ((a -> f b) -> Array a -> f (Array b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array a -> f (Array b)
traverse a -> f b
f) ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- | Maps a function along specified dimensions.
--
-- >>> pretty $ maps [1] transpose a
-- [[[0,12],
--   [4,16],
--   [8,20]],
--  [[1,13],
--   [5,17],
--   [9,21]],
--  [[2,14],
--   [6,18],
--   [10,22]],
--  [[3,15],
--   [7,19],
--   [11,23]]]
maps ::
  Dims ->
  (Array a -> Array b) ->
  Array a ->
  Array b
maps :: forall a b. [Int] -> (Array a -> Array b) -> Array a -> Array b
maps [Int]
ds Array a -> Array b
f Array a
a = [Int] -> Array (Array b) -> Array b
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds ((Array a -> Array b) -> Array (Array a) -> Array (Array b)
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> Array b
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a))

-- | Filters along specified dimensions (which are flattened).
--
-- >>> pretty $ filters [0,1] (any ((==0) . (`mod` 7))) a
-- [[0,1,2,3],
--  [4,5,6,7],
--  [12,13,14,15],
--  [20,21,22,23]]
filters ::
  Dims ->
  (Array a -> Bool) ->
  Array a ->
  Array a
filters :: forall a. [Int] -> (Array a -> Bool) -> Array a -> Array a
filters [Int]
ds Array a -> Bool
p Array a
a = Array (Array a) -> Array a
forall a. Array (Array a) -> Array a
join (Vector (Array a) -> Array (Array a)
forall t a. FromArray t a => t -> Array a
asArray (Vector (Array a) -> Array (Array a))
-> Vector (Array a) -> Array (Array a)
forall a b. (a -> b) -> a -> b
$ (Array a -> Bool) -> Vector (Array a) -> Vector (Array a)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Array a -> Bool
p (Vector (Array a) -> Vector (Array a))
-> Vector (Array a) -> Vector (Array a)
forall a b. (a -> b) -> a -> b
$ Array (Array a) -> Vector (Array a)
forall t a. FromVector t a => t -> Vector a
asVector ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a))

-- | Zips two arrays with a function along specified dimensions.
--
-- >>> pretty $ zips [0,1] (zipWith (,)) a (reverses [0] a)
-- [[[(0,12),(1,13),(2,14),(3,15)],
--   [(4,16),(5,17),(6,18),(7,19)],
--   [(8,20),(9,21),(10,22),(11,23)]],
--  [[(12,0),(13,1),(14,2),(15,3)],
--   [(16,4),(17,5),(18,6),(19,7)],
--   [(20,8),(21,9),(22,10),(23,11)]]]
zips ::
  Dims ->
  (Array a -> Array b -> Array c) ->
  Array a ->
  Array b ->
  Array c
zips :: forall a b c.
[Int]
-> (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
zips [Int]
ds Array a -> Array b -> Array c
f Array a
a Array b
b = [Int] -> Array (Array c) -> Array c
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds ((Array a -> Array b -> Array c)
-> Array (Array a) -> Array (Array b) -> Array (Array c)
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith Array a -> Array b -> Array c
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a) ([Int] -> Array b -> Array (Array b)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array b
b))

-- | Zips two arrays with a function along specified dimensions, checking shapes.
--
-- >>> zipsSafe [0] (zipWith (,)) (asArray [1::Int]) (asArray [1,2::Int])
-- Nothing
zipsSafe ::
  Dims ->
  (Array a -> Array b -> Array c) ->
  Array a ->
  Array b ->
  Maybe (Array c)
zipsSafe :: forall a b c.
[Int]
-> (Array a -> Array b -> Array c)
-> Array a
-> Array b
-> Maybe (Array c)
zipsSafe [Int]
ds Array a -> Array b -> Array c
f Array a
a Array b
b =
  Maybe (Array c) -> Maybe (Array c) -> Bool -> Maybe (Array c)
forall a. a -> a -> Bool -> a
bool
    (Array c -> Maybe (Array c)
forall a. a -> Maybe a
Just (Array c -> Maybe (Array c)) -> Array c -> Maybe (Array c)
forall a b. (a -> b) -> a -> b
$ [Int] -> Array (Array c) -> Array c
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds ((Array a -> Array b -> Array c)
-> Array (Array a) -> Array (Array b) -> Array (Array c)
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith Array a -> Array b -> Array c
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a) ([Int] -> Array b -> Array (Array b)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array b
b)))
    Maybe (Array c)
forall a. Maybe a
Nothing
    (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= (Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b :: [Int]))

-- | Modify using the supplied function along dimensions & positions.
--
-- >>> pretty $ modifies (fmap (100+)) [2] [0] a
-- [[[100,1,2,3],
--   [104,5,6,7],
--   [108,9,10,11]],
--  [[112,13,14,15],
--   [116,17,18,19],
--   [120,21,22,23]]]
modifies ::
  (Array a -> Array a) ->
  Dims ->
  [Int] ->
  Array a ->
  Array a
modifies :: forall a.
(Array a -> Array a) -> [Int] -> [Int] -> Array a -> Array a
modifies Array a -> Array a
f [Int]
ds [Int]
ps Array a
a = [Int] -> Array (Array a) -> Array a
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds (Array (Array a) -> Array a) -> Array (Array a) -> Array a
forall a b. (a -> b) -> a -> b
$ [Int] -> (Array a -> Array a) -> Array (Array a) -> Array (Array a)
forall a. [Int] -> (a -> a) -> Array a -> Array a
modify [Int]
ps Array a -> Array a
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- | Apply a binary function between successive slices, across dimensions and lags.
--
-- >>> pretty $ diffs [1] [1] (zipWith (-)) a
-- [[[4,4,4,4],
--   [4,4,4,4]],
--  [[4,4,4,4],
--   [4,4,4,4]]]
diffs :: Dims -> [Int] -> (Array a -> Array a -> Array b) -> Array a -> Array b
diffs :: forall a b.
[Int]
-> [Int] -> (Array a -> Array a -> Array b) -> Array a -> Array b
diffs [Int]
ds [Int]
xs Array a -> Array a -> Array b
f Array a
a = [Int]
-> (Array a -> Array a -> Array b) -> Array a -> Array a -> Array b
forall a b c.
[Int]
-> (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
zips [Int]
ds Array a -> Array a -> Array b
f ([Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
drops [Int]
ds [Int]
xs Array a
a) ([Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
drops [Int]
ds ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Int
forall a. Num a => a -> a
P.negate [Int]
xs) Array a
a)

-- | Product two arrays using the supplied binary function.
--
-- For context, if the function is multiply, and the arrays are tensors,
-- then this can be interpreted as a [tensor product](https://en.wikipedia.org/wiki/Tensor_product).
-- The concept of a tensor product is a dense crossroad, and a complete treatment is elsewhere.  To quote the wiki article:
--
-- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category.
--
-- >>> x = array [3] [1,2,3]
-- >>> pretty $ expand (*) x x
-- [[1,2,3],
--  [2,4,6],
--  [3,6,9]]
--
-- Alternatively, expand can be understood as representing the permutation of element pairs of two arrays, so like the Applicative List instance.
--
-- >>> i2 = indices [2,2]
-- >>> pretty $ expand (,) i2 i2
-- [[[[([0,0],[0,0]),([0,0],[0,1])],
--    [([0,0],[1,0]),([0,0],[1,1])]],
--   [[([0,1],[0,0]),([0,1],[0,1])],
--    [([0,1],[1,0]),([0,1],[1,1])]]],
--  [[[([1,0],[0,0]),([1,0],[0,1])],
--    [([1,0],[1,0]),([1,0],[1,1])]],
--   [[([1,1],[0,0]),([1,1],[0,1])],
--    [([1,1],[1,0]),([1,1],[1,1])]]]]
expand ::
  (a -> b -> c) ->
  Array a ->
  Array b ->
  Array c
expand :: forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
f Array a
a Array b
b = [Int] -> ([Int] -> c) -> Array c
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b) (\[Int]
i -> a -> b -> c
f (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
r [Int]
i)) (Array b -> [Int] -> b
forall a. Array a -> [Int] -> a
index Array b
b (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
r [Int]
i)))
  where
    r :: Int
r = Array a -> Int
forall a. Array a -> Int
rank Array a
a

-- | Like expand, but permutes the first array first, rather than the second.
--
-- >>> pretty $ expand (,) v (fmap (+3) v)
-- [[(0,3),(0,4),(0,5)],
--  [(1,3),(1,4),(1,5)],
--  [(2,3),(2,4),(2,5)]]
--
-- >>> pretty $ coexpand (,) v (fmap (+3) v)
-- [[(0,3),(1,3),(2,3)],
--  [(0,4),(1,4),(2,4)],
--  [(0,5),(1,5),(2,5)]]
coexpand ::
  (a -> b -> c) ->
  Array a ->
  Array b ->
  Array c
coexpand :: forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
coexpand a -> b -> c
f Array a
a Array b
b = [Int] -> ([Int] -> c) -> Array c
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b) (\[Int]
i -> a -> b -> c
f (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
r [Int]
i)) (Array b -> [Int] -> b
forall a. Array a -> [Int] -> a
index Array b
b (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
r [Int]
i)))
  where
    r :: Int
r = Array a -> Int
forall a. Array a -> Int
rank Array a
a

-- | Contract an array by applying the supplied (folding) function on diagonal elements of the dimensions.
--
-- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2.
--
--
-- >>> pretty $ contract [1,2] sum (expand (*) m (transpose m))
-- [[5,14],
--  [14,50]]
contract ::
  Dims ->
  (Array a -> b) ->
  Array a ->
  Array b
contract :: forall a b. [Int] -> (Array a -> b) -> Array a -> Array b
contract [Int]
ds Array a -> b
f Array a
a = Array a -> b
f (Array a -> b) -> (Array a -> Array a) -> Array a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array a -> Array a
forall a. Array a -> Array a
diag (Array a -> b) -> Array (Array a) -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts ([Int] -> [Int] -> [Int]
exceptDims [Int]
ds (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a

-- | Product two arrays using the supplied function and then contract the result using the supplied matching dimensions and function.
--
-- >>> pretty $ prod [1] [0] sum (*) (range [2,3]) (range [3,2])
-- [[10,13],
--  [28,40]]
--
-- With full laziness, this computation would be equivalent to:
--
-- > f . diag <$> extracts ds' (expand g a b)
prod ::
  Dims ->
  Dims ->
  (Array c -> d) ->
  (a -> b -> c) ->
  Array a ->
  Array b ->
  Array d
prod :: forall c d a b.
[Int]
-> [Int]
-> (Array c -> d)
-> (a -> b -> c)
-> Array a
-> Array b
-> Array d
prod [Int]
ds0 [Int]
ds1 Array c -> d
g a -> b -> c
f Array a
a Array b
b = [Int] -> ([Int] -> d) -> Array d
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
S.deleteDims [Int]
ds0 (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> [Int] -> [Int] -> [Int]
S.deleteDims [Int]
ds1 (Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b)) (\[Int]
so -> Array c -> d
g (Array c -> d) -> Array c -> d
forall a b. (a -> b) -> a -> b
$ [Int] -> ([Int] -> c) -> Array c
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
S.getDims [Int]
ds0 (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) (\[Int]
si -> a -> b -> c
f (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
S.insertDims [Int]
ds0 [Int]
si (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.take Int
sp [Int]
so))) (Array b -> [Int] -> b
forall a. Array a -> [Int] -> a
index Array b
b ([Int] -> [Int] -> [Int] -> [Int]
S.insertDims [Int]
ds1 [Int]
si (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop Int
sp [Int]
so)))))
  where
    sp :: Int
sp = Array a -> Int
forall a. Array a -> Int
rank Array a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall a. [a] -> Int
S.rank [Int]
ds0

-- | A generalisation of a dot operation, which is a multiplicative expansion of two arrays and sum contraction along the middle two dimensions.
--
-- matrix multiplication
--
-- >>> pretty $ dot sum (*) m (transpose m)
-- [[5,14],
--  [14,50]]
--
-- inner product
--
-- >>> pretty $ dot sum (*) v v
-- 5
--
-- matrix-vector multiplication
-- Note that an Array with shape [3] is neither a row vector nor column vector.
--
-- >>> pretty $ dot sum (*) v (transpose m)
-- [5,14]
--
-- >>> pretty $ dot sum (*) m v
-- [5,14]
dot ::
  (Array c -> d) ->
  (a -> b -> c) ->
  Array a ->
  Array b ->
  Array d
dot :: forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array c -> d
f a -> b -> c
g Array a
a Array b
b = [Int] -> (Array c -> d) -> Array c -> Array d
forall a b. [Int] -> (Array a -> b) -> Array a -> Array b
contract [Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
r] Array c -> d
f ((a -> b -> c) -> Array a -> Array b -> Array c
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
g Array a
a Array b
b)
  where
    r :: Int
r = Array a -> Int
forall a. Array a -> Int
rank Array a
a

-- | Array multiplication.
--
-- matrix multiplication
--
-- >>> pretty $ mult m (transpose m)
-- [[5,14],
--  [14,50]]
--
-- inner product
--
-- >>> pretty $ mult v v
-- 5
--
-- matrix-vector multiplication
--
-- >>> pretty $ mult v (transpose m)
-- [5,14]
--
-- >>> pretty $ mult m v
-- [5,14]
mult ::
  (Num a) =>
  Array a ->
  Array a ->
  Array a
mult :: forall a. Num a => Array a -> Array a -> Array a
mult = (Array a -> a) -> (a -> a -> a) -> Array a -> Array a -> Array a
forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array a -> a
forall a. Num a => Array a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum a -> a -> a
forall a. Num a => a -> a -> a
(*)

-- | @windows xs@ are xs-sized windows of an array
--
-- >>> shape $ windows [2,2] (range [4,3,2])
-- [3,2,2,2,2]
windows :: [Int] -> Array a -> Array a
windows :: forall a. [Int] -> Array a -> Array a
windows [Int]
xs Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> [Int] -> [Int]
expandWindows [Int]
xs) (Int -> [Int] -> [Int]
indexWindows ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
xs)) Array a
a

-- | Find the starting positions of occurences of one array in another.
--
-- >>> a = cycle [4,4] (range [3]) :: Array Int
-- >>> i = array [2,2] [1,2,2,0] :: Array Int
-- >>> pretty $ find i a
-- [[False,True,False],
--  [True,False,False],
--  [False,False,True]]
find :: (Eq a) => Array a -> Array a -> Array Bool
find :: forall a. Eq a => Array a -> Array a -> Array Bool
find Array a
i Array a
a = Array Bool
xs
  where
    i' :: Array a
i' = Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
rerank (Array a -> Int
forall a. Array a -> Int
rank Array a
a) Array a
i
    ws :: Array a
ws = [Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
windows (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
i') Array a
a
    xs :: Array Bool
xs = (Array a -> Bool) -> Array (Array a) -> Array Bool
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Array a -> Array a -> Bool
forall a. Eq a => a -> a -> Bool
== Array a
i') ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts ([Int] -> [Int] -> [Int]
dimWindows ([Int] -> [Int] -> [Int]
expandWindows (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
i') (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
ws)

-- | Find the ending positions of one array in another except where the array overlaps with another copy.
--
-- >>> a = konst [5,5] 1 :: Array Int
-- >>> i = konst [2,2] 1 :: Array Int
-- >>> pretty $ findNoOverlap i a
-- [[True,False,True,False],
--  [False,False,False,False],
--  [True,False,True,False],
--  [False,False,False,False]]
findNoOverlap :: (Eq a) => Array a -> Array a -> Array Bool
findNoOverlap :: forall a. Eq a => Array a -> Array a -> Array Bool
findNoOverlap Array a
i Array a
a = Array Bool
r
  where
    f :: Array Bool
    f :: Array Bool
f = Array a -> Array a -> Array Bool
forall a. Eq a => Array a -> Array a -> Array Bool
find Array a
i Array a
a

    cl :: [Int] -> [[Int]]
    cl :: [Int] -> [[Int]]
cl [Int]
sh = ([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
List.filter (Bool -> Bool
P.not (Bool -> Bool) -> ([Int] -> Bool) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) ([Int] -> Bool) -> ([Int] -> [Int]) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
List.init) ([[Int]] -> [[Int]]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ ([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
List.filter (Bool -> Bool
P.not (Bool -> Bool) -> ([Int] -> Bool) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)) ([[Int]] -> [[Int]]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ Array [Int] -> [[Int]]
forall t a. FromArray t a => Array a -> t
arrayAs (Array [Int] -> [[Int]]) -> Array [Int] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ [Int] -> ([Int] -> [Int]) -> Array [Int]
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ((\Int
x -> Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int) -> [Int] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
sh) (\[Int]
s -> (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (\Int
x Int
x0 -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
s [Int]
sh)
    go :: Array Bool -> [Int] -> Bool
go Array Bool
r' [Int]
s = Array Bool -> [Int] -> Bool
forall a. Array a -> [Int] -> a
index Array Bool
f [Int]
s Bool -> Bool -> Bool
&& Bool -> Bool
not (([Int] -> Bool) -> [[Int]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Array Bool -> [Int] -> Bool
forall a. Array a -> [Int] -> a
index Array Bool
r') (([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
List.filter (\[Int]
x -> [Int] -> [Int] -> Bool
isFins [Int]
x (Array Bool -> [Int]
forall a. Array a -> [Int]
shape Array Bool
f)) ([[Int]] -> [[Int]]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
s) ([Int] -> [[Int]]
cl (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
i))))
    r :: Array Bool
r = [Int] -> ([Int] -> Bool) -> Array Bool
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Array Bool -> [Int]
forall a. Array a -> [Int]
shape Array Bool
f) (Array Bool -> [Int] -> Bool
go Array Bool
r)

-- | Find the indices of the starting location of one array in another.
--
-- >>> b = cycle [4,4] (range [3]) :: Array Int
-- >>> i = array [2,2] [1,2,2,0] :: Array Int
-- >>> pretty $ findIndices i b
-- [[0,1],[1,0],[2,2]]
findIndices :: (Eq a) => Array a -> Array a -> Array [Int]
findIndices :: forall a. Eq a => Array a -> Array a -> Array [Int]
findIndices Array a
i Array a
a = (([Int], Bool) -> [Int]) -> Array ([Int], Bool) -> Array [Int]
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int], Bool) -> [Int]
forall a b. (a, b) -> a
fst (Array ([Int], Bool) -> Array [Int])
-> Array ([Int], Bool) -> Array [Int]
forall a b. (a -> b) -> a -> b
$ Vector ([Int], Bool) -> Array ([Int], Bool)
forall t a. FromVector t a => Vector a -> t
vectorAs (Vector ([Int], Bool) -> Array ([Int], Bool))
-> Vector ([Int], Bool) -> Array ([Int], Bool)
forall a b. (a -> b) -> a -> b
$ (([Int], Bool) -> Bool)
-> Vector ([Int], Bool) -> Vector ([Int], Bool)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter ([Int], Bool) -> Bool
forall a b. (a, b) -> b
snd (Vector ([Int], Bool) -> Vector ([Int], Bool))
-> Vector ([Int], Bool) -> Vector ([Int], Bool)
forall a b. (a -> b) -> a -> b
$ Array ([Int], Bool) -> Vector ([Int], Bool)
forall t a. FromVector t a => t -> Vector a
asVector (Array ([Int], Bool) -> Vector ([Int], Bool))
-> Array ([Int], Bool) -> Vector ([Int], Bool)
forall a b. (a -> b) -> a -> b
$ ([Int] -> Bool -> ([Int], Bool))
-> Array Bool -> Array ([Int], Bool)
forall a b. ([Int] -> a -> b) -> Array a -> Array b
imap (,) Array Bool
b
  where
    b :: Array Bool
b = Array a -> Array a -> Array Bool
forall a. Eq a => Array a -> Array a -> Array Bool
find Array a
i Array a
a

-- | Check if the first array is a prefix of the second
--
-- >>> isPrefixOf (array [2,2] [0,1,4,5]) a
-- True
isPrefixOf :: (Eq a) => Array a -> Array a -> Bool
isPrefixOf :: forall a. Eq a => Array a -> Array a -> Bool
isPrefixOf Array a
p Array a
a = Array a
p Array a -> Array a -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
cut (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
p) Array a
a

-- | Check if the first array is a suffix of the second
--
-- >>> isSuffixOf (array [2,2] [18,19,22,23]) a
-- True
isSuffixOf :: (Eq a) => Array a -> Array a -> Bool
isSuffixOf :: forall a. Eq a => Array a -> Array a -> Bool
isSuffixOf Array a
p Array a
a = Array a
p Array a -> Array a -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
cutSuffix (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
p) Array a
a

-- | Check if the first array is an infix of the second
--
-- >>> isInfixOf (array [2,2] [18,19,22,23]) a
-- True
isInfixOf :: (Eq a) => Array a -> Array a -> Bool
isInfixOf :: forall a. Eq a => Array a -> Array a -> Bool
isInfixOf Array a
p Array a
a = Array Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (Array Bool -> Bool) -> Array Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Array a -> Array a -> Array Bool
forall a. Eq a => Array a -> Array a -> Array Bool
find Array a
p Array a
a

-- * shape manipulation

-- | Fill an array with the supplied value without regard to the original shape or cut the array values to match array size.
--
-- > validate (def x a) == True
--
-- >>> pretty $ fill 0 (array [3] [])
-- [0,0,0]
-- >>> pretty $ fill 0 (array [3] [1..4])
-- [1,2,3]
fill :: a -> Array a -> Array a
fill :: forall a. a -> Array a -> Array a
fill a
x (UnsafeArray [Int]
s Vector a
v) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
s (Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.take ([Int] -> Int
S.size [Int]
s) (Vector a
v Vector a -> Vector a -> Vector a
forall a. Semigroup a => a -> a -> a
<> Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate ([Int] -> Int
S.size [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v) a
x))

-- | Cut an array to form a new (smaller) shape. Errors if the new shape is larger. The old array is reranked to the rank of the new shape first.
--
-- >>> cut [2] (array [4] [0..3] :: Array Int)
-- UnsafeArray [2] [0,1]
cut ::
  [Int] ->
  Array a ->
  Array a
cut :: forall a. [Int] -> Array a -> Array a
cut [Int]
s' Array a
a = Array a -> Array a -> Bool -> Array a
forall a. a -> a -> Bool -> a
bool (String -> Array a
forall a. HasCallStack => String -> a
error String
"bad cut") ([Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s' (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a')) ([Int] -> [Int] -> Bool
isSubset [Int]
s' (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a))
  where
    a' :: Array a
a' = Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
rerank ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
s') Array a
a

-- | Cut an array to form a new (smaller) shape, using suffix elements. Errors if the new shape is larger. The old array is reranked to the rank of the new shape first.
--
-- >>> cutSuffix [2,2] a
-- UnsafeArray [2,2] [18,19,22,23]
cutSuffix ::
  [Int] ->
  Array a ->
  Array a
cutSuffix :: forall a. [Int] -> Array a -> Array a
cutSuffix [Int]
s' Array a
a = Array a -> Array a -> Bool -> Array a
forall a. a -> a -> Bool -> a
bool (String -> Array a
forall a. HasCallStack => String -> a
error String
"bad cut") ([Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s' (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a' ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) [Int]
diffDim)) ([Int] -> [Int] -> Bool
isSubset [Int]
s' (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a))
  where
    a' :: Array a
a' = Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
rerank ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
s') Array a
a
    diffDim :: [Int]
diffDim = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a') [Int]
s'

-- | Pad an array to form a new shape, supplying a default value for elements outside the shape of the old array. The old array is reranked to the rank of the new shape first.
--
-- >>> pad 0 [5] (array [4] [0..3] :: Array Int)
-- UnsafeArray [5] [0,1,2,3,0]
pad ::
  a ->
  [Int] ->
  Array a ->
  Array a
pad :: forall a. a -> [Int] -> Array a -> Array a
pad a
d [Int]
s' Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s' (\[Int]
s -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
d (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a' [Int]
s) ([Int]
s [Int] -> [Int] -> Bool
`isFins` Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a'))
  where
    a' :: Array a
a' = Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
rerank ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
s') Array a
a

-- | Left pad an array to form a new shape, supplying a default value for elements outside the shape of the old array.
--
-- >>> lpad 0 [5] (array [4] [0..3] :: Array Int)
-- UnsafeArray [5] [0,0,1,2,3]
-- >>> pretty $ lpad 0 [3,3] (range [2,2] :: Array Int)
-- [[0,0,0],
--  [0,0,1],
--  [0,2,3]]
lpad ::
  a ->
  [Int] ->
  Array a ->
  Array a
lpad :: forall a. a -> [Int] -> Array a -> Array a
lpad a
d [Int]
s' Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s' (\[Int]
s -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
d (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a' ([Int] -> [Int]
olds [Int]
s)) ([Int] -> [Int]
olds [Int]
s [Int] -> [Int] -> Bool
`S.isFins` Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a'))
  where
    a' :: Array a
a' = Int -> Array a -> Array a
forall a. Int -> Array a -> Array a
rerank ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
s') Array a
a
    gap :: [Int]
gap = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) [Int]
s' (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a')
    olds :: [Int] -> [Int]
olds [Int]
s = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith (-) [Int]
s [Int]
gap

-- | Reshape an array (with the same or less number of elements).
--
-- >>> pretty $ reshape [4,3,2] a
-- [[[0,1],
--   [2,3],
--   [4,5]],
--  [[6,7],
--   [8,9],
--   [10,11]],
--  [[12,13],
--   [14,15],
--   [16,17]],
--  [[18,19],
--   [20,21],
--   [22,23]]]
reshape ::
  [Int] ->
  Array a ->
  Array a
reshape :: forall a. [Int] -> Array a -> Array a
reshape [Int]
s Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> [Int] -> [Int]
forall a b. a -> b -> a
const [Int]
s) ([Int] -> Int -> [Int]
shapen (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
flatten [Int]
s) Array a
a

-- | Make an Array single dimensional.
--
-- >>> pretty $ flat (range [2,2])
-- [0,1,2,3]
-- >>> pretty (flat $ toScalar 0)
-- [0]
flat :: Array a -> Array a
flat :: forall a. Array a -> Array a
flat Array a
a = ([Int] -> [Int]) -> Array a -> Array a
forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape (Int -> [Int]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int
S.size) Array a
a

-- | Reshape an array, repeating the original array. The shape of the array should be a suffix of the new shape.
--
-- >>> pretty $ repeat [2,2,2] (array [2] [1,2])
-- [[[1,2],
--   [1,2]],
--  [[1,2],
--   [1,2]]]
--
-- > repeat ds (toScalar x) == konst ds x
repeat ::
  [Int] ->
  Array a ->
  Array a
repeat :: forall a. [Int] -> Array a -> Array a
repeat [Int]
s Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> [Int] -> [Int]
forall a b. a -> b -> a
const [Int]
s) (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
List.drop ([Int] -> Int
forall a. [a] -> Int
S.rank [Int]
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Array a -> Int
forall a. Array a -> Int
rank Array a
a)) Array a
a

-- | Reshape an array, cycling through the elements without regard to the original shape.
--
-- >>> pretty $ cycle [2,2,2] (array [3] [1,2,3])
-- [[[1,2],
--   [3,1]],
--  [[2,3],
--   [1,2]]]
cycle ::
  [Int] ->
  Array a ->
  Array a
cycle :: forall a. [Int] -> Array a -> Array a
cycle [Int]
s Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> [Int] -> [Int]
forall a b. a -> b -> a
const [Int]
s) ([Int] -> Int -> [Int]
shapen (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Array a -> Int
forall a. Array a -> Int
size Array a
a) (Int -> Int) -> ([Int] -> Int) -> [Int] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
flatten [Int]
s) Array a
a

-- | Change rank by adding new dimensions at the front, if the new rank is greater, or combining dimensions (from left to right) into rows, if the new rank is lower.
--
-- >>> shape (rerank 4 a)
-- [1,2,3,4]
-- >>> shape (rerank 2 a)
-- [6,4]
--
-- > flat == rerank 1
rerank :: Int -> Array a -> Array a
rerank :: forall a. Int -> Array a -> Array a
rerank Int
r Array a
a = ([Int] -> [Int]) -> Array a -> Array a
forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape (Int -> [Int] -> [Int]
S.rerank Int
r) Array a
a

-- | Change the order of dimensions.
--
-- >>> pretty $ reorder [2,0,1] a
-- [[[0,4,8],
--   [12,16,20]],
--  [[1,5,9],
--   [13,17,21]],
--  [[2,6,10],
--   [14,18,22]],
--  [[3,7,11],
--   [15,19,23]]]
reorder ::
  Dims ->
  Array a ->
  Array a
reorder :: forall a. [Int] -> Array a -> Array a
reorder [Int]
ds Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> [Int] -> [Int]
`S.reorder` [Int]
ds) (\[Int]
s -> [Int] -> [Int] -> [Int] -> [Int]
insertDims [Int]
ds [Int]
s []) Array a
a

-- | Remove single dimensions.
--
-- >>> let sq = array [2,1,3,4,1] [1..24] :: Array Int
-- >>> shape $ squeeze sq
-- [2,3,4]
--
-- >>> shape $ squeeze (singleton 0)
-- []
squeeze ::
  Array a ->
  Array a
squeeze :: forall a. Array a -> Array a
squeeze Array a
a = ([Int] -> [Int]) -> Array a -> Array a
forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape [Int] -> [Int]
S.squeeze Array a
a

-- | Insert a single dimension at the supplied position.
--
-- >>> shape $ elongate 1 a
-- [2,1,3,4]
-- >>> elongate 0 (toScalar 1)
-- UnsafeArray [1] [1]
elongate ::
  Dim ->
  Array a ->
  Array a
elongate :: forall a. Int -> Array a -> Array a
elongate Int
d Array a
a = ([Int] -> [Int]) -> Array a -> Array a
forall a. ([Int] -> [Int]) -> Array a -> Array a
unsafeModifyShape (Int -> Int -> [Int] -> [Int]
insertDim Int
d Int
1) Array a
a

-- | Reverse indices eg transposes the element A/ijk/ to A/kji/.
--
-- >>> index (transpose a) [1,0,0] == index a [0,0,1]
-- True
-- >>> pretty $ transpose (array [2,2,2] [1..8])
-- [[[1,5],
--   [3,7]],
--  [[2,6],
--   [4,8]]]
transpose :: Array a -> Array a
transpose :: forall a. Array a -> Array a
transpose Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
forall a. [a] -> [a]
List.reverse [Int] -> [Int]
forall a. [a] -> [a]
List.reverse Array a
a

-- | Inflate an array by inserting a new dimension given a supplied dimension and size.
--
-- alt name: replicate
--
-- >>> pretty $ inflate 0 2 (array [3] [0,1,2])
-- [[0,1,2],
--  [0,1,2]]
inflate ::
  Dim ->
  Int ->
  Array a ->
  Array a
inflate :: forall a. Int -> Int -> Array a -> Array a
inflate Int
d Int
n Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute (Int -> Int -> [Int] -> [Int]
insertDim Int
d Int
n) (Int -> [Int] -> [Int]
deleteDim Int
d) Array a
a

-- | Intercalate an array along dimensions.
--
-- >>> pretty $ intercalate 2 (konst [2,3] 0) a
-- [[[0,0,1,0,2,0,3],
--   [4,0,5,0,6,0,7],
--   [8,0,9,0,10,0,11]],
--  [[12,0,13,0,14,0,15],
--   [16,0,17,0,18,0,19],
--   [20,0,21,0,22,0,23]]]
intercalate :: Dim -> Array a -> Array a -> Array a
intercalate :: forall a. Int -> Array a -> Array a -> Array a
intercalate Int
d Array a
i Array a
a = [Int] -> Array (Array a) -> Array a
forall a. [Int] -> Array (Array a) -> Array a
joins [Int
d] (Array (Array a) -> Array a) -> Array (Array a) -> Array a
forall a b. (a -> b) -> a -> b
$ [Array a] -> Array (Array a)
forall t a. FromArray t a => t -> Array a
asArray (Array a -> [Array a] -> [Array a]
forall a. a -> [a] -> [a]
List.intersperse Array a
i (Array (Array a) -> [Array a]
forall t a. FromArray t a => Array a -> t
arrayAs ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int
d] Array a
a)))

-- | Intersperse an element along dimensions.
--
-- >>> pretty $ intersperse 2 0 a
-- [[[0,0,1,0,2,0,3],
--   [4,0,5,0,6,0,7],
--   [8,0,9,0,10,0,11]],
--  [[12,0,13,0,14,0,15],
--   [16,0,17,0,18,0,19],
--   [20,0,21,0,22,0,23]]]
intersperse :: Dim -> a -> Array a -> Array a
intersperse :: forall a. Int -> a -> Array a -> Array a
intersperse Int
d a
i Array a
a = Int -> Array a -> Array a -> Array a
forall a. Int -> Array a -> Array a -> Array a
intercalate Int
d ([Int] -> a -> Array a
forall a. [Int] -> a -> Array a
konst (Int -> [Int] -> [Int]
deleteDim Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) a
i) Array a
a

-- | Concatenate and replace dimensions, creating a new dimension at the supplied postion.
--
-- >>> pretty $ concats [0,1] 1 a
-- [[0,4,8,12,16,20],
--  [1,5,9,13,17,21],
--  [2,6,10,14,18,22],
--  [3,7,11,15,19,23]]
concats ::
  Dims ->
  Int ->
  Array a ->
  Array a
concats :: forall a. [Int] -> Int -> Array a -> Array a
concats [Int]
ds Int
n Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute ([Int] -> Int -> [Int] -> [Int]
concatDims [Int]
ds Int
n) ([Int] -> Int -> [Int] -> [Int] -> [Int]
unconcatDimsIndex [Int]
ds Int
n (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a

-- | Reverses element order along specified dimensions.
--
-- >>> pretty $ reverses [0,1] a
-- [[[20,21,22,23],
--   [16,17,18,19],
--   [12,13,14,15]],
--  [[8,9,10,11],
--   [4,5,6,7],
--   [0,1,2,3]]]
reverses ::
  Dims ->
  Array a ->
  Array a
reverses :: forall a. [Int] -> Array a -> Array a
reverses [Int]
ds Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
forall a. a -> a
id ([Int] -> [Int] -> [Int] -> [Int]
reverseIndex [Int]
ds (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a

-- | Rotate an array by/along dimensions & offsets.
--
-- >>> pretty $ rotates [1] [2] a
-- [[[8,9,10,11],
--   [0,1,2,3],
--   [4,5,6,7]],
--  [[20,21,22,23],
--   [12,13,14,15],
--   [16,17,18,19]]]
rotates ::
  Dims ->
  [Int] ->
  Array a ->
  Array a
rotates :: forall a. [Int] -> [Int] -> Array a -> Array a
rotates [Int]
ds [Int]
rs Array a
a = ([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
forall a.
([Int] -> [Int]) -> ([Int] -> [Int]) -> Array a -> Array a
backpermute [Int] -> [Int]
forall a. a -> a
id ([Int] -> [Int] -> [Int] -> [Int] -> [Int]
rotatesIndex [Int]
ds [Int]
rs (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Array a
a

-- * sorting

-- | Sort an array along the supplied dimensions.
--
-- >>> sorts [0] (array [2,2] [2,3,1,4])
-- UnsafeArray [2,2] [1,4,2,3]
-- >>> sorts [1] (array [2,2] [2,3,1,4])
-- UnsafeArray [2,2] [2,3,1,4]
-- >>> sorts [0,1] (array [2,2] [2,3,1,4])
-- UnsafeArray [2,2] [1,2,3,4]
sorts :: (Ord a) => Dims -> Array a -> Array a
sorts :: forall a. Ord a => [Int] -> Array a -> Array a
sorts [Int]
ds Array a
a = [Int] -> Array (Array a) -> Array a
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds (Array (Array a) -> Array a) -> Array (Array a) -> Array a
forall a b. (a -> b) -> a -> b
$ (Vector (Array a) -> Vector (Array a))
-> Array (Array a) -> Array (Array a)
forall u a v b.
(FromVector u a, FromVector v b) =>
(u -> v) -> Array a -> Array b
unsafeModifyVector Vector (Array a) -> Vector (Array a)
forall a. Ord a => Vector a -> Vector a
sortV ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- | The indices into the array if it were sorted by a comparison function along the dimensions supplied.
--
-- >>> import Data.Ord (Down (..))
-- >>> sortsBy [0] (fmap Down) (array [2,2] [2,3,1,4])
-- UnsafeArray [2,2] [2,3,1,4]
sortsBy :: (Ord b) => Dims -> (Array a -> Array b) -> Array a -> Array a
sortsBy :: forall b a.
Ord b =>
[Int] -> (Array a -> Array b) -> Array a -> Array a
sortsBy [Int]
ds Array a -> Array b
c Array a
a = [Int] -> Array (Array a) -> Array a
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds (Array (Array a) -> Array a) -> Array (Array a) -> Array a
forall a b. (a -> b) -> a -> b
$ (Vector (Array a) -> Vector (Array a))
-> Array (Array a) -> Array (Array a)
forall u a v b.
(FromVector u a, FromVector v b) =>
(u -> v) -> Array a -> Array b
unsafeModifyVector ((Array a -> Array b) -> Vector (Array a) -> Vector (Array a)
forall b a. Ord b => (a -> b) -> Vector a -> Vector a
sortByV Array a -> Array b
c) ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- | The indices into the array if it were sorted along the dimensions supplied.
--
-- >>> orders [0] (array [2,2] [2,3,1,4])
-- UnsafeArray [2] [1,0]
orders :: (Ord a) => Dims -> Array a -> Array Int
orders :: forall a. Ord a => [Int] -> Array a -> Array Int
orders [Int]
ds Array a
a = (Vector (Array a) -> Vector Int) -> Array (Array a) -> Array Int
forall u a v b.
(FromVector u a, FromVector v b) =>
(u -> v) -> Array a -> Array b
unsafeModifyVector Vector (Array a) -> Vector Int
forall a. Ord a => Vector a -> Vector Int
orderV ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- | The indices into the array if it were sorted by a comparison function along the dimensions supplied.
--
-- >>> import Data.Ord (Down (..))
-- >>> ordersBy [0] (fmap Down) (array [2,2] [2,3,1,4])
-- UnsafeArray [2] [0,1]
ordersBy :: (Ord b) => Dims -> (Array a -> Array b) -> Array a -> Array Int
ordersBy :: forall b a.
Ord b =>
[Int] -> (Array a -> Array b) -> Array a -> Array Int
ordersBy [Int]
ds Array a -> Array b
c Array a
a = (Vector (Array a) -> Vector Int) -> Array (Array a) -> Array Int
forall u a v b.
(FromVector u a, FromVector v b) =>
(u -> v) -> Array a -> Array b
unsafeModifyVector ((Array a -> Array b) -> Vector (Array a) -> Vector Int
forall b a. Ord b => (a -> b) -> Vector a -> Vector Int
orderByV Array a -> Array b
c) ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a)

-- * transmission

-- | Apply a binary array function to two arrays with matching shapes across the supplied dimensions. No check on shapes.
--
-- >>> a = array [2,3] [0..5]
-- >>> b = array [3] [0..2]
-- >>> pretty $ telecasts [1] [0] (concatenate 0) a b
-- [[0,1,2],
--  [3,4,5],
--  [0,1,2]]
telecasts :: Dims -> Dims -> (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
telecasts :: forall a b c.
[Int]
-> [Int]
-> (Array a -> Array b -> Array c)
-> Array a
-> Array b
-> Array c
telecasts [Int]
dsa [Int]
dsb Array a -> Array b -> Array c
f Array a
a Array b
b = (Array a -> Array b -> Array c)
-> Array (Array a) -> Array (Array b) -> Array (Array c)
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith Array a -> Array b -> Array c
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
dsa Array a
a) ([Int] -> Array b -> Array (Array b)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
dsb Array b
b) Array (Array c) -> (Array (Array c) -> Array c) -> Array c
forall a b. a -> (a -> b) -> b
& [Int] -> Array (Array c) -> Array c
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
dsa

-- | Apply a binary array function to two arrays with matching shapes across the supplied dimensions. Checks shape.
--
-- >>> a = array [2,3] [0..5]
-- >>> b = array [1] [1]
-- >>> telecastsSafe [0] [0] (zipWith (+)) a b
-- Nothing
telecastsSafe :: Dims -> Dims -> (Array a -> Array b -> Array c) -> Array a -> Array b -> Maybe (Array c)
telecastsSafe :: forall a b c.
[Int]
-> [Int]
-> (Array a -> Array b -> Array c)
-> Array a
-> Array b
-> Maybe (Array c)
telecastsSafe [Int]
dsa [Int]
dsb Array a -> Array b -> Array c
f Array a
a Array b
b =
  Maybe (Array c) -> Maybe (Array c) -> Bool -> Maybe (Array c)
forall a. a -> a -> Bool -> a
bool
    (Array c -> Maybe (Array c)
forall a. a -> Maybe a
Just (Array c -> Maybe (Array c)) -> Array c -> Maybe (Array c)
forall a b. (a -> b) -> a -> b
$ [Int]
-> [Int]
-> (Array a -> Array b -> Array c)
-> Array a
-> Array b
-> Array c
forall a b c.
[Int]
-> [Int]
-> (Array a -> Array b -> Array c)
-> Array a
-> Array b
-> Array c
telecasts [Int]
dsa [Int]
dsb Array a -> Array b -> Array c
f Array a
a Array b
b)
    Maybe (Array c)
forall a. Maybe a
Nothing
    (Array (Array a) -> [Int]
forall a. Array a -> [Int]
shape ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
dsa Array a
a) [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= (Array (Array b) -> [Int]
forall a. Array a -> [Int]
shape ([Int] -> Array b -> Array (Array b)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
dsb Array b
b) :: [Int]))

-- | Apply a binary array function to two arrays where the shape of the first array is a prefix of the second array. No checks on shape.
--
-- >>> a = array [2,3] [0..5]
-- >>> pretty $ transmit (zipWith (+)) (toScalar 1) a
-- [[1,2,3],
--  [4,5,6]]
transmit :: (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
transmit :: forall a b c.
(Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
transmit Array a -> Array b -> Array c
f Array a
a Array b
b = [Int] -> (Array b -> Array c) -> Array b -> Array c
forall a b. [Int] -> (Array a -> Array b) -> Array a -> Array b
maps [Int]
ds (Array a -> Array b -> Array c
f Array a
a) Array b
b
  where
    ds :: [Int]
ds = [(Array a -> Int
forall a. Array a -> Int
rank Array a
a) .. (Array b -> Int
forall a. Array a -> Int
rank Array b
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]

-- | Apply a binary array function to two arrays where the shape of the first array is a prefix of the second array. Checks shape.
--
-- >>> a = array [2,3] [0..5]
-- >>> transmitSafe (zipWith (+)) (array [3] [1,2,3]) a
-- Nothing
transmitSafe :: (Array a -> Array b -> Array c) -> Array a -> Array b -> Maybe (Array c)
transmitSafe :: forall a b c.
(Array a -> Array b -> Array c)
-> Array a -> Array b -> Maybe (Array c)
transmitSafe Array a -> Array b -> Array c
f Array a
a Array b
b = Maybe (Array c) -> Maybe (Array c) -> Bool -> Maybe (Array c)
forall a. a -> a -> Bool -> a
bool Maybe (Array c)
forall a. Maybe a
Nothing (Array c -> Maybe (Array c)
forall a. a -> Maybe a
Just (Array c -> Maybe (Array c)) -> Array c -> Maybe (Array c)
forall a b. (a -> b) -> a -> b
$ (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
forall a b c.
(Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
transmit Array a -> Array b -> Array c
f Array a
a Array b
b) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`List.isPrefixOf` Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b)

-- | Transmit an operation if the first array is a prefix of the second or vice versa.
--
-- >>> pretty $ transmitOp (*) a (asArray [1,2])
-- [[[0,1,2,3],
--   [4,5,6,7],
--   [8,9,10,11]],
--  [[24,26,28,30],
--   [32,34,36,38],
--   [40,42,44,46]]]
transmitOp :: (a -> b -> c) -> Array a -> Array b -> Array c
transmitOp :: forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
transmitOp a -> b -> c
f Array a
a Array b
b
  | Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b = (a -> b -> c) -> Array a -> Array b -> Array c
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith a -> b -> c
f Array a
a Array b
b
  | Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a [Int] -> [Int] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`List.isPrefixOf` Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b = (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
forall a b c.
(Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
transmit ((a -> b -> c) -> Array a -> Array b -> Array c
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith a -> b -> c
f) Array a
a Array b
b
  | Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b [Int] -> [Int] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`List.isPrefixOf` Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a = (Array b -> Array a -> Array c) -> Array b -> Array a -> Array c
forall a b c.
(Array a -> Array b -> Array c) -> Array a -> Array b -> Array c
transmit ((b -> a -> c) -> Array b -> Array a -> Array c
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith ((a -> b -> c) -> b -> a -> c
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> b -> c
f)) Array b
b Array a
a
  | Bool
otherwise = String -> Array c
forall a. HasCallStack => String -> a
error String
"bad shapes"

-- | Vector specialisation of 'range'
--
-- >>> iota 5
-- UnsafeArray [5] [0,1,2,3,4]
iota :: Int -> Array Int
iota :: Int -> Array Int
iota Int
n = [Int] -> Array Int
range [Int
n]

-- * row (first dimension) specializations

-- | Add a new row
--
-- >>> pretty $ cons (array [2] [0,1]) (array [2,2] [2,3,4,5])
-- [[0,1],
--  [2,3],
--  [4,5]]
cons :: Array a -> Array a -> Array a
cons :: forall a. Array a -> Array a -> Array a
cons = Int -> Array a -> Array a -> Array a
forall a. Int -> Array a -> Array a -> Array a
prepend Int
0

-- | split an array into the first row and the remaining rows.
--
-- >>> uncons (array [3,2] [0..5])
-- (UnsafeArray [2] [0,1],UnsafeArray [2,2] [2,3,4,5])
uncons :: Array a -> (Array a, Array a)
uncons :: forall a. Array a -> (Array a, Array a)
uncons Array a
a = ([Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
heads [Int
0] Array a
a', [Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
tails [Int
0] Array a
a')
  where
    a' :: Array a
a' = Array a -> Array a
forall a. Array a -> Array a
asSingleton Array a
a

-- | Convenience pattern for row extraction and consolidation at the beginning of an Array.
--
-- >>> (x:<xs) = array [4] [0..3]
-- >>> x
-- UnsafeArray [] [0]
-- >>> xs
-- UnsafeArray [3] [1,2,3]
-- >>> (x:<xs)
-- UnsafeArray [4] [0,1,2,3]
pattern (:<) :: Array a -> Array a -> Array a
pattern x $m:< :: forall {r} {a}.
Array a -> (Array a -> Array a -> r) -> ((# #) -> r) -> r
$b:< :: forall a. Array a -> Array a -> Array a
:< xs <- (uncons -> (x, xs))
  where
    Array a
x :< Array a
xs = Array a -> Array a -> Array a
forall a. Array a -> Array a -> Array a
cons Array a
x Array a
xs

infix 5 :<

{-# COMPLETE (:<) :: Array #-}

-- | Add a new row at the end
--
-- >>> pretty $ snoc (array [2,2] [0,1,2,3]) (array [2] [4,5])
-- [[0,1],
--  [2,3],
--  [4,5]]
snoc :: Array a -> Array a -> Array a
snoc :: forall a. Array a -> Array a -> Array a
snoc = Int -> Array a -> Array a -> Array a
forall a. Int -> Array a -> Array a -> Array a
append Int
0

-- | split an array into the initial rows and the last row.
--
-- >>> unsnoc (array [3,2] [0..5])
-- (UnsafeArray [2,2] [0,1,2,3],UnsafeArray [2] [4,5])
unsnoc :: Array a -> (Array a, Array a)
unsnoc :: forall a. Array a -> (Array a, Array a)
unsnoc Array a
a = ([Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
inits [Int
0] Array a
a', [Int] -> Array a -> Array a
forall a. [Int] -> Array a -> Array a
lasts [Int
0] Array a
a')
  where
    a' :: Array a
a' = Array a -> Array a
forall a. Array a -> Array a
asSingleton Array a
a

-- | Convenience pattern for row extraction and consolidation at the end of an Array.
--
-- >>> (xs:>x) = array [4] [0..3]
-- >>> x
-- UnsafeArray [] [3]
-- >>> xs
-- UnsafeArray [3] [0,1,2]
-- >>> (xs:>x)
-- UnsafeArray [4] [0,1,2,3]
pattern (:>) :: Array a -> Array a -> Array a
pattern xs $m:> :: forall {r} {a}.
Array a -> (Array a -> Array a -> r) -> ((# #) -> r) -> r
$b:> :: forall a. Array a -> Array a -> Array a
:> x <- (unsnoc -> (xs, x))
  where
    Array a
xs :> Array a
x = Array a -> Array a -> Array a
forall a. Array a -> Array a -> Array a
snoc Array a
xs Array a
x

infix 5 :>

{-# COMPLETE (:>) :: Array #-}

-- * Math

-- | Generate an array of uniform random variates between a range.
--
-- >>> import System.Random.Stateful hiding (uniform)
-- >>> g <- newIOGenM (mkStdGen 42)
-- >>> u <- uniform g [2,3,4] (0,9 :: Int)
-- >>> pretty u
-- [[[0,7,0,2],
--   [1,7,4,2],
--   [5,9,8,2]],
--  [[9,8,1,0],
--   [2,2,8,2],
--   [2,8,0,6]]]
uniform :: (StatefulGen g m, UniformRange a) => g -> [Int] -> (a, a) -> m (Array a)
uniform :: forall g (m :: * -> *) a.
(StatefulGen g m, UniformRange a) =>
g -> [Int] -> (a, a) -> m (Array a)
uniform g
g [Int]
ds (a, a)
r = do
  Vector a
v <- Int -> m a -> m (Vector a)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM ([Int] -> Int
S.size [Int]
ds) ((a, a) -> g -> m a
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (a, a) -> g -> m a
uniformRM (a, a)
r g
g)
  Array a -> m (Array a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array a -> m (Array a)) -> Array a -> m (Array a)
forall a b. (a -> b) -> a -> b
$ [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
UnsafeArray [Int]
ds Vector a
v

-- | Inverse of a square matrix.
--
-- >>> e = array [3,3] [4,12,-16,12,37,-43,-16,-43,98] :: Array Double
-- >>> pretty (inverse e)
-- [[49.36111111111111,-13.555555555555554,2.1111111111111107],
--  [-13.555555555555554,3.7777777777777772,-0.5555555555555555],
--  [2.1111111111111107,-0.5555555555555555,0.1111111111111111]]
--
-- > mult (inverse a) a == a
inverse :: (Floating a) => Array a -> Array a
inverse :: forall a. Floating a => Array a -> Array a
inverse Array a
a = Array a -> Array a -> Array a
forall a. Num a => Array a -> Array a -> Array a
mult (Array a -> Array a
forall a. Fractional a => Array a -> Array a
invtri (Array a -> Array a
forall a. Array a -> Array a
transpose (Array a -> Array a
forall a. Floating a => Array a -> Array a
chol Array a
a))) (Array a -> Array a
forall a. Fractional a => Array a -> Array a
invtri (Array a -> Array a
forall a. Floating a => Array a -> Array a
chol Array a
a))

-- | [Inversion of a Triangular Matrix](https://math.stackexchange.com/questions/1003801/inverse-of-an-invertible-upper-triangular-matrix-of-order-3)
--
-- >>> t = array [3,3] ([1,0,1,0,1,2,0,0,1] :: [Double]) :: Array Double
-- >>> pretty (invtri t)
-- [[1.0,0.0,-1.0],
--  [0.0,1.0,-2.0],
--  [0.0,0.0,1.0]]
-- >>> ident (shape t) == mult t (invtri t)
-- True
invtri :: (Fractional a) => Array a -> Array a
invtri :: forall a. Fractional a => Array a -> Array a
invtri Array a
a = Array a
i
  where
    ti :: Array a
ti = Array a -> Array a
forall a. Num a => Array a -> Array a
undiag ((a -> a) -> Array a -> Array a
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Fractional a => a -> a
recip (Array a -> Array a
forall a. Array a -> Array a
diag Array a
a))
    tl :: Array a
tl = (a -> a -> a) -> Array a -> Array a -> Array a
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith (-) Array a
a (Array a -> Array a
forall a. Num a => Array a -> Array a
undiag (Array a -> Array a
forall a. Array a -> Array a
diag Array a
a))
    l :: Array a
l = (a -> a) -> Array a -> Array a
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate ((Array a -> a) -> (a -> a -> a) -> Array a -> Array a -> Array a
forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array a -> a
forall a. Num a => Array a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum a -> a -> a
forall a. Num a => a -> a -> a
(*) Array a
ti Array a
tl)
    pow :: Array a -> Int -> Array a
pow Array a
xs Int
x = ((Array a -> Array a) -> Array a -> Array a)
-> Array a -> [Array a -> Array a] -> Array a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Array a -> Array a) -> Array a -> Array a
forall a b. (a -> b) -> a -> b
($) ([Int] -> Array a
forall a. Num a => [Int] -> Array a
ident (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
xs)) (Int -> (Array a -> Array a) -> [Array a -> Array a]
forall a. Int -> a -> [a]
replicate Int
x (Array a -> Array a -> Array a
forall a. Num a => Array a -> Array a -> Array a
mult Array a
xs))
    zero' :: Array a
zero' = [Int] -> a -> Array a
forall a. [Int] -> a -> Array a
konst (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) a
0
    add :: Array a -> Array a -> Array a
add = (a -> a -> a) -> Array a -> Array a -> Array a
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+)
    sum' :: Array (Array a) -> Array a
sum' = (Array a -> Array a -> Array a)
-> Array a -> Array (Array a) -> Array a
forall b a. (b -> a -> b) -> b -> Array a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Array a -> Array a -> Array a
add Array a
zero'
    i :: Array a
i = Array a -> Array a -> Array a
forall a. Num a => Array a -> Array a -> Array a
mult (Array (Array a) -> Array a
sum' ((Int -> Array a) -> Array Int -> Array (Array a)
forall a b. (a -> b) -> Array a -> Array b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Array a -> Int -> Array a
forall {a}. Num a => Array a -> Int -> Array a
pow Array a
l) ([Int] -> Array Int
range [Int
n]))) Array a
ti
    n :: Int
n = Int -> [Int] -> Int
S.getDim Int
0 (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)

-- | Cholesky decomposition using the <https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky_algorithm Cholesky-Crout> algorithm.
--
-- >>> e = array [3,3] [4,12,-16,12,37,-43,-16,-43,98] :: Array Double
-- >>> pretty (chol e)
-- [[2.0,0.0,0.0],
--  [6.0,1.0,0.0],
--  [-8.0,5.0,3.0]]
-- >>> mult (chol e) (transpose (chol e)) == e
-- True
chol :: (Floating a) => Array a -> Array a
chol :: forall a. Floating a => Array a -> Array a
chol Array a
a =
  let l :: Array a
l =
        [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate
          (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)
          ( \[Int
i, Int
j] ->
              a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
                ( a
1
                    a -> a -> a
forall a. Fractional a => a -> a -> a
/ Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
l [Int
j, Int
j]
                    a -> a -> a
forall a. Num a => a -> a -> a
* ( Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int
i, Int
j]
                          a -> a -> a
forall a. Num a => a -> a -> a
- [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
                            ( (\Int
k -> Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
l [Int
i, Int
k] a -> a -> a
forall a. Num a => a -> a -> a
* Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
l [Int
j, Int
k])
                                (Int -> a) -> [Int] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
0 .. (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]
                            )
                      )
                )
                ( a -> a
forall a. Floating a => a -> a
sqrt
                    ( Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int
i, Int
i]
                        a -> a -> a
forall a. Num a => a -> a -> a
- [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
                          ( (\Int
k -> Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
l [Int
j, Int
k] a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int))
                              (Int -> a) -> [Int] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
0 .. (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]
                          )
                    )
                )
                (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j)
          )
   in Array a
l