{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Strided.Array where

import Data.List.NonEmpty qualified as NE
import Data.Proxy
import Data.Vector.Storable qualified as VS
import Foreign.Storable
import GHC.TypeLits


data Array (n :: Nat) a = Array
  { forall (n :: Nat) a. Array n a -> [Int]
arrShape :: ![Int]
  , forall (n :: Nat) a. Array n a -> [Int]
arrStrides :: ![Int]
  , forall (n :: Nat) a. Array n a -> Int
arrOffset :: !Int
  , forall (n :: Nat) a. Array n a -> Vector a
arrValues :: !(VS.Vector a)
  }

-- | Takes a vector in normalised order (inner dimension, i.e. last in the
-- list, iterates fastest).
arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a
arrayFromVector :: forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh Vector a
vec
  | Vector a -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector a
vec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
shsize
  , [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n))
  = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh [Int]
strides Int
0 Vector a
vec
  | Bool
otherwise = [Char] -> Array n a
forall a. HasCallStack => [Char] -> a
error ([Char] -> Array n a) -> [Char] -> Array n a
forall a b. (a -> b) -> a -> b
$ [Char]
"arrayFromVector: Shape " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Int] -> [Char]
forall a. Show a => a -> [Char]
show [Int]
sh [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" does not match vector length " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector a -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector a
vec)
  where
    shsize :: Int
shsize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh
    strides :: [Int]
strides = NonEmpty Int -> [Int]
forall a. NonEmpty a -> [a]
NE.tail ((Int -> Int -> Int) -> Int -> [Int] -> NonEmpty Int
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> f a -> NonEmpty b
NE.scanr Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 [Int]
sh)

arrayFromConstant :: Storable a => [Int] -> a -> Array n a
arrayFromConstant :: forall a (n :: Nat). Storable a => [Int] -> a -> Array n a
arrayFromConstant [Int]
sh a
x = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int]
sh) Int
0 (a -> Vector a
forall a. Storable a => a -> Vector a
VS.singleton a
x)

arrayRevDims :: [Bool] -> Array n a -> Array n a
arrayRevDims :: forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
bs (Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
  | [Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Bool]
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh =
      [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh
            ((Bool -> Int -> Int) -> [Bool] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Bool
b Int
s -> if Bool
b then -Int
s else Int
s) [Bool]
bs [Int]
strides)
            (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Bool -> Int -> Int -> Int) -> [Bool] -> [Int] -> [Int] -> [Int]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\Bool
b Int
n Int
s -> if Bool
b then (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
s else Int
0) [Bool]
bs [Int]
sh [Int]
strides))
            Vector a
vec
  | Bool
otherwise = [Char] -> Array n a
forall a. HasCallStack => [Char] -> a
error ([Char] -> Array n a) -> [Char] -> Array n a
forall a b. (a -> b) -> a -> b
$ [Char]
"arrayRevDims: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show ([Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Bool]
bs) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" booleans given but rank " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh)