-- | Tensors (n-dimensional arrays)
--
-- This is an implementation of tensors that emphasizes simplicify above all; it
-- is meant for use in QuickCheck tests.
--
-- Intended for qualified import.
--
-- > import Test.Tensor (Tensor)
-- > import Test.Tensor qualified as Tensor
module Test.Tensor (
    -- * Definition
    Tensor(..)
  , getScalar
  , getTensor
    -- ** Convenience constructors
  , scalar
  , dim1
  , dim2
  , dim3
  , dim4
  , dim5
  , dim6
  , dim7
  , dim8
  , dim9
    -- * Size
  , Size
  , size
  , sizeAtLeast
    -- * Standard operations
  , zipWith
  , replicate
  , rotate
  , distrib
  , transpose
  , foreach
  , foreachWith
    -- * Subtensors
  , subs
  , subsWithStride
  , convolve
  , convolveWithStride
  , padWith
  , padWith'
    -- * Conversions
  , Lists
  , toLists
  , fromLists
  , fromList
    -- * QuickCheck support
    -- ** Generation
  , arbitraryOfSize
    -- ** Shrinking
  , shrinkWith
  , shrinkWith'
  , shrinkElem
    -- *** Axes
  , Axe(..)
  , allAxes
  , axeWith
  , axeSize
    -- *** Zeroing
  , Zero(..)
  , zero
  , zeroWith
    -- * FFI
  , toStorable
  , fromStorable
  , unsafeWithCArray
  , unsafeFromCArray
  , unsafeFromPrealloc
  , unsafeFromPrealloc_
  ) where

import Prelude hiding (zipWith, replicate)

