{-# LANGUAGE ImportQualifiedPost #-}
module Data.Array.Strided.Orthotope (
  module Data.Array.Strided.Orthotope,
  module Data.Array.Strided.Arith,
) where

import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS

import Data.Array.Strided qualified as AS
import Data.Array.Strided.Arith

-- for liftVEltwise1
import Data.Array.Strided.Arith.Internal (stridesDense)
import Data.Vector.Storable qualified as VS
import Foreign.Storable
import GHC.TypeLits


fromO :: RS.Array n a -> AS.Array n a
fromO :: forall (n :: Nat) a. Array n a -> Array n a
fromO (RS.A (RG.A ShapeL
sh (OI.T ShapeL
strides Int
offset Vector a
vec))) = ShapeL -> ShapeL -> Int -> Vector a -> Array n a
forall (n :: Nat) a.
ShapeL -> ShapeL -> Int -> Vector a -> Array n a
AS.Array ShapeL
sh ShapeL
strides Int
offset Vector a
vec

toO :: AS.Array n a -> RS.Array n a
toO :: forall (n :: Nat) a. Array n a -> Array n a
toO (AS.Array ShapeL
sh ShapeL
strides Int
offset Vector a
vec) = Array n Vector a -> Array n a
forall (n :: Nat) a. Array n Vector a -> Array n a
RS.A (ShapeL -> T Vector a -> Array n Vector a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
RG.A ShapeL
sh (ShapeL -> Int -> Vector a -> T Vector a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
OI.T ShapeL
strides Int
offset Vector a
vec))

liftO1 :: (AS.Array n a -> AS.Array n' b)
       -> RS.Array n a -> RS.Array n' b
liftO1 :: forall (n :: Nat) a (n' :: Nat) b.
(Array n a -> Array n' b) -> Array n a -> Array n' b
liftO1 Array n a -> Array n' b
f = Array n' b -> Array n' b
forall (n :: Nat) a. Array n a -> Array n a
toO (Array n' b -> Array n' b)
-> (Array n a -> Array n' b) -> Array n a -> Array n' b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array n a -> Array n' b
f (Array n a -> Array n' b)
-> (Array n a -> Array n a) -> Array n a -> Array n' b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array n a -> Array n a
forall (n :: Nat) a. Array n a -> Array n a
fromO

liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
       -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
liftO2 :: forall (n :: Nat) a (n1 :: Nat) b (n2 :: Nat) c.
(Array n a -> Array n1 b -> Array n2 c)
-> Array n a -> Array n1 b -> Array n2 c
liftO2 Array n a -> Array n1 b -> Array n2 c
f Array n a
x Array n1 b
y = Array n2 c -> Array n2 c
forall (n :: Nat) a. Array n a -> Array n a
toO (Array n a -> Array n1 b -> Array n2 c
f (Array n a -> Array n a
forall (n :: Nat) a. Array n a -> Array n a
fromO Array n a
x) (Array n1 b -> Array n1 b
forall (n :: Nat) a. Array n a -> Array n a
fromO Array n1 b
y))

liftVEltwise1 :: (Storable a, Storable b)
              => SNat n
              -> (VS.Vector a -> VS.Vector b)
              -> RS.Array n a -> RS.Array n b
liftVEltwise1 :: forall a b (n :: Nat).
(Storable a, Storable b) =>
SNat n -> (Vector a -> Vector b) -> Array n a -> Array n b
liftVEltwise1 SNat n
SNat Vector a -> Vector b
f arr :: Array n a
arr@(RS.A (RG.A ShapeL
sh (OI.T ShapeL
strides Int
offset Vector a
vec)))
  | Just (Int
blockOff, Int
blockSz) <- ShapeL -> Int -> ShapeL -> Maybe (Int, Int)
stridesDense ShapeL
sh Int
offset ShapeL
strides =
      let vec' :: Vector b
vec' = Vector a -> Vector b
f (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff Int
blockSz Vector a
vec)
      in Array n Vector b -> Array n b
forall (n :: Nat) a. Array n Vector a -> Array n a
RS.A (ShapeL -> T Vector b -> Array n Vector b
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
RG.A ShapeL
sh (ShapeL -> Int -> Vector b -> T Vector b
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
OI.T ShapeL
strides (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector b
vec'))
  | Bool
otherwise = ShapeL -> Vector b -> Array n b
forall a (n :: Nat).
(Unbox a, KnownNat n) =>
ShapeL -> Vector a -> Array n a
RS.fromVector ShapeL
sh (Vector a -> Vector b
f (Array n a -> Vector a
forall (n :: Nat) a. Unbox a => Array n a -> Vector a
RS.toVector Array n a
arr))