module Test.Tensor (
Tensor(..)
, getScalar
, getTensor
, scalar
, dim1
, dim2
, dim3
, dim4
, dim5
, dim6
, dim7
, dim8
, dim9
, Size
, size
, sizeAtLeast
, zipWith
, replicate
, rotate
, distrib
, transpose
, foreach
, foreachWith
, subs
, subsWithStride
, convolve
, convolveWithStride
, padWith
, padWith'
, Lists
, toLists
, fromLists
, fromList
, arbitraryOfSize
, shrinkWith
, shrinkWith'
, shrinkElem
, Axe(..)
, allAxes
, axeWith
, axeSize
, Zero(..)
, zero
, zeroWith
, toStorable
, fromStorable
, unsafeWithCArray
, unsafeFromCArray
, unsafeFromPrealloc
, unsafeFromPrealloc_
) where
import Prelude hiding (zipWith, replicate)
import Control.Monad.Trans.State (StateT(..), evalStateT)
import Data.Bifunctor
import Data.Foldable (foldl')
import Data.Foldable qualified as Foldable
import Data.List qualified as L
import Data.Maybe (catMaybes)
import Data.Ord
import Data.Proxy
import Data.Type.Nat
import Data.Vec.Lazy (Vec(..))
import Data.Vec.Lazy qualified as Vec
import Data.Vector.Storable qualified as Storable (Vector)
import Data.Vector.Storable qualified as Vector
import Foreign hiding (rotate)
import GHC.Show (appPrec1, showSpace)
import GHC.Stack
import Numeric.Natural
import Test.QuickCheck (Arbitrary(..), Arbitrary1(..), Gen)
import Test.QuickCheck qualified as QC
data Tensor n a where
Scalar :: a -> Tensor Z a
Tensor :: [Tensor n a] -> Tensor (S n) a
deriving stock instance Eq a => Eq (Tensor n a)
deriving stock instance Functor (Tensor n)
deriving stock instance Traversable (Tensor n)
deriving stock instance Foldable (Tensor n)
getScalar :: Tensor Z a -> a
getScalar :: forall a. Tensor 'Z a -> a
getScalar (Scalar a
x) = a
x
getTensor :: Tensor (S n) a -> [Tensor n a]
getTensor :: forall (n :: Nat) a. Tensor ('S n) a -> [Tensor n a]
getTensor (Tensor [Tensor n a]
xs) = [Tensor n a]
[Tensor n a]
xs
type Size n = Vec n Int
size :: Tensor n a -> Size n
size :: forall (n :: Nat) a. Tensor n a -> Size n
size (Scalar a
_) = Vec n Int
Vec 'Z Int
forall a. Vec 'Z a
VNil
size (Tensor [Tensor n a]
xs) = [Tensor n a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length [Tensor n a]
xs Int -> Vec n Int -> Vec ('S n) Int
forall a (n1 :: Nat). a -> Vec n1 a -> Vec ('S n1) a
::: Tensor n a -> Vec n Int
forall (n :: Nat) a. Tensor n a -> Size n
size ([Tensor n a] -> Tensor n a
forall a. HasCallStack => [a] -> a
L.head [Tensor n a]
xs)
sizeAtLeast :: Size n -> Tensor n a -> Bool
sizeAtLeast :: forall (n :: Nat) a. Size n -> Tensor n a -> Bool
sizeAtLeast Size n
sz = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> (Tensor n a -> [Bool]) -> Tensor n a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec n Bool -> [Bool]
forall a. Vec n a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList (Vec n Bool -> [Bool])
-> (Tensor n a -> Vec n Bool) -> Tensor n a -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Bool) -> Size n -> Size n -> Vec n Bool
forall a b c (n :: Nat).
(a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
Vec.zipWith Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(<=) Size n
sz (Size n -> Vec n Bool)
-> (Tensor n a -> Size n) -> Tensor n a -> Vec n Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size
zipWith :: (a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith :: forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith a -> b -> c
f (Scalar a
a) (Scalar b
b) = c -> Tensor 'Z c
forall a. a -> Tensor 'Z a
Scalar (a -> b -> c
f a
a b
b)
zipWith a -> b -> c
f (Tensor [Tensor n a]
as) (Tensor [Tensor n b]
bs) = [Tensor n c] -> Tensor ('S n) c
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n c] -> Tensor ('S n) c)
-> [Tensor n c] -> Tensor ('S n) c
forall a b. (a -> b) -> a -> b
$ (Tensor n a -> Tensor n b -> Tensor n c)
-> [Tensor n a] -> [Tensor n b] -> [Tensor n c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
L.zipWith ((a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith a -> b -> c
f) [Tensor n a]
as [Tensor n b]
[Tensor n b]
bs
replicate :: Size n -> a -> Tensor n a
replicate :: forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n Int
VNil a
x = a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x
replicate (Int
n ::: Vec n1 Int
ns) a
x = [Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor ('S n1) a)
-> [Tensor n1 a] -> Tensor ('S n1) a
forall a b. (a -> b) -> a -> b
$ Int -> Tensor n1 a -> [Tensor n1 a]
forall a. Int -> a -> [a]
L.replicate Int
n (Vec n1 Int -> a -> Tensor n1 a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
ns a
x)
rotate :: Tensor n a -> Tensor n a
rotate :: forall (n :: Nat) a. Tensor n a -> Tensor n a
rotate (Scalar a
x) = a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x
rotate (Tensor [Tensor n a]
xs) = [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$ (Tensor n a -> Tensor n a) -> [Tensor n a] -> [Tensor n a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor n a -> Tensor n a
forall (n :: Nat) a. Tensor n a -> Tensor n a
rotate ([Tensor n a] -> [Tensor n a]
forall a. [a] -> [a]
L.reverse [Tensor n a]
xs)
distrib :: [Tensor n a] -> Tensor n [a]
distrib :: forall (n :: Nat) a. [Tensor n a] -> Tensor n [a]
distrib = \case
[] -> [Char] -> Tensor n [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"distrib: empty list"
Tensor n a
t:[Tensor n a]
ts -> Tensor n [a] -> [Tensor n a] -> Tensor n [a]
forall (n :: Nat) a. Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go ((a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[]) (a -> [a]) -> Tensor n a -> Tensor n [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor n a
t) [Tensor n a]
ts
where
go :: Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go :: forall (n :: Nat) a. Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go Tensor n [a]
acc [] = [a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> Tensor n [a] -> Tensor n [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor n [a]
acc
go Tensor n [a]
acc (Tensor n a
t:[Tensor n a]
ts) = Tensor n [a] -> [Tensor n a] -> Tensor n [a]
forall (n :: Nat) a. Tensor n [a] -> [Tensor n a] -> Tensor n [a]
go ((a -> [a] -> [a]) -> Tensor n a -> Tensor n [a] -> Tensor n [a]
forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith (:) Tensor n a
t Tensor n [a]
acc) [Tensor n a]
ts
transpose :: Tensor Nat2 a -> Tensor Nat2 a
transpose :: forall a. Tensor Nat2 a -> Tensor Nat2 a
transpose = [[a]] -> Tensor Nat2 a
Lists Nat2 a -> Tensor Nat2 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists ([[a]] -> Tensor Nat2 a)
-> (Tensor Nat2 a -> [[a]]) -> Tensor Nat2 a -> Tensor Nat2 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
L.transpose ([[a]] -> [[a]])
-> (Tensor Nat2 a -> [[a]]) -> Tensor Nat2 a -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Nat2 a -> [[a]]
Tensor Nat2 a -> Lists Nat2 a
forall (n :: Nat) a. Tensor n a -> Lists n a
toLists
foreach :: Tensor (S n) a -> (Tensor n a -> Tensor m b) -> Tensor (S m) b
foreach :: forall (n :: Nat) a (m :: Nat) b.
Tensor ('S n) a -> (Tensor n a -> Tensor m b) -> Tensor ('S m) b
foreach (Tensor [Tensor n a]
as) Tensor n a -> Tensor m b
f = [Tensor m b] -> Tensor ('S m) b
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ((Tensor n a -> Tensor m b) -> [Tensor n a] -> [Tensor m b]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map Tensor n a -> Tensor m b
f [Tensor n a]
[Tensor n a]
as)
foreachWith ::
Tensor (S n) a
-> [x]
-> (Tensor n a -> x -> Tensor m b)
-> Tensor (S m) b
foreachWith :: forall (n :: Nat) a x (m :: Nat) b.
Tensor ('S n) a
-> [x] -> (Tensor n a -> x -> Tensor m b) -> Tensor ('S m) b
foreachWith (Tensor [Tensor n a]
as) [x]
xs Tensor n a -> x -> Tensor m b
f = [Tensor m b] -> Tensor ('S m) b
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ((Tensor n a -> x -> Tensor m b)
-> [Tensor n a] -> [x] -> [Tensor m b]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
L.zipWith Tensor n a -> x -> Tensor m b
f [Tensor n a]
[Tensor n a]
as [x]
xs)
subs :: SNatI n => Size n -> Tensor n a -> Tensor n (Tensor n a)
subs :: forall (n :: Nat) a.
SNatI n =>
Size n -> Tensor n a -> Tensor n (Tensor n a)
subs = Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride (Int -> Vec n Int
forall a. a -> Vec n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
1)
subsWithStride :: Vec n Int -> Size n -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride :: forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride Vec n Int
VNil Vec n Int
VNil (Scalar a
x) = Tensor n a -> Tensor 'Z (Tensor n a)
forall a. a -> Tensor 'Z a
Scalar (a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x)
subsWithStride (Int
s ::: Vec n1 Int
ss) (Int
n ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xs) = [Tensor n1 (Tensor n a)] -> Tensor ('S n1) (Tensor n a)
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor [
[Tensor n1 a] -> Tensor n a
[Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor n a)
-> Tensor n1 [Tensor n1 a] -> Tensor n1 (Tensor n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Tensor n1 (Tensor n1 a)] -> Tensor n1 [Tensor n1 a]
forall (n :: Nat) a. [Tensor n a] -> Tensor n [a]
distrib [Tensor n1 (Tensor n1 a)]
selected
| [Tensor n1 (Tensor n1 a)]
selected <- Int -> [[Tensor n1 (Tensor n1 a)]] -> [[Tensor n1 (Tensor n1 a)]]
forall a. Int -> [a] -> [a]
everyNth Int
s ([[Tensor n1 (Tensor n1 a)]] -> [[Tensor n1 (Tensor n1 a)]])
-> [[Tensor n1 (Tensor n1 a)]] -> [[Tensor n1 (Tensor n1 a)]]
forall a b. (a -> b) -> a -> b
$ Int -> [Tensor n1 (Tensor n1 a)] -> [[Tensor n1 (Tensor n1 a)]]
forall a. Int -> [a] -> [[a]]
consecutive Int
n ((Tensor n1 a -> Tensor n1 (Tensor n1 a))
-> [Tensor n1 a] -> [Tensor n1 (Tensor n1 a)]
forall a b. (a -> b) -> [a] -> [b]
map (Vec n1 Int -> Vec n1 Int -> Tensor n1 a -> Tensor n1 (Tensor n1 a)
forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride Vec n1 Int
ss Vec n1 Int
Vec n1 Int
ns) [Tensor n1 a]
[Tensor n a]
xs)
]
convolve ::
(SNatI n, Num a)
=> Tensor n a
-> Tensor n a
-> Tensor n a
convolve :: forall (n :: Nat) a.
(SNatI n, Num a) =>
Tensor n a -> Tensor n a -> Tensor n a
convolve = Vec n Int -> Tensor n a -> Tensor n a -> Tensor n a
forall (n :: Nat) a.
Num a =>
Vec n Int -> Tensor n a -> Tensor n a -> Tensor n a
convolveWithStride (Int -> Vec n Int
forall a. a -> Vec n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
1)
convolveWithStride :: forall n a.
Num a
=> Vec n Int
-> Tensor n a
-> Tensor n a
-> Tensor n a
convolveWithStride :: forall (n :: Nat) a.
Num a =>
Vec n Int -> Tensor n a -> Tensor n a -> Tensor n a
convolveWithStride Vec n Int
stride Tensor n a
kernel Tensor n a
input =
Tensor n a -> a
aux (Tensor n a -> a) -> Tensor n (Tensor n a) -> Tensor n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
forall (n :: Nat) a.
Vec n Int -> Vec n Int -> Tensor n a -> Tensor n (Tensor n a)
subsWithStride Vec n Int
stride (Tensor n a -> Vec n Int
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
kernel) Tensor n a
input
where
aux :: Tensor n a -> a
aux :: Tensor n a -> a
aux = (a -> a -> a) -> a -> Tensor n a -> a
forall b a. (b -> a -> b) -> b -> Tensor n a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 (Tensor n a -> a) -> (Tensor n a -> Tensor n a) -> Tensor n a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> Tensor n a -> Tensor n a -> Tensor n a
forall a b c (n :: Nat).
(a -> b -> c) -> Tensor n a -> Tensor n b -> Tensor n c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Tensor n a
kernel
padWith :: SNatI n => a -> Int -> Tensor n a -> Tensor n a
padWith :: forall (n :: Nat) a.
SNatI n =>
a -> Int -> Tensor n a -> Tensor n a
padWith a
padding Int
n = a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
forall (n :: Nat) a.
a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
padWith' a
padding ((Int, Int) -> Vec n (Int, Int)
forall a. a -> Vec n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
n))
padWith' :: forall n a. a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
padWith' :: forall (n :: Nat) a.
a -> Vec n (Int, Int) -> Tensor n a -> Tensor n a
padWith' a
padding Vec n (Int, Int)
paddingSize Tensor n a
tensor =
Vec n (Int, Int) -> Size n -> Tensor n a -> Tensor n a
forall (m :: Nat).
Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go Vec n (Int, Int)
paddingSize Size n
newSize Tensor n a
tensor
where
newSize :: Size n
newSize :: Size n
newSize = ((Int, Int) -> Int -> Int) -> Vec n (Int, Int) -> Size n -> Size n
forall a b c (n :: Nat).
(a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
Vec.zipWith (\(Int
b, Int
a) Int
n -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a) Vec n (Int, Int)
paddingSize (Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
tensor)
go :: forall m. Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go :: forall (m :: Nat).
Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go Vec m (Int, Int)
VNil Vec m Int
VNil (Scalar a
x) = a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar a
x
go ((Int
before, Int
after) ::: Vec n1 (Int, Int)
ps) (Int
_ ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xs) = [Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor ('S n1) a)
-> [Tensor n1 a] -> Tensor ('S n1) a
forall a b. (a -> b) -> a -> b
$ [[Tensor n1 a]] -> [Tensor n1 a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
Int -> Tensor n1 a -> [Tensor n1 a]
forall a. Int -> a -> [a]
L.replicate Int
before (Tensor n1 a -> [Tensor n1 a]) -> Tensor n1 a -> [Tensor n1 a]
forall a b. (a -> b) -> a -> b
$ Vec n1 Int -> a -> Tensor n1 a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
ns a
padding
, (Tensor n1 a -> Tensor n1 a) -> [Tensor n1 a] -> [Tensor n1 a]
forall a b. (a -> b) -> [a] -> [b]
map (Vec n1 (Int, Int) -> Vec n1 Int -> Tensor n1 a -> Tensor n1 a
forall (m :: Nat).
Vec m (Int, Int) -> Size m -> Tensor m a -> Tensor m a
go Vec n1 (Int, Int)
Vec n1 (Int, Int)
ps Vec n1 Int
ns) [Tensor n1 a]
[Tensor n a]
xs
, Int -> Tensor n1 a -> [Tensor n1 a]
forall a. Int -> a -> [a]
L.replicate Int
after (Tensor n1 a -> [Tensor n1 a]) -> Tensor n1 a -> [Tensor n1 a]
forall a b. (a -> b) -> a -> b
$ Vec n1 Int -> a -> Tensor n1 a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
ns a
padding
]
arbitraryOfSize :: Size n -> Gen a -> Gen (Tensor n a)
arbitraryOfSize :: forall (n :: Nat) a. Size n -> Gen a -> Gen (Tensor n a)
arbitraryOfSize Size n
sz = Tensor n (Gen a) -> Gen (Tensor n a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => Tensor n (m a) -> m (Tensor n a)
sequence (Tensor n (Gen a) -> Gen (Tensor n a))
-> (Gen a -> Tensor n (Gen a)) -> Gen a -> Gen (Tensor n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Size n -> Gen a -> Tensor n (Gen a)
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Size n
sz
data Axe (n :: Nat) where
AxeHere :: (Int, Int) -> Axe (S n)
AxeNested :: Axe n -> Axe (S n)
deriving instance Show (Axe n)
axeSize :: Size n -> Axe n -> Int
axeSize :: forall (n :: Nat). Size n -> Axe n -> Int
axeSize = (Axe n -> Size n -> Int) -> Size n -> Axe n -> Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip Axe n -> Size n -> Int
forall (n :: Nat). Axe n -> Size n -> Int
go
where
go :: Axe n -> Size n -> Int
go :: forall (n :: Nat). Axe n -> Size n -> Int
go (AxeHere (Int
_, Int
len)) (Int
_ ::: Vec n1 Int
ns) = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int -> Int -> Int) -> Int -> Vec n1 Int -> Int
forall b a. (b -> a -> b) -> b -> Vec n1 a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 Vec n1 Int
ns
go (AxeNested Axe n
axe) (Int
n ::: Vec n1 Int
ns) = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Axe n -> Size n -> Int
forall (n :: Nat). Axe n -> Size n -> Int
go Axe n
axe Size n
Vec n1 Int
ns
allAxes :: Size n -> [Axe n]
allAxes :: forall (n :: Nat). Size n -> [Axe n]
allAxes = \Size n
sz ->
(Axe n -> Axe n -> Ordering) -> [Axe n] -> [Axe n]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy ((Axe n -> Axe n -> Ordering) -> Axe n -> Axe n -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Axe n -> Axe n -> Ordering) -> Axe n -> Axe n -> Ordering)
-> (Axe n -> Axe n -> Ordering) -> Axe n -> Axe n -> Ordering
forall a b. (a -> b) -> a -> b
$ (Axe n -> Int) -> Axe n -> Axe n -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Size n -> Axe n -> Int
forall (n :: Nat). Size n -> Axe n -> Int
axeSize Size n
sz)) ([Axe n] -> [Axe n]) -> [Axe n] -> [Axe n]
forall a b. (a -> b) -> a -> b
$ Size n -> [Axe n]
forall (n :: Nat). Size n -> [Axe n]
go Size n
sz
where
go :: Size n -> [Axe n]
go :: forall (n :: Nat). Size n -> [Axe n]
go Vec n Int
VNil = []
go (Int
n ::: Vec n1 Int
ns) = [[Axe n]] -> [Axe n]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
[[Axe n]] -> [Axe n]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
((Int, Int) -> Axe n) -> [(Int, Int)] -> [Axe n]
forall a b. (a -> b) -> [a] -> [b]
L.map (Int, Int) -> Axe n
(Int, Int) -> Axe ('S n1)
forall (n :: Nat). (Int, Int) -> Axe ('S n)
AxeHere (Int -> Int -> Int -> [(Int, Int)]
removes Int
0 Int
k Int
n)
| Int
k <- (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) ((Int -> Int) -> Int -> [Int]
forall a. (a -> a) -> a -> [a]
iterate (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2))
]
, (Axe n1 -> Axe n) -> [Axe n1] -> [Axe n]
forall a b. (a -> b) -> [a] -> [b]
L.map Axe n1 -> Axe n
Axe n1 -> Axe ('S n1)
forall (n :: Nat). Axe n -> Axe ('S n)
AxeNested (Vec n1 Int -> [Axe n1]
forall (n :: Nat). Size n -> [Axe n]
go Vec n1 Int
ns)
]
removes :: Int -> Int -> Int -> [(Int, Int)]
removes :: Int -> Int -> Int -> [(Int, Int)]
removes Int
offset Int
k Int
n
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = []
| Bool
otherwise = (Int
offset, Int
k) (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
forall a. a -> [a] -> [a]
: Int -> Int -> Int -> [(Int, Int)]
removes (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) Int
k (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)
axeWith :: Axe n -> Tensor n a -> Tensor n a
axeWith :: forall (n :: Nat) a. Axe n -> Tensor n a -> Tensor n a
axeWith (AxeHere (Int
offset, Int
len)) (Tensor [Tensor n a]
xss) = [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$
[Tensor n a]
before [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. Semigroup a => a -> a -> a
<> [Tensor n a]
after
where
([Tensor n a]
before, [Tensor n a]
dropFrom) = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
offset [Tensor n a]
xss
([Tensor n a]
_dropped, [Tensor n a]
after) = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
len [Tensor n a]
dropFrom
axeWith (AxeNested Axe n
axe) (Tensor [Tensor n a]
xss) = [Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$
(Tensor n a -> Tensor n a) -> [Tensor n a] -> [Tensor n a]
forall a b. (a -> b) -> [a] -> [b]
L.map (Axe n -> Tensor n a -> Tensor n a
forall (n :: Nat) a. Axe n -> Tensor n a -> Tensor n a
axeWith Axe n
axe) [Tensor n a]
[Tensor n a]
xss
data Zero a where
Zero :: Eq a => a -> Zero a
zero :: (Num a, Eq a) => Zero a
zero :: forall a. (Num a, Eq a) => Zero a
zero = a -> Zero a
forall a. Eq a => a -> Zero a
Zero a
0
zeroWith :: forall n a. Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
zeroWith :: forall (n :: Nat) a.
Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
zeroWith (Zero a
z) = \Axe n
axe Tensor n a
tensor ->
case Axe n -> Size n -> Tensor n a -> (Tensor n a, Bool)
forall (n' :: Nat).
Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go Axe n
axe (Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
tensor) Tensor n a
tensor of
(Tensor n a
_, Bool
False) -> Maybe (Tensor n a)
forall a. Maybe a
Nothing
(Tensor n a
tensor', Bool
True) -> Tensor n a -> Maybe (Tensor n a)
forall a. a -> Maybe a
Just Tensor n a
tensor'
where
go :: forall n'. Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go :: forall (n' :: Nat).
Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go (AxeHere (Int
offset, Int
len)) (Int
_ ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xss) = (
[Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$ [Tensor n a]
before [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. Semigroup a => a -> a -> a
<> Int -> Tensor n a -> [Tensor n a]
forall a. Int -> a -> [a]
L.replicate Int
len (Size n -> a -> Tensor n a
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Vec n1 Int
Size n
ns a
z) [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. Semigroup a => a -> a -> a
<> [Tensor n a]
after
, (a -> Bool) -> Tensor ('S n) a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
z) ([Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor [Tensor n a]
dropped)
)
where
([Tensor n a]
before, [Tensor n a]
dropFrom) = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
offset [Tensor n a]
xss
([Tensor n a]
dropped, [Tensor n a]
after) = Int -> [Tensor n a] -> ([Tensor n a], [Tensor n a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
len [Tensor n a]
dropFrom
go (AxeNested Axe n
axe) (Int
_ ::: Vec n1 Int
ns) (Tensor [Tensor n a]
xss) =
([Tensor n a] -> Tensor n' a)
-> ([Bool] -> Bool)
-> ([Tensor n a], [Bool])
-> (Tensor n' a, Bool)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [Tensor n a] -> Tensor n' a
[Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (([Tensor n a], [Bool]) -> (Tensor n' a, Bool))
-> ([Tensor n a], [Bool]) -> (Tensor n' a, Bool)
forall a b. (a -> b) -> a -> b
$ [(Tensor n a, Bool)] -> ([Tensor n a], [Bool])
forall a b. [(a, b)] -> ([a], [b])
L.unzip ([(Tensor n a, Bool)] -> ([Tensor n a], [Bool]))
-> [(Tensor n a, Bool)] -> ([Tensor n a], [Bool])
forall a b. (a -> b) -> a -> b
$ (Tensor n a -> (Tensor n a, Bool))
-> [Tensor n a] -> [(Tensor n a, Bool)]
forall a b. (a -> b) -> [a] -> [b]
L.map (Axe n -> Size n -> Tensor n a -> (Tensor n a, Bool)
forall (n' :: Nat).
Axe n' -> Size n' -> Tensor n' a -> (Tensor n' a, Bool)
go Axe n
axe Size n
Vec n1 Int
ns) [Tensor n a]
[Tensor n a]
xss
shrinkWith ::
Maybe (Zero a)
-> (a -> [a])
-> Tensor n a -> [Tensor n a]
shrinkWith :: forall a (n :: Nat).
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xs = [Axe n]
-> Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall (n :: Nat) a.
[Axe n]
-> Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith' (Size n -> [Axe n]
forall (n :: Nat). Size n -> [Axe n]
allAxes (Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
xs)) Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xs
shrinkWith' :: forall n a.
[Axe n]
-> Maybe (Zero a)
-> (a -> [a])
-> Tensor n a -> [Tensor n a]
shrinkWith' :: forall (n :: Nat) a.
[Axe n]
-> Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith' [Axe n]
axes Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xss = [[Tensor n a]] -> [Tensor n a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
[Axe n -> Tensor n a -> Tensor n a
forall (n :: Nat) a. Axe n -> Tensor n a -> Tensor n a
axeWith Axe n
axe Tensor n a
xss | Axe n
axe <- [Axe n]
axes]
, Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall (n :: Nat) a.
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkElem Maybe (Zero a)
mZero a -> [a]
f Tensor n a
xss
]
shrinkElem :: forall n a.
Maybe (Zero a)
-> (a -> [a])
-> Tensor n a -> [Tensor n a]
shrinkElem :: forall (n :: Nat) a.
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkElem Maybe (Zero a)
mZero a -> [a]
f Tensor n a
tensor = [[Tensor n a]] -> [Tensor n a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
case Maybe (Zero a)
mZero of
Maybe (Zero a)
Nothing -> []
Just Zero a
z -> [Maybe (Tensor n a)] -> [Tensor n a]
forall a. [Maybe a] -> [a]
catMaybes [
Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
forall (n :: Nat) a.
Zero a -> Axe n -> Tensor n a -> Maybe (Tensor n a)
zeroWith Zero a
z Axe n
axe Tensor n a
tensor
| Axe n
axe <- Size n -> [Axe n]
forall (n :: Nat). Size n -> [Axe n]
allAxes Size n
overallSize
, Size n -> Axe n -> Int
forall (n :: Nat). Size n -> Axe n -> Int
axeSize Size n
overallSize Axe n
axe Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
]
, Tensor n a -> [Tensor n a]
forall (n' :: Nat). Tensor n' a -> [Tensor n' a]
shrinkOne Tensor n a
tensor
]
where
overallSize :: Size n
overallSize :: Size n
overallSize = Tensor n a -> Size n
forall (n :: Nat) a. Tensor n a -> Size n
size Tensor n a
tensor
shrinkOne :: forall n'. Tensor n' a -> [Tensor n' a]
shrinkOne :: forall (n' :: Nat). Tensor n' a -> [Tensor n' a]
shrinkOne (Scalar a
x) = a -> Tensor n' a
a -> Tensor 'Z a
forall a. a -> Tensor 'Z a
Scalar (a -> Tensor n' a) -> [a] -> [Tensor n' a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> [a]
f a
x
shrinkOne (Tensor [Tensor n a]
xss) = [
[Tensor n a] -> Tensor ('S n) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n a] -> Tensor ('S n) a)
-> [Tensor n a] -> Tensor ('S n) a
forall a b. (a -> b) -> a -> b
$ [Tensor n a]
before [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. [a] -> [a] -> [a]
++ [Tensor n a
xs'] [Tensor n a] -> [Tensor n a] -> [Tensor n a]
forall a. [a] -> [a] -> [a]
++ [Tensor n a]
after
| ([Tensor n a]
before, Tensor n a
xs, [Tensor n a]
after) <- [Tensor n a] -> [([Tensor n a], Tensor n a, [Tensor n a])]
forall a. [a] -> [([a], a, [a])]
pickOne [Tensor n a]
xss
, Tensor n a
xs' <- Tensor n a -> [Tensor n a]
forall (n' :: Nat). Tensor n' a -> [Tensor n' a]
shrinkOne Tensor n a
xs
]
instance (SNatI n, Arbitrary a, Num a, Eq a) => Arbitrary (Tensor n a) where
arbitrary :: Gen (Tensor n a)
arbitrary = Gen a -> Gen (Tensor n a)
forall a. Gen a -> Gen (Tensor n a)
forall (f :: * -> *) a. Arbitrary1 f => Gen a -> Gen (f a)
liftArbitrary Gen a
forall a. Arbitrary a => Gen a
arbitrary
shrink :: Tensor n a -> [Tensor n a]
shrink = Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall a (n :: Nat).
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith (Zero a -> Maybe (Zero a)
forall a. a -> Maybe a
Just (a -> Zero a
forall a. Eq a => a -> Zero a
Zero a
0)) a -> [a]
forall a. Arbitrary a => a -> [a]
shrink
instance SNatI n => Arbitrary1 (Tensor n) where
liftArbitrary :: forall a. Gen a -> Gen (Tensor n a)
liftArbitrary Gen a
g = (Int -> Gen (Tensor n a)) -> Gen (Tensor n a)
forall a. (Int -> Gen a) -> Gen a
QC.sized ((Int -> Gen (Tensor n a)) -> Gen (Tensor n a))
-> (Int -> Gen (Tensor n a)) -> Gen (Tensor n a)
forall a b. (a -> b) -> a -> b
$ \Int
n -> do
Size n
sz :: Size n <- Gen Int -> Gen (Size n)
forall a. Gen a -> Gen (Vec n a)
forall (f :: * -> *) a. Arbitrary1 f => Gen a -> Gen (f a)
liftArbitrary (Gen Int -> Gen (Size n)) -> Gen Int -> Gen (Size n)
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Gen Int
forall a. Random a => (a, a) -> Gen a
QC.choose (Int
1, Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n)
Size n -> Gen a -> Gen (Tensor n a)
forall (n :: Nat) a. Size n -> Gen a -> Gen (Tensor n a)
arbitraryOfSize Size n
sz Gen a
g
liftShrink :: forall a. (a -> [a]) -> Tensor n a -> [Tensor n a]
liftShrink a -> [a]
f = Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
forall a (n :: Nat).
Maybe (Zero a) -> (a -> [a]) -> Tensor n a -> [Tensor n a]
shrinkWith Maybe (Zero a)
forall a. Maybe a
Nothing a -> [a]
f
toStorable :: Storable a => Tensor n a -> Storable.Vector a
toStorable :: forall a (n :: Nat). Storable a => Tensor n a -> Vector a
toStorable = [a] -> Vector a
forall a. Storable a => [a] -> Vector a
Vector.fromList ([a] -> Vector a) -> (Tensor n a -> [a]) -> Tensor n a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor n a -> [a]
forall a. Tensor n a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList
fromStorable ::
(HasCallStack, Storable a)
=> Size n -> Storable.Vector a -> Tensor n a
fromStorable :: forall a (n :: Nat).
(HasCallStack, Storable a) =>
Size n -> Vector a -> Tensor n a
fromStorable Size n
sz = Size n -> [a] -> Tensor n a
forall (n :: Nat) a. Size n -> [a] -> Tensor n a
fromList Size n
sz ([a] -> Tensor n a) -> (Vector a -> [a]) -> Vector a -> Tensor n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> [a]
forall a. Storable a => Vector a -> [a]
Vector.toList
unsafeWithCArray :: Storable a => Tensor n a -> (Ptr a -> IO r) -> IO r
unsafeWithCArray :: forall a (n :: Nat) r.
Storable a =>
Tensor n a -> (Ptr a -> IO r) -> IO r
unsafeWithCArray Tensor n a
tensor = Vector a -> (Ptr a -> IO r) -> IO r
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
Vector.unsafeWith (Tensor n a -> Vector a
forall a (n :: Nat). Storable a => Tensor n a -> Vector a
toStorable Tensor n a
tensor)
unsafeFromCArray :: Storable a => Size n -> ForeignPtr a -> Tensor n a
unsafeFromCArray :: forall a (n :: Nat).
Storable a =>
Size n -> ForeignPtr a -> Tensor n a
unsafeFromCArray Size n
sz ForeignPtr a
fptr =
Size n -> Vector a -> Tensor n a
forall a (n :: Nat).
(HasCallStack, Storable a) =>
Size n -> Vector a -> Tensor n a
fromStorable Size n
sz (Vector a -> Tensor n a) -> Vector a -> Tensor n a
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> Int -> Vector a
forall a. ForeignPtr a -> Int -> Vector a
Vector.unsafeFromForeignPtr0 ForeignPtr a
fptr Int
n
where
n :: Int
n :: Int
n = (Int -> Int -> Int) -> Int -> Size n -> Int
forall b a. (b -> a -> b) -> b -> Vec n a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 Size n
sz
unsafeFromPrealloc ::
Storable a
=> Size n -> (Ptr a -> IO r) -> IO (Tensor n a, r)
unsafeFromPrealloc :: forall a (n :: Nat) r.
Storable a =>
Size n -> (Ptr a -> IO r) -> IO (Tensor n a, r)
unsafeFromPrealloc Size n
sz Ptr a -> IO r
k = do
ForeignPtr a
fptr <- Int -> IO (ForeignPtr a)
forall a. Storable a => Int -> IO (ForeignPtr a)
mallocForeignPtrArray Int
n
r
res <- ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr Ptr a -> IO r
k
(Tensor n a, r) -> IO (Tensor n a, r)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Size n -> ForeignPtr a -> Tensor n a
forall a (n :: Nat).
Storable a =>
Size n -> ForeignPtr a -> Tensor n a
unsafeFromCArray Size n
sz ForeignPtr a
fptr, r
res)
where
n :: Int
n :: Int
n = (Int -> Int -> Int) -> Int -> Size n -> Int
forall b a. (b -> a -> b) -> b -> Vec n a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 Size n
sz
unsafeFromPrealloc_ ::
Storable a
=> Size n -> (Ptr a -> IO ()) -> IO (Tensor n a)
unsafeFromPrealloc_ :: forall a (n :: Nat).
Storable a =>
Size n -> (Ptr a -> IO ()) -> IO (Tensor n a)
unsafeFromPrealloc_ Size n
sz = ((Tensor n a, ()) -> Tensor n a)
-> IO (Tensor n a, ()) -> IO (Tensor n a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor n a, ()) -> Tensor n a
forall a b. (a, b) -> a
fst (IO (Tensor n a, ()) -> IO (Tensor n a))
-> ((Ptr a -> IO ()) -> IO (Tensor n a, ()))
-> (Ptr a -> IO ())
-> IO (Tensor n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Size n -> (Ptr a -> IO ()) -> IO (Tensor n a, ())
forall a (n :: Nat) r.
Storable a =>
Size n -> (Ptr a -> IO r) -> IO (Tensor n a, r)
unsafeFromPrealloc Size n
sz
scalar :: a -> Tensor Nat0 a
scalar :: forall a. a -> Tensor 'Z a
scalar = a -> Tensor 'Z a
Lists 'Z a -> Tensor 'Z a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim1 :: [a] -> Tensor Nat1 a
dim1 :: forall a. [a] -> Tensor ('S 'Z) a
dim1 = [a] -> Tensor ('S 'Z) a
Lists ('S 'Z) a -> Tensor ('S 'Z) a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim2 :: [[a]] -> Tensor Nat2 a
dim2 :: forall a. [[a]] -> Tensor Nat2 a
dim2 = [[a]] -> Tensor Nat2 a
Lists Nat2 a -> Tensor Nat2 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim3 :: [[[a]]] -> Tensor Nat3 a
dim3 :: forall a. [[[a]]] -> Tensor Nat3 a
dim3 = [[[a]]] -> Tensor Nat3 a
Lists Nat3 a -> Tensor Nat3 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim4 :: [[[[a]]]] -> Tensor Nat4 a
dim4 :: forall a. [[[[a]]]] -> Tensor Nat4 a
dim4 = [[[[a]]]] -> Tensor Nat4 a
Lists Nat4 a -> Tensor Nat4 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim5 :: [[[[[a]]]]] -> Tensor Nat5 a
dim5 :: forall a. [[[[[a]]]]] -> Tensor Nat5 a
dim5 = [[[[[a]]]]] -> Tensor Nat5 a
Lists Nat5 a -> Tensor Nat5 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim6 :: [[[[[[a]]]]]] -> Tensor Nat6 a
dim6 :: forall a. [[[[[[a]]]]]] -> Tensor Nat6 a
dim6 = [[[[[[a]]]]]] -> Tensor Nat6 a
Lists Nat6 a -> Tensor Nat6 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim7 :: [[[[[[[a]]]]]]] -> Tensor Nat7 a
dim7 :: forall a. [[[[[[[a]]]]]]] -> Tensor Nat7 a
dim7 = [[[[[[[a]]]]]]] -> Tensor Nat7 a
Lists Nat7 a -> Tensor Nat7 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim8 :: [[[[[[[[a]]]]]]]] -> Tensor Nat8 a
dim8 :: forall a. [[[[[[[[a]]]]]]]] -> Tensor Nat8 a
dim8 = [[[[[[[[a]]]]]]]] -> Tensor Nat8 a
Lists Nat8 a -> Tensor Nat8 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
dim9 :: [[[[[[[[[a]]]]]]]]] -> Tensor Nat9 a
dim9 :: forall a. [[[[[[[[[a]]]]]]]]] -> Tensor Nat9 a
dim9 = [[[[[[[[[a]]]]]]]]] -> Tensor Nat9 a
Lists Nat9 a -> Tensor Nat9 a
forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists
type family Lists n a where
Lists Z a = a
Lists (S n) a = [Lists n a]
toLists :: Tensor n a -> Lists n a
toLists :: forall (n :: Nat) a. Tensor n a -> Lists n a
toLists (Scalar a
x) = a
Lists n a
x
toLists (Tensor [Tensor n a]
xs) = (Tensor n a -> Lists n a) -> [Tensor n a] -> [Lists n a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor n a -> Lists n a
forall (n :: Nat) a. Tensor n a -> Lists n a
toLists [Tensor n a]
xs
fromLists :: SNatI n => Lists n a -> Tensor n a
fromLists :: forall (n :: Nat) a. SNatI n => Lists n a -> Tensor n a
fromLists = SNat n -> Lists n a -> Tensor n a
forall (n :: Nat) a. SNat n -> Lists n a -> Tensor n a
go SNat n
forall (n :: Nat). SNatI n => SNat n
snat
where
go :: SNat n -> Lists n a -> Tensor n a
go :: forall (n :: Nat) a. SNat n -> Lists n a -> Tensor n a
go SNat n
SZ = a -> Tensor 'Z a
Lists n a -> Tensor n a
forall a. a -> Tensor 'Z a
Scalar
go SNat n
SS = [Tensor n1 a] -> Tensor n a
[Tensor n1 a] -> Tensor ('S n1) a
forall (n :: Nat) a. [Tensor n a] -> Tensor ('S n) a
Tensor ([Tensor n1 a] -> Tensor n a)
-> ([Lists n1 a] -> [Tensor n1 a]) -> [Lists n1 a] -> Tensor n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lists n1 a -> Tensor n1 a) -> [Lists n1 a] -> [Tensor n1 a]
forall a b. (a -> b) -> [a] -> [b]
map (SNat n1 -> Lists n1 a -> Tensor n1 a
forall (n :: Nat) a. SNat n -> Lists n a -> Tensor n a
go SNat n1
forall (n :: Nat). SNatI n => SNat n
snat)
fromList :: forall n a. Size n -> [a] -> Tensor n a
fromList :: forall (n :: Nat) a. Size n -> [a] -> Tensor n a
fromList Size n
sz [a]
xs =
Maybe (Tensor n a) -> Tensor n a
checkEnoughElems (Maybe (Tensor n a) -> Tensor n a)
-> (StateT [a] Maybe (Tensor n a) -> Maybe (Tensor n a))
-> StateT [a] Maybe (Tensor n a)
-> Tensor n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StateT [a] Maybe (Tensor n a) -> [a] -> Maybe (Tensor n a))
-> [a] -> StateT [a] Maybe (Tensor n a) -> Maybe (Tensor n a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT [a] Maybe (Tensor n a) -> [a] -> Maybe (Tensor n a)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT [a]
xs (StateT [a] Maybe (Tensor n a) -> Tensor n a)
-> StateT [a] Maybe (Tensor n a) -> Tensor n a
forall a b. (a -> b) -> a -> b
$ Tensor n (StateT [a] Maybe a) -> StateT [a] Maybe (Tensor n a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a.
Applicative f =>
Tensor n (f a) -> f (Tensor n a)
sequenceA (Size n -> StateT [a] Maybe a -> Tensor n (StateT [a] Maybe a)
forall (n :: Nat) a. Size n -> a -> Tensor n a
replicate Size n
sz StateT [a] Maybe a
genElem)
where
genElem :: StateT [a] Maybe a
genElem :: StateT [a] Maybe a
genElem = ([a] -> Maybe (a, [a])) -> StateT [a] Maybe a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT [a] -> Maybe (a, [a])
forall a. [a] -> Maybe (a, [a])
L.uncons
checkEnoughElems :: Maybe (Tensor n a) -> Tensor n a
checkEnoughElems :: Maybe (Tensor n a) -> Tensor n a
checkEnoughElems Maybe (Tensor n a)
Nothing = [Char] -> Tensor n a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromList: insufficient elements"
checkEnoughElems (Just Tensor n a
t) = Tensor n a
t
showLists :: Show a => Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists :: forall a (n :: Nat) r.
Show a =>
Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists Proxy a
_ SNat n
SZ Show (Lists n a) => r
k = r
Show (Lists n a) => r
k
showLists Proxy a
p (SS' SNat n
n) Show (Lists n a) => r
k = Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
forall a (n :: Nat) r.
Show a =>
Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists Proxy a
p SNat n
n r
Show (Lists n a) => r
Show (Lists n a) => r
k
showConstructor :: Int -> SNat n -> ShowS
showConstructor :: forall (n :: Nat). Int -> SNat n -> ShowS
showConstructor Int
p SNat n
sn
| Natural
n' Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
0 = [Char] -> ShowS
showString [Char]
"scalar"
| Natural
1 Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
n' Bool -> Bool -> Bool
&& Natural
n' Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
9 = [Char] -> ShowS
showString [Char]
"dim" ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> ShowS
forall a. Show a => a -> ShowS
shows Natural
n'
| Bool
otherwise = [Char] -> ShowS
showString [Char]
"fromLists @"
ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Nat -> ShowS
explicitShowsPrec Int
p (SNat n -> Nat
forall (n :: Nat). SNat n -> Nat
snatToNat SNat n
sn)
where
n' :: Natural
n' :: Natural
n' = SNat n -> Natural
forall (n :: Nat). SNat n -> Natural
snatToNatural SNat n
sn
instance Show a => Show (Tensor n a) where
showsPrec :: Int -> Tensor n a -> ShowS
showsPrec Int
p Tensor n a
tensor = Proxy a -> SNat n -> (Show (Lists n a) => ShowS) -> ShowS
forall a (n :: Nat) r.
Show a =>
Proxy a -> SNat n -> (Show (Lists n a) => r) -> r
showLists (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a) (Tensor n a -> SNat n
forall (n :: Nat) a. Tensor n a -> SNat n
tensorSNat Tensor n a
tensor) ((Show (Lists n a) => ShowS) -> ShowS)
-> (Show (Lists n a) => ShowS) -> ShowS
forall a b. (a -> b) -> a -> b
$
Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
appPrec1) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
Int -> SNat n -> ShowS
forall (n :: Nat). Int -> SNat n -> ShowS
showConstructor Int
appPrec1 (Tensor n a -> SNat n
forall (n :: Nat) a. Tensor n a -> SNat n
tensorSNat Tensor n a
tensor)
ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
showSpace
ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Lists n a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
appPrec1 (Tensor n a -> Lists n a
forall (n :: Nat) a. Tensor n a -> Lists n a
toLists Tensor n a
tensor)
tensorSNatI :: Tensor n a -> (SNatI n => r) -> r
tensorSNatI :: forall (n :: Nat) a r. Tensor n a -> (SNatI n => r) -> r
tensorSNatI (Scalar a
_) SNatI n => r
k = r
SNatI n => r
k
tensorSNatI (Tensor [Tensor n a]
xs) SNatI n => r
k = Tensor n a -> (SNatI n => r) -> r
forall (n :: Nat) a r. Tensor n a -> (SNatI n => r) -> r
tensorSNatI ([Tensor n a] -> Tensor n a
forall a. HasCallStack => [a] -> a
L.head [Tensor n a]
xs) r
SNatI n => r
SNatI n => r
k
tensorSNat :: Tensor n a -> SNat n
tensorSNat :: forall (n :: Nat) a. Tensor n a -> SNat n
tensorSNat Tensor n a
tensor = Tensor n a -> (SNatI n => SNat n) -> SNat n
forall (n :: Nat) a r. Tensor n a -> (SNatI n => r) -> r
tensorSNatI Tensor n a
tensor SNat n
SNatI n => SNat n
forall (n :: Nat). SNatI n => SNat n
snat
consecutive :: Int -> [a] -> [[a]]
consecutive :: forall a. Int -> [a] -> [[a]]
consecutive Int
n = ([a] -> Bool) -> [[a]] -> [[a]]
forall a. (a -> Bool) -> [a] -> [a]
L.takeWhile ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n) (Int -> Bool) -> ([a] -> Int) -> [a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) ([[a]] -> [[a]]) -> ([a] -> [[a]]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
L.take Int
n) ([[a]] -> [[a]]) -> ([a] -> [[a]]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [[a]]
forall a. [a] -> [[a]]
L.tails
everyNth :: forall a. Int -> [a] -> [a]
everyNth :: forall a. Int -> [a] -> [a]
everyNth Int
n = \[a]
xs ->
if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
then [a] -> [a]
go [a]
xs
else [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"everyNth: n should be strictly positive"
where
go :: [a] -> [a]
go :: [a] -> [a]
go [] = []
go (a
x:[a]
xs) = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a]
go (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [a]
xs)
pickOne :: forall a. [a] -> [([a], a, [a])]
pickOne :: forall a. [a] -> [([a], a, [a])]
pickOne = \case
[] -> [Char] -> [([a], a, [a])]
forall a. HasCallStack => [Char] -> a
error [Char]
"pickOne: empty list"
a
x:[a]
xs -> [a] -> a -> [a] -> [([a], a, [a])]
go [] a
x [a]
xs
where
go :: [a] -> a -> [a] -> [([a], a, [a])]
go :: [a] -> a -> [a] -> [([a], a, [a])]
go [a]
acc a
x [] = [([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, a
x, [])]
go [a]
acc a
x (a
y:[a]
ys) = ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, a
x, (a
ya -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
ys)) ([a], a, [a]) -> [([a], a, [a])] -> [([a], a, [a])]
forall a. a -> [a] -> [a]
: [a] -> a -> [a] -> [([a], a, [a])]
go (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc) a
y [a]
ys