import Control.Monad.Trans.State (StateT(..), evalStateT)
import Data.Bifunctor
import Data.Foldable (foldl')
import Data.Foldable qualified as Foldable
import Data.List qualified as L
import Data.Maybe (catMaybes)
import Data.Ord
import Data.Proxy
import Data.Type.Nat
import Data.Vec.Lazy (Vec(..))
import Data.Vec.Lazy qualified as Vec
import Data.Vector.Storable qualified as Storable (Vector)
import Data.Vector.Storable qualified as Vector
import Foreign hiding (rotate)
import GHC.Show (appPrec1, showSpace)
import GHC.Stack
import Numeric.Natural
import Test.QuickCheck (Arbitrary(..), Arbitrary1(..), Gen)
import Test.QuickCheck qualified as QC

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

data Tensor n a where
  Scalar :: a -> Tensor Z a
  Tensor :: [Tensor n a] -> Tensor (S n) a

deriving stock instance Eq a => Eq (Tensor n a)

deriving stock instance Functor     (Tensor n)
deriving stock instance Traversable (Tensor n)
deriving stock instance Foldable    (Tensor n)

getScalar :: Tensor Z a -> a
getScalar :: forall a. Tensor 'Z a -> a
getScalar (Scalar a
x) = a
x

getTensor :: Tensor (S n) a -> [Tensor n a]
getTensor :: forall (n :: Nat) a. Tensor ('S n) a -> [Tensor n a]
getTensor (Tensor [Tensor n a]
xs) = [Tensor n a]
[Tensor n a]
xs

{-------------------------------------------------------------------------------
  Size
-------------------------------------------------------------------------------}

type Size n = Vec n Int

-- | Analogue of 'List.length'
size :: Tensor n a -> Size n
size :: forall (n :: Nat) a. Tensor n a -> Size n
size (Scalar a
_)  = Vec n Int
Vec 'Z Int
forall a. Vec 'Z a
VNil
size (Tensor [Tensor n a]
xs) = [Tensor n a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length [Tensor n a]
xs Int -> Vec n Int -> Vec ('S n) Int
forall a (n1 :: Nat). a -> Vec n1 a -> Vec ('S n1) a
::: Tensor n a -> Vec n Int
forall (n :: Nat) a. Tensor n a -> Size n
size ([Tensor n a] -> Tensor n a
forall a. HasCallStack => [a] -> a
L.head [Tensor n a]
xs)

-- | Check that each dimension has at least the specified size
sizeAtLeast :: Size n -> Tensor n a -> Bool
sizeAtLeast :: forall (n :: Nat) a. Size n -> Tensor n a -> Bool
sizeAtLeast Size n
sz = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> (Tensor n a -> [Bool]) -> Tensor n a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec n Bool -> [Bool]
forall a. Vec n a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList (Vec n Bool -> [Bool])
-> (Tensor n a -> Vec n Bool) -> Tensor n a -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Bool) -> Size n -> Size n -> Vec n Bool
forall a b c (n :: Nat).
(a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
Vec.zipWith Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(<=) Size n
sz (Size n -> Vec n Bool)
-> (Tensor n a -> Size n) -> Tensor n a -> Vec n Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size

{-------------------------------------------------------------------------------
  Standard operations
-------------------------------------------------------------------------------}

-- | Analogue of 'List.zipWith'
zipWith :: (a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith :: forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith a -> b -> c
f (Scalar a
a)  (Scalar b
b)  = c -> Tensor 'Z c
forall a. a -> Tensor 'Z a
Scalar (a -> b -> c
f a
a b
b)
zipWith a -> b -> c
f (Tensor [Tensor n a]
as) (Tensor [Tensor n b]
bs) = [Tensor n c] -> Tensor ('S n) c
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n c] -> Tensor ('S n) c)
-> [Tensor n c] -> Tensor ('S n) c
forall a b. (a -> b) -> a -> b
$ (Tensor n a -> Tensor n b -> Tensor n c)
-> [Tensor n a] -> [Tensor n b] -> [Tensor n c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
L.zipWith ((a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith a -> b -> c
f) [Tensor n a]
as [Tensor n b]
[Tensor n b]
bs

-- | Analogue of 'List.replicate'
replicate :: Size n -> a -> Tensor n a
replicate :: forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n Int
VNil       a
x = a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x
replicate (Int
n ::: Vec n1 Int
ns) a
x = [Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor ('S n1) a)
-> [Tensor n1 a] -> Tensor ('S n1) a
forall a b. (a -> b) -> a -> b
$ Int -> Tensor n1 a -> [Tensor n1 a]
forall a. Int -> a -> [a]
L.replicate Int
n (Vec n1 Int -> a -> Tensor n1 a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
ns a
x)

-- | Analogue of 'List.reverse'
--
-- This amounts to a 180 degrees rotation of the tensor.
rotate :: Tensor n a -> Tensor n a
rotate :: forall (n :: Nat) a. Tensor n a -> Tensor n a
rotate (Scalar a
x)  = a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x
rotate (Tensor [Tensor n a]
xs) = [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$ (Tensor n a -> Tensor n a) -> [Tensor n a] -> [Tensor n a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor n a -> Tensor n a
forall (n :: Nat) a. Tensor n a -> Tensor n a
rotate ([Tensor n a] -> [Tensor n a]
forall a. [a] -> [a]
L.reverse [Tensor n a]
xs)

-- | Distribute '[]' over 'Tensor'
--
-- Collects values in corresponding in all tensors.
distrib :: [Tensor n a] -> Tensor n [a]
distrib :: forall (n :: Nat) a. [Tensor n a] -> Tensor n [a]
distrib = \case
    []   -> [Char] -> Tensor n [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"distrib: empty list"
    Tensor n a
t:[Tensor n a]
ts -> Tensor n [a] -> [Tensor n a] -> Tensor n [a]
forall (n :: Nat) a. Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go ((a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[]) (a -> [a]) -> Tensor n a -> Tensor n [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor n a
t) [Tensor n a]
ts
  where
    go :: Tensor n [a] -> [Tensor n a] -> Tensor n [a]
    go :: forall (n :: Nat) a. Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go Tensor n [a]
acc []     = [a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> Tensor n [a] -> Tensor n [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor n [a]
acc
    go Tensor n [a]
acc (Tensor n a
t:[Tensor n a]
ts) = Tensor n [a] -> [Tensor n a] -> Tensor n [a]
forall (n :: Nat) a. Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go ((a -> [a] -> [a]) -> Tensor n a -> Tensor n [a] -> Tensor n [a]
forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith (:) Tensor n a
t Tensor n [a]
acc) [Tensor n a]
ts

-- | Transpose
--
-- This is essentially a special case of 'distrib'.
transpose :: Tensor Nat2 a -> Tensor Nat2 a
transpose :: forall a. Tensor Nat2 a -> Tensor Nat2 a
transpose = [[a]] -> Tensor Nat2 a
Lists Nat2 a -> Tensor Nat2 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists ([[a]] -> Tensor Nat2 a)
-> (Tensor Nat2 a -> [[a]]) -> Tensor Nat2 a -> Tensor Nat2 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
L.transpose ([[a]] -> [[a]])
-> (Tensor Nat2 a -> [[a]]) -> Tensor Nat2 a -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Nat2 a -> [[a]]
Tensor Nat2 a -> Lists Nat2 a
forall (n :: Nat) a. Tensor n a -> Lists n a
toLists

-- | Map element over the first dimension of the tensor
foreach :: Tensor (S n) a -> (Tensor n a -> Tensor m b) -> Tensor (S m) b
foreach :: forall (n :: Nat) a (m :: Nat) b.
Tensor ('S n) a -> (Tensor n a -> Tensor m b) -> Tensor ('S m) b
foreach (Tensor [Tensor n a]
as) Tensor n a -> Tensor m b
f = [Tensor m b] -> Tensor ('S m) b
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ((Tensor n a -> Tensor m b) -> [Tensor n a] -> [Tensor m b]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map Tensor n a -> Tensor m b
f [Tensor n a]
[Tensor n a]
as)

-- | Variation of 'foreach' with an auxiliary list
foreachWith ::
    Tensor (S n) a
 -> [x]
 -> (Tensor n a -> x -> Tensor m b)
 -> Tensor (S m) b
foreachWith :: forall (n :: Nat) a x (m :: Nat) b.
Tensor ('S n) a
-> [x] -> (Tensor n a -> x -> Tensor m b) -> Tensor ('S m) b
foreachWith (Tensor [Tensor n a]
as) [x]
xs Tensor n a -> x -> Tensor m b
f = [Tensor m b] -> Tensor ('S m) b
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ((Tensor n a -> x -> Tensor m b)
-> [Tensor n a] -> [x] -> [Tensor m b]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
L.zipWith Tensor n a -> x -> Tensor m b
f [Tensor n a]
[Tensor n a]
as [x]
xs)

{-------------------------------------------------------------------------------
  Subtensors
-------------------------------------------------------------------------------}

-- | Subtensors of the specified size
subs :: SNatI n => Size n -> Tensor n a -> Tensor n (Tensor n a)
subs :: forall (n :: Nat) a.
SNatI n =>
Size n -> Tensor n a -> Tensor n (Tensor n a)
subs = Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride (Int -> Vec n Int
forall a. a -> Vec n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
1)

-- | Generalization of 'subs' with non-default stride
subsWithStride :: Vec n Int -> Size n -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride :: forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride Vec n Int
VNil       Vec n Int
VNil       (Scalar a
x)  = Tensor n a -> Tensor 'Z (Tensor n a)
forall a. a -> Tensor 'Z a
Scalar (a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x)
subsWithStride (Int
s ::: Vec n1 Int
ss) (Int
n ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xs) = [Tensor n1 (Tensor n a)] -> Tensor ('S n1) (Tensor n a)
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor [
      [Tensor n1 a] -> Tensor n a
[Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor n a)
-> Tensor n1 [Tensor n1 a] -> Tensor n1 (Tensor n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Tensor n1 (Tensor n1 a)] -> Tensor n1 [Tensor n1 a]
forall (n :: Nat) a. [Tensor n a] -> Tensor n [a]
distrib [Tensor n1 (Tensor n1 a)]
selected
    | [Tensor n1 (Tensor n1 a)]
selected <- Int -> [[Tensor n1 (Tensor n1 a)]] -> [[Tensor n1 (Tensor n1 a)]]
forall a. Int -> [a] -> [a]
everyNth Int
s ([[Tensor n1 (Tensor n1 a)]] -> [[Tensor n1 (Tensor n1 a)]])
-> [[Tensor n1 (Tensor n1 a)]] -> [[Tensor n1 (Tensor n1 a)]]
forall a b. (a -> b) -> a -> b
$ Int -> [Tensor n1 (Tensor n1 a)] -> [[Tensor n1 (Tensor n1 a)]]
forall a. Int -> [a] -> [[a]]
consecutive Int
n ((Tensor n1 a -> Tensor n1 (Tensor n1 a))
-> [Tensor n1 a] -> [Tensor n1 (Tensor n1 a)]
forall a b. (a -> b) -> [a] -> [b]
map (Vec n1 Int -> Vec n1 Int -> Tensor n1 a -> Tensor n1 (Tensor n1 a)
forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride Vec n1 Int
ss Vec n1 Int
Vec n1 Int
ns) [Tensor n1 a]
[Tensor n a]
xs)
    ]

-- | Convolution
--
-- See 'padWith' for adjusting boundary conditions.
convolve ::
     (SNatI n, Num a)
  => Tensor n a  -- ^ Kernel
  -> Tensor n a  -- ^ Input
  -> Tensor n a
convolve :: forall (n :: Nat) a.
(SNatI n, Num a) =>
Tensor n a -> Tensor n a -> Tensor n a
convolve = Vec n Int -> Tensor n a -> Tensor n a -> Tensor n a
forall (n :: Nat) a.
Num a =>
Vec n Int -> Tensor n a -> Tensor n a -> Tensor n a
convolveWithStride (Int -> Vec n Int
forall a. a -> Vec n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
1)

-- | Generalization of 'convolve' when using a non-default stride
convolveWithStride :: forall n a.
     Num a
  => Vec n Int   -- ^ Stride
  -> Tensor n a  -- ^ Kernel
  -> Tensor n a  -- ^ Input
  -> Tensor n a
convolveWithStride :: forall (n :: Nat) a.
Num a =>
Vec n Int -> Tensor n a -> Tensor n a -> Tensor n a
convolveWithStride Vec n Int
stride Tensor n a
kernel Tensor n a
input =
    Tensor n a -> a
aux (Tensor n a -> a) -> Tensor n (Tensor n a) -> Tensor n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride Vec n Int
stride (Tensor n a -> Vec n Int
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
kernel) Tensor n a
input
  where
    aux :: Tensor n a -> a
    aux :: Tensor n a -> a
aux = (a -> a -> a) -> a -> Tensor n a -> a
forall b a. (b -> a -> b) -> b -> Tensor n a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 (Tensor n a -> a) -> (Tensor n a -> Tensor n a) -> Tensor n a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> Tensor n a -> Tensor n a -> Tensor n a
forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Tensor n a
kernel

{-------------------------------------------------------------------------------
  Padding
-------------------------------------------------------------------------------}

-- | Add uniform padding
padWith :: SNatI n => a -> Int -> Tensor n a -> Tensor n a
padWith :: forall (n :: Nat) a.
SNatI n =>
a -> Int -> Tensor n a -> Tensor n a
padWith a
padding Int
n = a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
forall (n :: Nat) a.
a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
padWith' a
padding ((Int, Int) -> Vec n (Int, Int)
forall a. a -> Vec n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
n))

-- | Generalization of 'padWith' with different padding per dimension
padWith' :: forall n a. a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
padWith' :: forall (n :: Nat) a.
a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
padWith' a
padding Vec n (Int, Int)
paddingSize Tensor n a
tensor =
    Vec n (Int, Int) -> Size n -> Tensor n a -> Tensor n a
forall (m :: Nat).
Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go Vec n (Int, Int)
paddingSize Size n
newSize Tensor n a
tensor
  where
    newSize :: Size n
    newSize :: Size n
newSize = ((Int, Int) -> Int -> Int) -> Vec n (Int, Int) -> Size n -> Size n
forall a b c (n :: Nat).
(a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
Vec.zipWith (\(Int
b, Int
a) Int
n -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a) Vec n (Int, Int)
paddingSize (Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
tensor)

    go :: forall m. Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
    go :: forall (m :: Nat).
Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go Vec m (Int, Int)
VNil                     Vec m Int
VNil       (Scalar a
x)  = a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x
    go ((Int
before, Int
after) ::: Vec n1 (Int, Int)
ps) (Int
_ ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xs) = [Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor ('S n1) a)
-> [Tensor n1 a] -> Tensor ('S n1) a
forall a b. (a -> b) -> a -> b
$ [[Tensor n1 a]] -> [Tensor n1 a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
          Int -> Tensor n1 a -> [Tensor n1 a]
forall a. Int -> a -> [a]
L.replicate Int
before (Tensor n1 a -> [Tensor n1 a]) -> Tensor n1 a -> [Tensor n1 a]
forall a b. (a -> b) -> a -> b
$ Vec n1 Int -> a -> Tensor n1 a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
ns a
padding
        , (Tensor n1 a -> Tensor n1 a) -> [Tensor n1 a] -> [Tensor n1 a]
forall a b. (a -> b) -> [a] -> [b]
map (Vec n1 (Int, Int) -> Vec n1 Int -> Tensor n1 a -> Tensor n1 a
forall (m :: Nat).
Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go Vec n1 (Int, Int)
Vec n1 (Int, Int)
ps Vec n1 Int
ns) [Tensor n1 a]
[Tensor n a]
xs
        , Int -> Tensor n1 a -> [Tensor n1 a]
forall a. Int -> a -> [a]
L.replicate Int
after (Tensor n1 a -> [Tensor n1 a]) -> Tensor n1 a -> [Tensor n1 a]
forall a b. (a -> b) -> a -> b
$ Vec n1 Int -> a -> Tensor n1 a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
ns a
padding
        ]

{-------------------------------------------------------------------------------
  QuickCheck support
-------------------------------------------------------------------------------}

arbitraryOfSize :: Size n -> Gen a -> Gen (Tensor n a)
arbitraryOfSize :: forall (n :: Nat) a. Size n -> Gen a -> Gen (Tensor n a)
arbitraryOfSize Size n
sz = Tensor n (Gen a) -> Gen (Tensor n a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => Tensor n (m a) -> m (Tensor n a)
sequence (Tensor n (Gen a) -> Gen (Tensor n a))
-> (Gen a -> Tensor n (Gen a)) -> Gen a -> Gen (Tensor n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Size n -> Gen a -> Tensor n (Gen a)
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Size n
sz

data Axe (n :: Nat) where
  -- | Axe some elements from the current dimension
  --
  -- We record which elements to drop as an @(offset, length)@ pair.
  AxeHere :: (Int, Int) -> Axe (S n)

  -- | Axe some elements from a nested dimension
  --
  -- In order to keep the tensor square, we must apply the same axe for every
  -- element of the /current/ dimension
  AxeNested :: Axe n -> Axe (S n)

deriving instance Show (Axe n)

-- | How many elements are removed by this axe?
--
-- Examples:
--
-- > axeSize (2 ::: 100 ::: VNil) (AxeHere (0, 1))               == 100
-- > axeSize (2 ::: 100 ::: VNil) (AxeNested (AxeHere (0, 99)))  == 198
axeSize :: Size n -> Axe n -> Int
axeSize :: forall (n :: Nat). Size n -> Axe n -> Int
axeSize = (Axe n -> Size n -> Int) -> Size n -> Axe n -> Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip Axe n -> Size n -> Int
forall (n :: Nat). Axe n -> Size n -> Int
go
  where
    go ::  Axe n -> Size n -> Int
    go :: forall (n :: Nat). Axe n -> Size n -> Int
go (AxeHere (Int
_, Int
len)) (Int
_ ::: Vec n1 Int
ns) = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int -> Int -> Int) -> Int -> Vec n1 Int -> Int
forall b a. (b -> a -> b) -> b -> Vec n1 a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 Vec n1 Int
ns
    go (AxeNested Axe n
axe)    (Int
n ::: Vec n1 Int
ns) = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Axe n -> Size n -> Int
forall (n :: Nat). Axe n -> Size n -> Int
go Axe n
axe Size n
Vec n1 Int
ns

-- | All possible ways to axe some elements
--
-- This is adopted from the implementation of 'shrinkList' (in a way, an 'Axe'
-- is an explanation of the decisions made by 'shrinkList', generalized to
-- multiple dimensions).
--
-- Axes are sorted to remove as many elements as early as possible.
allAxes :: Size n -> [Axe n]
allAxes :: forall (n :: Nat). Size n -> [Axe n]
allAxes = \Size n
sz ->
    (Axe n -> Axe n -> Ordering) -> [Axe n] -> [Axe n]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy ((Axe n -> Axe n -> Ordering) -> Axe n -> Axe n -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Axe n -> Axe n -> Ordering) -> Axe n -> Axe n -> Ordering)
-> (Axe n -> Axe n -> Ordering) -> Axe n -> Axe n -> Ordering
forall a b. (a -> b) -> a -> b
$ (Axe n -> Int) -> Axe n -> Axe n -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Size n -> Axe n -> Int
forall (n :: Nat). Size n -> Axe n -> Int
axeSize Size n
sz)) ([Axe n] -> [Axe n]) -> [Axe n] -> [Axe n]
forall a b. (a -> b) -> a -> b
$ Size n -> [Axe n]
forall (n :: Nat). Size n -> [Axe n]
go Size n
sz
  where
    go :: Size n -> [Axe n]
    go :: forall (n :: Nat). Size n -> [Axe n]
go Vec n Int
VNil       = []
    go (Int
n ::: Vec n1 Int
ns) = [[Axe n]] -> [Axe n]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
          [[Axe n]] -> [Axe n]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
              ((Int, Int) -> Axe n) -> [(Int, Int)] -> [Axe n]
forall a b. (a -> b) -> [a] -> [b]
L.map (Int, Int) -> Axe n
(Int, Int) -> Axe ('S n1)
forall (n :: Nat). (Int, Int) -> Axe ('S n)
AxeHere (Int -> Int -> Int -> [(Int, Int)]
removes Int
0 Int
k Int
n)
            | Int
k <- (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) ((Int -> Int) -> Int -> [Int]
forall a. (a -> a) -> a -> [a]
iterate (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2))
            ]
        , (Axe n1 -> Axe n) -> [Axe n1] -> [Axe n]
forall a b. (a -> b) -> [a] -> [b]
L.map Axe n1 -> Axe n
Axe n1 -> Axe ('S n1)
forall (n :: Nat). Axe n -> Axe ('S n)
AxeNested (Vec n1 Int -> [Axe n1]
forall (n :: Nat). Size n -> [Axe n]
go Vec n1 Int
ns)
        ]

    removes :: Int -> Int -> Int -> [(Int, Int)]
    removes :: Int -> Int -> Int -> [(Int, Int)]
removes Int
offset Int
k Int
n
      | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n     = []
      | Bool
otherwise = (Int
offset, Int
k) (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
forall a. a -> [a] -> [a]
: Int -> Int -> Int -> [(Int, Int)]
removes (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) Int
k (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)

-- | Remove elements from the tensor (shrink dimensions)
axeWith :: Axe n -> Tensor n a -> Tensor n a
axeWith :: forall (n :: Nat) a. Axe n -> Tensor n a -> Tensor n a
axeWith (AxeHere (Int
offset, Int
len)) (Tensor [Tensor n a]
xss) = [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$
    [Tensor n a]
before [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. Semigroup a => a -> a -> a
<> [Tensor n a]
after
  where
    ([Tensor n a]
before, [Tensor n a]
dropFrom) = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
offset [Tensor n a]
xss
    ([Tensor n a]
_dropped, [Tensor n a]
after)  = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
len [Tensor n a]
dropFrom
axeWith (AxeNested Axe n
axe) (Tensor [Tensor n a]
xss) = [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$
    (Tensor n a -> Tensor n a) -> [Tensor n a] -> [Tensor n a]
forall a b. (a -> b) -> [a] -> [b]
L.map (Axe n -> Tensor n a -> Tensor n a
forall (n :: Nat) a. Axe n -> Tensor n a -> Tensor n a
axeWith Axe n
axe) [Tensor n a]
[Tensor n a]
xss

-- | Zero element
data Zero a where
  Zero :: Eq a => a -> Zero a

-- | Default 'Zero'
zero :: (Num a, Eq a) => Zero a
zero :: forall a. (Num a, Eq a) => Zero a
zero = a -> Zero a
forall a. Eq a => a -> Zero a
Zero a
0

-- | Zero elements in the tensor (leaving dimensions the same)
--
-- Returns 'Nothing' if the specified region was already zero everywhere.
zeroWith :: forall n a. Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
zeroWith :: forall (n :: Nat) a.
Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
zeroWith (Zero a
z) = \Axe n
axe Tensor n a
tensor ->
    case Axe n -> Size n -> Tensor n a -> (Tensor n a, Bool)
forall (n' :: Nat).
Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go Axe n
axe (Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
tensor) Tensor n a
tensor of
      (Tensor n a
_, Bool
False)      -> Maybe (Tensor n a)
forall a. Maybe a
Nothing
      (Tensor n a
tensor', Bool
True) -> Tensor n a -> Maybe (Tensor n a)
forall a. a -> Maybe a
Just Tensor n a
tensor'
  where
    -- Additionally returns if anything changed
    go :: forall n'. Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
    go :: forall (n' :: Nat).
Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go (AxeHere (Int
offset, Int
len)) (Int
_ ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xss) = (
          [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$ [Tensor n a]
before [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. Semigroup a => a -> a -> a
<> Int -> Tensor n a -> [Tensor n a]
forall a. Int -> a -> [a]
L.replicate Int
len (Size n -> a -> Tensor n a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
Size n
ns a
z) [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. Semigroup a => a -> a -> a
<> [Tensor n a]
after
        , (a -> Bool) -> Tensor ('S n) a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
z) ([Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor [Tensor n a]
dropped)
        )
      where
         ([Tensor n a]
before, [Tensor n a]
dropFrom) = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
offset [Tensor n a]
xss
         ([Tensor n a]
dropped, [Tensor n a]
after)   = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
len [Tensor n a]
dropFrom
    go (AxeNested Axe n
axe) (Int
_ ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xss) =
        ([Tensor n a] -> Tensor n' a)
-> ([Bool] -> Bool)
-> ([Tensor n a], [Bool])
-> (Tensor n' a, Bool)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [Tensor n a] -> Tensor n' a
[Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (([Tensor n a], [Bool]) -> (Tensor n' a, Bool))
-> ([Tensor n a], [Bool]) -> (Tensor n' a, Bool)
forall a b. (a -> b) -> a -> b
$ [(Tensor n a, Bool)] -> ([Tensor n a], [Bool])
forall a b. [(a, b)] -> ([a], [b])
L.unzip ([(Tensor n a, Bool)] -> ([Tensor n a], [Bool]))
-> [(Tensor n a, Bool)] -> ([Tensor n a], [Bool])
forall a b. (a -> b) -> a -> b
$ (Tensor n a -> (Tensor n a, Bool))
-> [Tensor n a] -> [(Tensor n a, Bool)]
forall a b. (a -> b) -> [a] -> [b]
L.map (Axe n -> Size n -> Tensor n a -> (Tensor n a, Bool)
forall (n' :: Nat).
Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go Axe n
axe Size n
Vec n1 Int
ns) [Tensor n a]
[Tensor n a]
xss

-- | Shrink tensor
shrinkWith ::
     Maybe (Zero a)  -- ^ Optional zero element (see 'shrinkElem')
  -> (a -> [a])      -- ^ Shrink individual elements
  -> Tensor n a -> [Tensor n a]
shrinkWith :: forall a (n :: Nat).
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xs = [Axe n]
-> Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall (n :: Nat) a.
[Axe n]
-> Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith' (Size n -> [Axe n]
forall (n :: Nat). Size n -> [Axe n]
allAxes (Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
xs)) Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xs

-- | Generalization of 'shrinkWith'
shrinkWith' :: forall n a.
     [Axe n]         -- ^ Shrink the size of the tensor (see 'allAxes')
  -> Maybe (Zero a)  -- ^ Optional zero element (see 'shrinkElem')
  -> (a -> [a])      -- ^ Shrink elements of the tensor
  -> Tensor n a -> [Tensor n a]
shrinkWith' :: forall (n :: Nat) a.
[Axe n]
-> Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith' [Axe n]
axes Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xss = [[Tensor n a]] -> [Tensor n a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
      [Axe n -> Tensor n a -> Tensor n a
forall (n :: Nat) a. Axe n -> Tensor n a -> Tensor n a
axeWith Axe n
axe Tensor n a
xss | Axe n
axe <- [Axe n]
axes]
    , Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall (n :: Nat) a.
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkElem Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xss
    ]

-- | Shrink an element of the tensor, leaving the size of the tensor unchanged
--
-- If a zero element is specified, we will first try to replace entire regions
-- of the tensor by zeroes; this can dramatically speed up shrinking.
shrinkElem :: forall n a.
     Maybe (Zero a)  -- ^ Optional zero element
  -> (a -> [a])      -- ^ Shrink individual elements
  -> Tensor n a -> [Tensor n a]
shrinkElem :: forall (n :: Nat) a.
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkElem Maybe (Zero a)
mZero a -> [a]
f Tensor n a
tensor = [[Tensor n a]] -> [Tensor n a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
      case Maybe (Zero a)
mZero of
        Maybe (Zero a)
Nothing -> []
        Just Zero a
z  -> [Maybe (Tensor n a)] -> [Tensor n a]
forall a. [Maybe a] -> [a]
catMaybes [
            Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
forall (n :: Nat) a.
Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
zeroWith Zero a
z Axe n
axe Tensor n a
tensor
          | Axe n
axe <- Size n -> [Axe n]
forall (n :: Nat). Size n -> [Axe n]
allAxes Size n
overallSize
          , Size n -> Axe n -> Int
forall (n :: Nat). Size n -> Axe n -> Int
axeSize Size n
overallSize Axe n
axe Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
          ]
    , Tensor n a -> [Tensor n a]
forall (n' :: Nat). Tensor n' a -> [Tensor n' a]
shrinkOne Tensor n a
tensor
    ]
  where
    overallSize :: Size n
    overallSize :: Size n
overallSize = Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
tensor

    shrinkOne :: forall n'. Tensor n' a -> [Tensor n' a]
    shrinkOne :: forall (n' :: Nat). Tensor n' a -> [Tensor n' a]
shrinkOne (Scalar a
x)   = a -> Tensor n' a
a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar (a -> Tensor n' a) -> [a] -> [Tensor n' a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> [a]
f a
x
    shrinkOne (Tensor [Tensor n a]
xss) = [
          [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$ [Tensor n a]
before [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. [a] -> [a] -> [a]
++ [Tensor n a
xs'] [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. [a] -> [a] -> [a]
++ [Tensor n a]
after
        | ([Tensor n a]
before, Tensor n a
xs, [Tensor n a]
after) <- [Tensor n a] -> [([Tensor n a], Tensor n a, [Tensor n a])]
forall a. [a] -> [([a], a, [a])]
pickOne [Tensor n a]
xss
        , Tensor n a
xs' <- Tensor n a -> [Tensor n a]
forall (n' :: Nat). Tensor n' a -> [Tensor n' a]
shrinkOne Tensor n a
xs
        ]

instance (SNatI n, Arbitrary a, Num a, Eq a) => Arbitrary (Tensor n a) where
  arbitrary :: Gen (Tensor n a)
arbitrary = Gen a -> Gen (Tensor n a)
forall a. Gen a -> Gen (Tensor n a)
forall (f :: * -> *) a. Arbitrary1 f => Gen a -> Gen (f a)
liftArbitrary Gen a
forall a. Arbitrary a => Gen a
arbitrary
  shrink :: Tensor n a -> [Tensor n a]
shrink    = Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall a (n :: Nat).
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith (Zero a -> Maybe (Zero a)
forall a. a -> Maybe a
Just (a -> Zero a
forall a. Eq a => a -> Zero a
Zero a
0)) a -> [a]
forall a. Arbitrary a => a -> [a]
shrink

-- | Lift generators and shrinkers
--
-- NOTE: Since we cannot put any constraints on the type of the elements here,
-- we cannot use any zero elements. Using 'shrink' (or 'shrinkWith' directly)
-- might result in faster shrinking.
instance SNatI n => Arbitrary1 (Tensor n) where
  liftArbitrary :: forall a. Gen a -> Gen (Tensor n a)
liftArbitrary Gen a
g = (Int -> Gen (Tensor n a)) -> Gen (Tensor n a)
forall a. (Int -> Gen a) -> Gen a
QC.sized ((Int -> Gen (Tensor n a)) -> Gen (Tensor n a))
-> (Int -> Gen (Tensor n a)) -> Gen (Tensor n a)
forall a b. (a -> b) -> a -> b
$ \Int
n -> do
      Size n
sz :: Size n <- Gen Int -> Gen (Size n)
forall a. Gen a -> Gen (Vec n a)
forall (f :: * -> *) a. Arbitrary1 f => Gen a -> Gen (f a)
liftArbitrary (Gen Int -> Gen (Size n)) -> Gen Int -> Gen (Size n)
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Gen Int
forall a. Random a => (a, a) -> Gen a
QC.choose (Int
1, Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n)
      Size n -> Gen a -> Gen (Tensor n a)
forall (n :: Nat) a. Size n -> Gen a -> Gen (Tensor n a)
arbitraryOfSize Size n
sz Gen a
g

  liftShrink :: forall a. (a -> [a]) -> Tensor n a -> [Tensor n a]
liftShrink a -> [a]
f = Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall a (n :: Nat).
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith Maybe (Zero a)
forall a. Maybe a
Nothing a -> [a]
f

{-------------------------------------------------------------------------------
  FFI
-------------------------------------------------------------------------------}

-- | Translate to storable vector
--
-- The tensor is laid out in order specified (outer dimensions before inner).
toStorable :: Storable a => Tensor n a -> Storable.Vector a
toStorable :: forall a (n :: Nat). Storable a => Tensor n a -> Vector a
toStorable = [a] -> Vector a
forall a. Storable a => [a] -> Vector a
Vector.fromList ([a] -> Vector a) -> (Tensor n a -> [a]) -> Tensor n a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor n a -> [a]
forall a. Tensor n a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList

-- | Translate from storable vector
--
-- Throws an exception if the vector does not contain enough elements.
fromStorable ::
     (HasCallStack, Storable a)
  => Size n -> Storable.Vector a -> Tensor n a
fromStorable :: forall a (n :: Nat).
(HasCallStack, Storable a) =>
Size n -> Vector a -> Tensor n a
fromStorable Size n
sz = Size n -> [a] -> Tensor n a
forall (n :: Nat) a. Size n -> [a] -> Tensor n a
fromList Size n
sz ([a] -> Tensor n a) -> (Vector a -> [a]) -> Vector a -> Tensor n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> [a]
forall a. Storable a => Vector a -> [a]
Vector.toList

-- | Get pointer to elements of the tensor
--
-- See 'toStorable' for discussion of the layout.
--
-- The data should not be modified through the pointer, and the pointer should
-- not be used outside its scope.
unsafeWithCArray :: Storable a => Tensor n a -> (Ptr a -> IO r) -> IO r
unsafeWithCArray :: forall a (n :: Nat) r.
Storable a =>
Tensor n a -> (Ptr a -> IO r) -> IO r
unsafeWithCArray Tensor n a
tensor = Vector a -> (Ptr a -> IO r) -> IO r
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
Vector.unsafeWith (Tensor n a -> Vector a
forall a (n :: Nat). Storable a => Tensor n a -> Vector a
toStorable Tensor n a
tensor)

-- | Construct tensor from C array
--
-- The data should not be modified through the pointer after the tensor has
-- been constructed.
unsafeFromCArray :: Storable a => Size n -> ForeignPtr a -> Tensor n a
unsafeFromCArray :: forall a (n :: Nat).
Storable a =>
Size n -> ForeignPtr a -> Tensor n a
unsafeFromCArray Size n
sz ForeignPtr a
fptr =
    Size n -> Vector a -> Tensor n a
forall a (n :: Nat).
(HasCallStack, Storable a) =>
Size n -> Vector a -> Tensor n a
fromStorable Size n
sz (Vector a -> Tensor n a) -> Vector a -> Tensor n a
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> Int -> Vector a
forall a. ForeignPtr a -> Int -> Vector a
Vector.unsafeFromForeignPtr0 ForeignPtr a
fptr Int
n
  where
    n :: Int
    n :: Int
n = (Int -> Int -> Int) -> Int -> Size n -> Int
forall b a. (b -> a -> b) -> b -> Vec n a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 Size n
sz

-- | Construct tensor from preallocated C array
--
-- Allocates sufficient memory to hold the elements of the tensor; writing more
-- data will result in invalid memory access. The pointer should not be used
-- outside its scope.
unsafeFromPrealloc ::
     Storable a
  => Size n -> (Ptr a -> IO r) -> IO (Tensor n a, r)
unsafeFromPrealloc :: forall a (n :: Nat) r.
Storable a =>
Size n -> (Ptr a -> IO r) -> IO (Tensor n a, r)
unsafeFromPrealloc Size n
sz Ptr a -> IO r
k = do
    ForeignPtr a
fptr <- Int -> IO (ForeignPtr a)
forall a. Storable a => Int -> IO (ForeignPtr a)
mallocForeignPtrArray Int
n
    r
res  <- ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr Ptr a -> IO r
k
    (Tensor n a, r) -> IO (Tensor n a, r)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Size n -> ForeignPtr a -> Tensor n a
forall a (n :: Nat).
Storable a =>
Size n -> ForeignPtr a -> Tensor n a
unsafeFromCArray Size n
sz ForeignPtr a
fptr, r
res)
  where
    n :: Int
    n :: Int
n = (Int -> Int -> Int) -> Int -> Size n -> Int
forall b a. (b -> a -> b) -> b -> Vec n a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 Size n
sz

-- | Like 'unsafeFromPrealloc' but without an additional return value
unsafeFromPrealloc_ ::
     Storable a
  => Size n -> (Ptr a -> IO ()) -> IO (Tensor n a)
unsafeFromPrealloc_ :: forall a (n :: Nat).
Storable a =>
Size n -> (Ptr a -> IO ()) -> IO (Tensor n a)
unsafeFromPrealloc_ Size n
sz = ((Tensor n a, ()) -> Tensor n a)
-> IO (Tensor n a, ()) -> IO (Tensor n a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor n a, ()) -> Tensor n a
forall a b. (a, b) -> a
fst (IO (Tensor n a, ()) -> IO (Tensor n a))
-> ((Ptr a -> IO ()) -> IO (Tensor n a, ()))
-> (Ptr a -> IO ())
-> IO (Tensor n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Size n -> (Ptr a -> IO ()) -> IO (Tensor n a, ())
forall a (n :: Nat) r.
Storable a =>
Size n -> (Ptr a -> IO r) -> IO (Tensor n a, r)
unsafeFromPrealloc Size n
sz

{-------------------------------------------------------------------------------
  Convenience constructors
-------------------------------------------------------------------------------}

scalar :: a -> Tensor Nat0 a
scalar :: forall a. a -> Tensor 'Z a
scalar = a -> Tensor 'Z a
Lists 'Z a -> Tensor 'Z a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim1 :: [a] -> Tensor Nat1 a
dim1 :: forall a. [a] -> Tensor ('S 'Z) a
dim1 = [a] -> Tensor ('S 'Z) a
Lists ('S 'Z) a -> Tensor ('S 'Z) a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim2 :: [[a]] -> Tensor Nat2 a
dim2 :: forall a. [[a]] -> Tensor Nat2 a
dim2 = [[a]] -> Tensor Nat2 a
Lists Nat2 a -> Tensor Nat2 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim3 :: [[[a]]] -> Tensor Nat3 a
dim3 :: forall a. [[[a]]] -> Tensor Nat3 a
dim3 = [[[a]]] -> Tensor Nat3 a
Lists Nat3 a -> Tensor Nat3 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim4 :: [[[[a]]]] -> Tensor Nat4 a
dim4 :: forall a. [[[[a]]]] -> Tensor Nat4 a
dim4 = [[[[a]]]] -> Tensor Nat4 a
Lists Nat4 a -> Tensor Nat4 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim5 :: [[[[[a]]]]] -> Tensor Nat5 a
dim5 :: forall a. [[[[[a]]]]] -> Tensor Nat5 a
dim5 = [[[[[a]]]]] -> Tensor Nat5 a
Lists Nat5 a -> Tensor Nat5 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim6 :: [[[[[[a]]]]]] -> Tensor Nat6 a
dim6 :: forall a. [[[[[[a]]]]]] -> Tensor Nat6 a
dim6 = [[[[[[a]]]]]] -> Tensor Nat6 a
Lists Nat6 a -> Tensor Nat6 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim7 :: [[[[[[[a]]]]]]] -> Tensor Nat7 a
dim7 :: forall a. [[[[[[[a]]]]]]] -> Tensor Nat7 a
dim7 = [[[[[[[a]]]]]]] -> Tensor Nat7 a
Lists Nat7 a -> Tensor Nat7 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim8 :: [[[[[[[[a]]]]]]]] -> Tensor Nat8 a
dim8 :: forall a. [[[[[[[[a]]]]]]]] -> Tensor Nat8 a
dim8 = [[[[[[[[a]]]]]]]] -> Tensor Nat8 a
Lists Nat8 a -> Tensor Nat8 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

dim9 :: [[[[[[[[[a]]]]]]]]] -> Tensor Nat9 a
dim9 :: forall a. [[[[[[[[[a]]]]]]]]] -> Tensor Nat9 a
dim9 = [[[[[[[[[a]]]]]]]]] -> Tensor Nat9 a
Lists Nat9 a -> Tensor Nat9 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists

{-------------------------------------------------------------------------------
  Conversions

  This is primarily useful for specify tensor constants.
-------------------------------------------------------------------------------}

type family Lists n a where
  Lists Z     a = a
  Lists (S n) a = [Lists n a]

toLists :: Tensor n a -> Lists n a
toLists :: forall (n :: Nat) a. Tensor n a -> Lists n a
toLists (Scalar a
x)  = a
Lists n a
x
toLists (Tensor [Tensor n a]
xs) = (Tensor n a -> Lists n a) -> [Tensor n a] -> [Lists n a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor n a -> Lists n a
forall (n :: Nat) a. Tensor n a -> Lists n a
toLists [Tensor n a]
xs

fromLists :: SNatI n => Lists n a -> Tensor n a
fromLists :: forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists = SNat n -> Lists n a -> Tensor n a
forall (n :: Nat) a. SNat n -> Lists n a -> Tensor n a
go SNat n
forall (n :: Nat). SNatI n => SNat n
snat
  where
    go :: SNat n -> Lists n a -> Tensor n a
    go :: forall (n :: Nat) a. SNat n -> Lists n a -> Tensor n a
go SNat n
SZ = a -> Tensor 'Z a
Lists n a -> Tensor n a
forall a. a -> Tensor 'Z a
Scalar
    go SNat n
SS = [Tensor n1 a] -> Tensor n a
[Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor n a)
-> ([Lists n1 a] -> [Tensor n1 a]) -> [Lists n1 a] -> Tensor n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lists n1 a -> Tensor n1 a) -> [Lists n1 a] -> [Tensor n1 a]
forall a b. (a -> b) -> [a] -> [b]
map (SNat n1 -> Lists n1 a -> Tensor n1 a
forall (n :: Nat) a. SNat n -> Lists n a -> Tensor n a
go SNat n1
forall (n :: Nat). SNatI n => SNat n
snat)

-- | Inverse to 'Foldable.toList'
--
-- Throws a pure exception if the list does not contain enough elements.
fromList :: forall n a. Size n -> [a] -> Tensor n a
fromList :: forall (n :: Nat) a. Size n -> [a] -> Tensor n a
fromList Size n
sz [a]
xs =
    Maybe (Tensor n a) -> Tensor n a
checkEnoughElems (Maybe (Tensor n a) -> Tensor n a)
-> (StateT [a] Maybe (Tensor n a) -> Maybe (Tensor n a))
-> StateT [a] Maybe (Tensor n a)
-> Tensor n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StateT [a] Maybe (Tensor n a) -> [a] -> Maybe (Tensor n a))
-> [a] -> StateT [a] Maybe (Tensor n a) -> Maybe (Tensor n a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT [a] Maybe (Tensor n a) -> [a] -> Maybe (Tensor n a)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT [a]
xs (StateT [a] Maybe (Tensor n a) -> Tensor n a)
-> StateT [a] Maybe (Tensor n a) -> Tensor n a
forall a b. (a -> b) -> a -> b
$ Tensor n (StateT [a] Maybe a) -> StateT [a] Maybe (Tensor n a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a.
Applicative f =>
Tensor n (f a) -> f (Tensor n a)
sequenceA (Size n -> StateT [a] Maybe a -> Tensor n (StateT [a] Maybe a)
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Size n
sz StateT [a] Maybe a
genElem)
  where
    genElem :: StateT [a] Maybe a
    genElem :: StateT [a] Maybe a
genElem = ([a] -> Maybe (a, [a])) -> StateT [a] Maybe a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT [a] -> Maybe (a, [a])
forall a. [a] -> Maybe (a, [a])
L.uncons

    checkEnoughElems :: Maybe (Tensor n a) -> Tensor n a
    checkEnoughElems :: Maybe (Tensor n a) -> Tensor n a
checkEnoughElems Maybe (Tensor n a)
Nothing  = [Char] -> Tensor n a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromList: insufficient elements"
    checkEnoughElems (Just Tensor n a
t) = Tensor n a
t

{-------------------------------------------------------------------------------
  Show instance
-------------------------------------------------------------------------------}

showLists :: Show a => Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists :: forall a (n :: Nat) r.
Show a =>
Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists Proxy a
_ SNat n
SZ      Show (Lists n a) => r
k = r
Show (Lists n a) => r
k
showLists Proxy a
p (SS' SNat n
n) Show (Lists n a) => r
k = Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
forall a (n :: Nat) r.
Show a =>
Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists Proxy a
p SNat n
n r
Show (Lists n a) => r
Show (Lists n a) => r
k

showConstructor :: Int -> SNat n -> ShowS
showConstructor :: forall (n :: Nat). Int -> SNat n -> ShowS
showConstructor Int
p SNat n
sn
  | Natural
n' Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
0            = [Char] -> ShowS
showString [Char]
"scalar"
  | Natural
1 Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
n' Bool -> Bool -> Bool
&& Natural
n' Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
9 = [Char] -> ShowS
showString [Char]
"dim" ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> ShowS
forall a. Show a => a -> ShowS
shows Natural
n'
  | Bool
otherwise          = [Char] -> ShowS
showString [Char]
"fromLists @"
                       ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Nat -> ShowS
explicitShowsPrec Int
p (SNat n -> Nat
forall (n :: Nat). SNat n -> Nat
snatToNat SNat n
sn)
  where
    n' :: Natural
    n' :: Natural
n' = SNat n -> Natural
forall (n :: Nat). SNat n -> Natural
snatToNatural SNat n
sn

instance Show a => Show (Tensor n a) where
  showsPrec :: Int -> Tensor n a -> ShowS
showsPrec Int
p Tensor n a
tensor = Proxy a -> SNat n -> (Show (Lists n a) => ShowS) -> ShowS
forall a (n :: Nat) r.
Show a =>
Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a) (Tensor n a -> SNat n
forall (n :: Nat) a. Tensor n a -> SNat n
tensorSNat Tensor n a
tensor) ((Show (Lists n a) => ShowS) -> ShowS)
-> (Show (Lists n a) => ShowS) -> ShowS
forall a b. (a -> b) -> a -> b
$
      Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
appPrec1) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
          Int -> SNat n -> ShowS
forall (n :: Nat). Int -> SNat n -> ShowS
showConstructor Int
appPrec1 (Tensor n a -> SNat n
forall (n :: Nat) a. Tensor n a -> SNat n
tensorSNat Tensor n a
tensor)
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
showSpace
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Lists n a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
appPrec1 (Tensor n a -> Lists n a
forall (n :: Nat) a. Tensor n a -> Lists n a
toLists Tensor n a
tensor)

{-------------------------------------------------------------------------------
  Internal auxiliary: SNat
-------------------------------------------------------------------------------}

tensorSNatI :: Tensor n a -> (SNatI n => r) -> r
tensorSNatI :: forall (n :: Nat) a r. Tensor n a -> (SNatI n => r) -> r
tensorSNatI (Scalar a
_)  SNatI n => r
k = r
SNatI n => r
k
tensorSNatI (Tensor [Tensor n a]
xs) SNatI n => r
k = Tensor n a -> (SNatI n => r) -> r
forall (n :: Nat) a r. Tensor n a -> (SNatI n => r) -> r
tensorSNatI ([Tensor n a] -> Tensor n a
forall a. HasCallStack => [a] -> a
L.head [Tensor n a]
xs) r
SNatI n => r
SNatI n => r
k

tensorSNat :: Tensor n a -> SNat n
tensorSNat :: forall (n :: Nat) a. Tensor n a -> SNat n
tensorSNat Tensor n a
tensor = Tensor n a -> (SNatI n => SNat n) -> SNat n
forall (n :: Nat) a r. Tensor n a -> (SNatI n => r) -> r
tensorSNatI Tensor n a
tensor SNat n
SNatI n => SNat n
forall (n :: Nat). SNatI n => SNat n
snat

{-------------------------------------------------------------------------------
  Internal auxiliary: lists
-------------------------------------------------------------------------------}

-- | Consecutive elements
--
-- >    consecutive 3 [1..5]
-- > == [[1,2,3],[2,3,4],[3,4,5]]
consecutive :: Int -> [a] -> [[a]]
consecutive :: forall a. Int -> [a] -> [[a]]
consecutive Int
n = ([a] -> Bool) -> [[a]] -> [[a]]
forall a. (a -> Bool) -> [a] -> [a]
L.takeWhile ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n) (Int -> Bool) -> ([a] -> Int) -> [a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) ([[a]] -> [[a]]) -> ([a] -> [[a]]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
L.take Int
n) ([[a]] -> [[a]]) -> ([a] -> [[a]]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [[a]]
forall a. [a] -> [[a]]
L.tails

-- | Every nth element of the list
--
-- Examples
--
-- > everyNth 1 [0..9] == [0,2,3,4,5,6,7,8,9]
-- > everyNth 2 [0..9] == [0,2,4,6,8]
-- > everyNth 3 [0..9] == [0,3,6,9]
everyNth :: forall a. Int -> [a] -> [a]
everyNth :: forall a. Int -> [a] -> [a]
everyNth Int
n = \[a]
xs ->
    if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
      then [a] -> [a]
go [a]
xs
      else [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"everyNth: n should be strictly positive"
  where
    go :: [a] -> [a]
    go :: [a] -> [a]
go []     = []
    go (a
x:[a]
xs) = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a]
go (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [a]
xs)

-- | Single out an element from the list
--
-- >    pickOne [1..4]
-- > == [ ( []      , 1 , [2,3,4] )
-- >    , ( [1]     , 2 , [3,4]   )
-- >    , ( [1,2]   , 3 , [4]     )
-- >    , ( [1,2,3] , 4 , []      )
-- >    ]
pickOne :: forall a. [a] -> [([a], a, [a])]
pickOne :: forall a. [a] -> [([a], a, [a])]
pickOne = \case
    []   -> [Char] -> [([a], a, [a])]
forall a. HasCallStack => [Char] -> a
error [Char]
"pickOne: empty list"
    a
x:[a]
xs -> [a] -> a -> [a] -> [([a], a, [a])]
go [] a
x [a]
xs
  where
    go :: [a] -> a -> [a] -> [([a], a, [a])]
    go :: [a] -> a -> [a] -> [([a], a, [a])]
go [a]
acc a
x []     = [([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, a
x, [])]
    go [a]
acc a
x (a
y:[a]
ys) = ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, a
x, (a
ya -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
ys)) ([a], a, [a]) -> [([a], a, [a])] -> [([a], a, [a])]
forall a. a -> [a] -> [a]
: [a] -> a -> [a] -> [([a], a, [a])]
go (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc) a
y [a]
ys