{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Strided.Arith.Internal where
import Control.Monad
import Data.Bifunctor (second)
import Data.Bits
import Data.Int
import Data.List (sort, zip4)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
import Data.Array.Strided.Arith.Internal.Foreign
import Data.Array.Strided.Arith.Internal.Lists
import Data.Array.Strided.Array
fromSNat' :: SNat n -> Int
fromSNat' :: forall (n :: Nat). SNat n -> Int
fromSNat' = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (SNat n -> Integer) -> SNat n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
fromSNat
data Dict c where
Dict :: c => Dict c
debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String
debugShow :: forall (n :: Nat) a.
(Storable a, KnownNat n) =>
Array n a -> String
debugShow (Array [Int]
sh [Int]
strides Int
offset Vector a
vec) =
String
"Array @" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
strides String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
offset String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" <_*" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Vector a -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector a
vec) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
">"
liftOpEltwise1 :: Storable a
=> SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a -> Array n a
liftOpEltwise1 :: forall a (n :: Nat) b.
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
liftOpEltwise1 sn :: SNat n
sn@SNat n
SNat Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided arr :: Array n a
arr@(Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
| Just (Int
blockOff, Int
blockSz) <- [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh Int
offset [Int]
strides =
if Int
blockSz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) [Int]
strides) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
else let resvec :: Vector a
resvec = Array n a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array n a -> Vector a) -> Array n a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary SNat n
sn Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int
blockSz] [Int
1] Int
blockOff Vector a
vec)
in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh [Int]
strides (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector a
resvec
| Bool
otherwise = SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary SNat n
sn Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided Array n a
arr
liftOpEltwise2 :: Storable a
=> SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a -> Array n a -> Array n a
liftOpEltwise2 :: forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
liftOpEltwise2 sn :: SNat n
sn@SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv a -> a -> a
f_ss Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
f_sv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
f_vs Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
f_vv
arr1 :: Array n a
arr1@(Array [Int]
sh1 [Int]
strides1 Int
offset1 Vector a
vec1)
arr2 :: Array n a
arr2@(Array [Int]
sh2 [Int]
strides2 Int
offset2 Vector a
vec2)
| [Int]
sh1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"liftOpEltwise2: shapes unequal: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" vs " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh2
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh1 = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int]
strides1) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
| Bool
otherwise = case ([Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh1 Int
offset1 [Int]
strides1, [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh2 Int
offset2 [Int]
strides2) of
(Just (Int
_, Int
1), Just (Int
_, Int
1)) ->
let vec' :: Vector a
vec' = a -> Vector a
forall a. Storable a => a -> Vector a
VS.singleton (a -> a -> a
f_ss (Vector a
vec1 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset1) (Vector a
vec2 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset2))
in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides1 Int
0 Vector a
vec'
(Just (Int
_, Int
1), Just (Int
blockOff, Int
blockSz)) ->
let arr2' :: Array 1 a
arr2' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff Int
blockSz Vector a
vec2)
resvec :: Vector a
resvec = Array 1 a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array 1 a -> Vector a) -> Array 1 a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat 1
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array 1 a
-> Array 1 a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV (forall (n :: Nat). KnownNat n => SNat n
SNat @1) a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
f_sv (Vector a
vec1 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset1) Array 1 a
arr2'
in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides2 (Int
offset2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector a
resvec
(Just (Int
_, Int
1), Maybe (Int, Int)
Nothing) ->
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
f_sv (Vector a
vec1 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset1) Array n a
arr2
(Just (Int
blockOff, Int
blockSz), Just (Int
_, Int
1)) ->
let arr1' :: Array 1 a
arr1' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff Int
blockSz Vector a
vec1)
resvec :: Vector a
resvec = Array 1 a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array 1 a -> Vector a) -> Array 1 a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat 1
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array 1 a
-> a
-> Array 1 a
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
wrapBinaryVS (forall (n :: Nat). KnownNat n => SNat n
SNat @1) a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
f_vs Array 1 a
arr1' (Vector a
vec2 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset2)
in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides1 (Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector a
resvec
(Maybe (Int, Int)
Nothing, Just (Int
_, Int
1)) ->
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
wrapBinaryVS SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
f_vs Array n a
arr1 (Vector a
vec2 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset2)
(Just (Int
blockOff1, Int
blockSz1), Just (Int
blockOff2, Int
blockSz2))
| [Int]
strides1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
strides2
->
if Int
blockSz1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
blockSz2 Bool -> Bool -> Bool
|| Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
offset2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff2
then String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ([Int], (Int, Int), [Int], (Int, Int)) -> String
forall a. Show a => a -> String
show ([Int]
strides1, (Int
blockOff1, Int
blockSz1), [Int]
strides2, (Int
blockOff2, Int
blockSz2))
else
let arr1' :: Array 1 a
arr1' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz1] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff1 Int
blockSz1 Vector a
vec1)
arr2' :: Array 1 a
arr2' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz1] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff2 Int
blockSz2 Vector a
vec2)
resvec :: Vector a
resvec = Array 1 a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array 1 a -> Vector a) -> Array 1 a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat 1
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array 1 a
-> Array 1 a
-> Array 1 a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
wrapBinaryVV (forall (n :: Nat). KnownNat n => SNat n
SNat @1) Ptr a -> Ptr b
ptrconv Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
f_vv Array 1 a
arr1' Array 1 a
arr2'
in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides1 (Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff1) Vector a
resvec
(Maybe (Int, Int)
_, Maybe (Int, Int)
_) ->
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
wrapBinaryVV SNat n
sn Ptr a -> Ptr b
ptrconv Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
f_vv Array n a
arr1 Array n a
arr2
stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh Int
offset [Int]
_ | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
offset, Int
0)
stridesDense [Int]
sh Int
offsetNeg [Int]
stridesNeg =
let (Int
offset, [Int]
strides) = [Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [Int]
sh Int
offsetNeg [Int]
stridesNeg
in
case ((Int, Int) -> Bool) -> [(Int, Int)] -> [(Int, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (Int -> Bool) -> ((Int, Int) -> Int) -> (Int, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Int) -> Int
forall a b. (a, b) -> a
fst) ([(Int, Int)] -> [(Int, Int)]
forall a. Ord a => [a] -> [a]
sort ([Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
strides [Int]
sh)) of
[] -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
offset, Int
1)
(Int
1, Int
n) : [(Int, Int)]
pairs -> (Int
offset,) (Int -> (Int, Int)) -> Maybe Int -> Maybe (Int, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [(Int, Int)] -> Maybe Int
checkCover Int
n [(Int, Int)]
pairs
[(Int, Int)]
_ -> Maybe (Int, Int)
forall a. Maybe a
Nothing
where
checkCover :: Int -> [(Int, Int)] -> Maybe Int
checkCover :: Int -> [(Int, Int)] -> Maybe Int
checkCover Int
block [] = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
block
checkCover Int
block ((Int
s, Int
n) : [(Int, Int)]
pairs) = Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
block) Maybe () -> Maybe Int -> Maybe Int
forall a b. Maybe a -> Maybe b -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> [(Int, Int)] -> Maybe Int
checkCover ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
block) [(Int, Int)]
pairs
flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [] Int
off [] = (Int
off, [])
flipReverseds (Int
n : [Int]
sh') Int
off (Int
s : [Int]
str')
| Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = ([Int] -> [Int]) -> (Int, [Int]) -> (Int, [Int])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Int
s Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) ([Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [Int]
sh' Int
off [Int]
str')
| Bool
otherwise =
let off' :: Int
off' = Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
s
in ([Int] -> [Int]) -> (Int, [Int]) -> (Int, [Int])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((-Int
s) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) ([Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [Int]
sh' Int
off' [Int]
str')
flipReverseds [Int]
_ Int
_ [Int]
_ = String -> (Int, [Int])
forall a. HasCallStack => String -> a
error String
"flipReverseds: invalid arguments"
data Unreplicated a =
forall n'. KnownNat n' =>
Unreplicated
(Array n' a)
Int
([Int] -> [Int])
unreplicateStrides :: Array n a -> Unreplicated a
unreplicateStrides :: forall (n :: Nat) a. Array n a -> Unreplicated a
unreplicateStrides (Array [Int]
sh [Int]
strides Int
offset Vector a
vec) =
let replDims :: [Bool]
replDims = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides
([Int]
shF, [Int]
stridesF) = [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int
n, Int
s) | (Int
n, Int
s) <- [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
sh [Int]
strides, Int
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0]
reinsertZeros :: [Bool] -> [a] -> [a]
reinsertZeros (Bool
False : [Bool]
zeros) (a
s : [a]
strides') = a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
reinsertZeros (Bool
True : [Bool]
zeros) [a]
strides' = a
0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
reinsertZeros [] [] = []
reinsertZeros (Bool
False : [Bool]
_) [] = String -> [a]
forall a. HasCallStack => String -> a
error String
"unreplicateStrides: Internal error: reply strides too short"
reinsertZeros [] (a
_:[a]
_) = String -> [a]
forall a. HasCallStack => String -> a
error String
"unreplicateStrides: Internal error: reply strides too long"
unrepSize :: Int
unrepSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int
n | (Int
n, Bool
True) <- [Int] -> [Bool] -> [(Int, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
sh [Bool]
replDims]
in Nat
-> (forall {n :: Nat}. SNat n -> Unreplicated a) -> Unreplicated a
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shF)) ((forall {n :: Nat}. SNat n -> Unreplicated a) -> Unreplicated a)
-> (forall {n :: Nat}. SNat n -> Unreplicated a) -> Unreplicated a
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat lenshF) ->
Array n a -> Int -> ([Int] -> [Int]) -> Unreplicated a
forall a (n' :: Nat).
KnownNat n' =>
Array n' a -> Int -> ([Int] -> [Int]) -> Unreplicated a
Unreplicated (forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array @lenshF [Int]
shF [Int]
stridesF Int
offset Vector a
vec) Int
unrepSize ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros [Bool]
replDims)
simplifyArray :: Array n a
-> (forall n'. KnownNat n'
=> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray :: forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray Array n a
array forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k
| let revDims :: [Bool]
revDims = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
0) (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
array)
, Unreplicated Array n' a
array' Int
unrepSize [Int] -> [Int]
rereplicate <- Array n a -> Unreplicated a
forall (n :: Nat) a. Array n a -> Unreplicated a
unreplicateStrides ([Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims Array n a
array)
= Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k Array n' a
array'
Int
unrepSize
(\[Int]
idx -> [Int] -> [Int]
rereplicate ((Bool -> Int -> Int -> Int) -> [Bool] -> [Int] -> [Int] -> [Int]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\Bool
b Int
n Int
i -> if Bool
b then Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i else Int
i)
[Bool]
revDims (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') [Int]
idx))
(\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
if [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array'
then [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n a
array) ([Int] -> [Int]
rereplicate [Int]
strides') Int
offset' Vector a
vec')
else String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
(\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
if | [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') ->
String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error (String -> Array (n - 1) a) -> String -> Array (n - 1) a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
array) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error String
"simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
| Bool
otherwise ->
[Bool] -> Array (n - 1) a -> Array (n - 1) a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
init [Bool]
revDims) ([Int] -> [Int] -> Int -> Vector a -> Array (n - 1) a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n a
array)) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init ([Int] -> [Int]
rereplicate ([Int]
strides' [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]))) Int
offset' Vector a
vec'))
simplifyArray2 :: Array n a -> Array n a
-> (forall n'. KnownNat n'
=> Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray2 :: forall (n :: Nat) a r.
Array n a
-> Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray2 arr1 :: Array n a
arr1@(Array [Int]
sh [Int]
_ Int
_ Vector a
_) arr2 :: Array n a
arr2@(Array [Int]
sh2 [Int]
_ Int
_ Vector a
_) forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k
| [Int]
sh [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> r
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Unequal shapes"
| let revDims :: [Bool]
revDims = (Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
s1 Int
s2 -> Int
s1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
&& Int
s2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
arr1) (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
arr2)
, Array [Int]
_ [Int]
strides1 Int
offset1 Vector a
vec1 <- [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims Array n a
arr1
, Array [Int]
_ [Int]
strides2 Int
offset2 Vector a
vec2 <- [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims Array n a
arr2
, let replDims :: [Bool]
replDims = (Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
s1 Int
s2 -> Int
s1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
s2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides1 [Int]
strides2
, let ([Int]
shF, [Int]
strides1F, [Int]
strides2F) = [(Int, Int, Int)] -> ([Int], [Int], [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Int
n, Int
s1, Int
s2) | (Int
n, Int
s1, Int
s2, Bool
False) <- [Int] -> [Int] -> [Int] -> [Bool] -> [(Int, Int, Int, Bool)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int]
sh [Int]
strides1 [Int]
strides2 [Bool]
replDims]
, let reinsertZeros :: [Bool] -> [a] -> [a]
reinsertZeros (Bool
False : [Bool]
zeros) (a
s : [a]
strides') = a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
reinsertZeros (Bool
True : [Bool]
zeros) [a]
strides' = a
0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
reinsertZeros [] [] = []
reinsertZeros (Bool
False : [Bool]
_) [] = String -> [a]
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Internal error: reply strides too short"
reinsertZeros [] (a
_:[a]
_) = String -> [a]
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Internal error: reply strides too long"
, let unrepSize :: Int
unrepSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int
n | (Int
n, Bool
True) <- [Int] -> [Bool] -> [(Int, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
sh [Bool]
replDims]
= Nat -> (forall {n :: Nat}. SNat n -> r) -> r
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shF)) ((forall {n :: Nat}. SNat n -> r) -> r)
-> (forall {n :: Nat}. SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat lenshF) ->
forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k @lenshF
([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
shF [Int]
strides1F Int
offset1 Vector a
vec1)
([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
shF [Int]
strides2F Int
offset2 Vector a
vec2)
Int
unrepSize
(\[Int]
idx -> (Bool -> Int -> Int -> Int) -> [Bool] -> [Int] -> [Int] -> [Int]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\Bool
b Int
n Int
i -> if Bool
b then Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i else Int
i)
[Bool]
revDims [Int]
sh ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros [Bool]
replDims [Int]
idx))
(\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
if [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
shF then String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray2: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
shF String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
else [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros [Bool]
replDims [Int]
strides') Int
offset' Vector a
vec'))
(\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
if | [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
shF ->
String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error (String -> Array (n - 1) a) -> String -> Array (n - 1) a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray2: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
shF String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
| [Bool] -> Bool
forall a. HasCallStack => [a] -> a
last [Bool]
replDims ->
String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
| Bool
otherwise ->
[Bool] -> Array (n - 1) a -> Array (n - 1) a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
init [Bool]
revDims) ([Int] -> [Int] -> Int -> Vector a -> Array (n - 1) a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
init [Bool]
replDims) [Int]
strides') Int
offset' Vector a
vec'))
{-# NOINLINE wrapUnary #-}
wrapUnary :: forall a b n. Storable a
=> SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary :: forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary SNat n
_ Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided Array n a
array =
Array n a
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> Array n a)
-> Array n a)
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh [Int]
strides Int
offset Vector a
vec) Int
_ [Int] -> [Int]
_ Array n' a -> Array n a
restore Array (n' - 1) a -> Array (n - 1) a
_ -> IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh
IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh)
IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
let pv' :: Ptr b
pv' = Ptr a
pv Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
in Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') (Ptr a -> Ptr b
ptrconv Ptr a
poutv) Ptr Int64
psh Ptr Int64
pstrides Ptr b
pv'
Array n' a -> Array n a
restore (Array n' a -> Array n a)
-> (Vector a -> Array n' a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Vector a -> Array n' a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
{-# NOINLINE wrapBinarySV #-}
wrapBinarySV :: forall a b n. Storable a
=> SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a -> Array n a
-> Array n a
wrapBinarySV :: forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
cf_strided a
x Array n a
array =
Array n a
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> Array n a)
-> Array n a)
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh [Int]
strides Int
offset Vector a
vec) Int
_ [Int] -> [Int]
_ Array n' a -> Array n a
restore Array (n' - 1) a -> Array (n - 1) a
_ -> IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh
IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh)
IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
let pv' :: Ptr b
pv' = Ptr a
pv Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
in Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
cf_strided (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') Ptr Int64
psh (Ptr a -> Ptr b
ptrconv Ptr a
poutv) (a -> b
valconv a
x) Ptr Int64
pstrides Ptr b
pv'
Array n' a -> Array n a
restore (Array n' a -> Array n a)
-> (Vector a -> Array n' a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Vector a -> Array n' a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
wrapBinaryVS :: Storable a
=> SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a -> a
-> Array n a
wrapBinaryVS :: forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
wrapBinaryVS SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
cf_strided Array n a
arr a
y =
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv
(\Int64
rank Ptr Int64
psh Ptr b
poutv b
y' Ptr Int64
pstrides Ptr b
pv -> Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
cf_strided Int64
rank Ptr Int64
psh Ptr b
poutv Ptr Int64
pstrides Ptr b
pv b
y') a
y Array n a
arr
{-# NOINLINE wrapBinaryVV #-}
wrapBinaryVV :: forall a b n. Storable a
=> SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a -> Array n a
-> Array n a
wrapBinaryVV :: forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
wrapBinaryVV sn :: SNat n
sn@SNat n
SNat Ptr a -> Ptr b
ptrconv Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
cf_strided
(Array [Int]
sh [Int]
strides1 Int
offset1 Vector a
vec1)
(Array [Int]
sh2 [Int]
strides2 Int
offset2 Vector a
vec2)
| [Int]
sh [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"wrapBinaryVV: unequal shapes: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh2
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"wrapBinaryVV: empty shape: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh
| Bool
otherwise = IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh)
IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn) ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn) ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides1)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides1 ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn) ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides2)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides2 ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec1 ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv1 ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec2 ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv2 ->
let pv1' :: Ptr b
pv1' = Ptr a
pv1 Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
pv2' :: Ptr b
pv2' = Ptr a
pv2 Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
in Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
cf_strided (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn)) Ptr Int64
psh (Ptr a -> Ptr b
ptrconv Ptr a
poutv) Ptr Int64
pstrides1 Ptr b
pv1' Ptr Int64
pstrides2 Ptr b
pv2'
[Int] -> Vector a -> Array n a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
{-# NOINLINE vectorRedInnerOp #-}
vectorRedInnerOp :: forall a b n. (Num a, Storable a)
=> SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a -> Array n a
vectorRedInnerOp :: forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp sn :: SNat n
sn@SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred array :: Array (n + 1) a
array@(Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
| [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh = String -> Array n a
forall a. HasCallStack => String -> a
error String
"unreachable"
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Int] -> a -> Array n a
forall a (n :: Nat). Storable a => [Int] -> a -> Array n a
arrayFromConstant ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) a
0
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides) Int
offset Vector a
vec
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
strides Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @a ([Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh))
([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides) Int
offset Vector a
vec)
| Bool
otherwise =
Array (n + 1) a
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array (n + 1) a)
-> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
-> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray Array (n + 1) a
array ((forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array (n + 1) a)
-> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
-> Array n a)
-> Array n a)
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array (n + 1) a)
-> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
-> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec' :: Array n' a) Int
_ [Int] -> [Int]
_ Array n' a -> Array (n + 1) a
_ Array (n' - 1) a -> Array ((n + 1) - 1) a
restore -> IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh'))
IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
let pv' :: Ptr a
pv' = Ptr a
pv Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset' Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
in Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') (Ptr a -> Ptr b
ptrconv Ptr a
poutv) Ptr Int64
psh Ptr Int64
pstrides (Ptr a -> Ptr b
ptrconv Ptr a
pv')
Nat
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
ndims' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) ((forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a))
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat n'm1) -> do
(Dict (1 <= n')
Dict :: Dict (1 <= n')) <- case SNat 1 -> SNat n' -> OrderingI 1 n'
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
(proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (n :: Nat). KnownNat n => SNat n
natSing @1) (forall (n :: Nat). KnownNat n => SNat n
natSing @n') of
OrderingI 1 n'
LTI -> Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
(Dict
(Assert
(OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
OrderingI 1 n'
EQI -> Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
(Dict
(Assert
(OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
OrderingI 1 n'
_ -> String
-> IO
(Dict
(Assert
(OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. HasCallStack => String -> a
error String
"impossible"
case SNat (n' - 1) -> SNat n -> Maybe ((n' - 1) :~: n)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
(proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe (a :~: b)
sameNat (forall (n :: Nat). KnownNat n => SNat n
natSing @(n' - 1)) (forall (n :: Nat). KnownNat n => SNat n
natSing @n'm1) of
Just (n' - 1) :~: n
Refl -> Array n a -> Array n a
Array (n' - 1) a -> Array ((n + 1) - 1) a
restore (Array n a -> Array n a)
-> (Vector a -> Array n a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector @_ @n'm1 ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh') (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
Maybe ((n' - 1) :~: n)
Nothing -> String -> IO (Array n a)
forall a. HasCallStack => String -> a
error String
"impossible"
{-# NOINLINE vectorRedFullOp #-}
vectorRedFullOp :: forall a b n. (Num a, Storable a)
=> SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a -> a
vectorRedFullOp :: forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a
-> a
vectorRedFullOp SNat n
_ a -> Int -> a
scaleval b -> a
valbackconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred array :: Array n a
array@(Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
| [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh = Vector a
vec Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = a
0
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides = Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh) a -> a -> a
forall a. Num a => a -> a -> a
* Vector a
vec Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset
| Bool
otherwise =
Array n a
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> a)
-> a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> a)
-> a)
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> a)
-> a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') Int
unrepSize [Int] -> [Int]
_ Array n' a -> Array n a
_ Array (n' - 1) a -> Array (n - 1) a
_ -> IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
Vector Int64 -> (Ptr Int64 -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO a) -> IO a) -> (Ptr Int64 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides')) ((Ptr Int64 -> IO a) -> IO a) -> (Ptr Int64 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
Vector a -> (Ptr a -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec' ((Ptr a -> IO a) -> IO a) -> (Ptr a -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
let pv' :: Ptr a
pv' = Ptr a
pv Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset' Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
in (a -> Int -> a
`scaleval` Int
unrepSize) (a -> a) -> (b -> a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
valbackconv
(b -> a) -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') Ptr Int64
psh Ptr Int64
pstrides (Ptr a -> Ptr b
ptrconv Ptr a
pv')
{-# NOINLINE vectorExtremumOp #-}
vectorExtremumOp :: forall a b n. Storable a
=> (Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a -> [Int]
vectorExtremumOp :: forall a b (n :: Nat).
Storable a =>
(Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> [Int]
vectorExtremumOp Ptr a -> Ptr b
ptrconv Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextrem array :: Array n a
array@(Array [Int]
sh [Int]
strides Int
_ Vector a
_)
| [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh = []
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = String -> [Int]
forall a. HasCallStack => String -> a
error String
"Extremum (minindex/maxindex): empty array"
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides = Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int]
sh
| Bool
otherwise =
Array n a
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> [Int])
-> [Int]
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> [Int])
-> [Int])
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> [Int])
-> [Int]
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') Int
_ [Int] -> [Int]
upindex Array n' a -> Array n a
_ Array (n' - 1) a -> Array (n - 1) a
_ -> IO [Int] -> [Int]
forall a. IO a -> a
unsafePerformIO (IO [Int] -> [Int]) -> IO [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ do
let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
IOVector Int64
outvR <- Int -> IO (MVector (PrimState IO) Int64)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh')
IOVector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector Int64
outvR ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
poutv ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
let pv' :: Ptr a
pv' = Ptr a
pv Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset' Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
in Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextrem Ptr Int64
poutv (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') Ptr Int64
psh Ptr Int64
pstrides (Ptr a -> Ptr b
ptrconv Ptr a
pv')
[Int] -> [Int]
upindex ([Int] -> [Int])
-> (Vector Int64 -> [Int]) -> Vector Int64 -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64 -> Int) -> [Int64] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int64 @Int) ([Int64] -> [Int])
-> (Vector Int64 -> [Int64]) -> Vector Int64 -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int64 -> [Int64]
forall a. Storable a => Vector a -> [a]
VS.toList (Vector Int64 -> [Int]) -> IO (Vector Int64) -> IO [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) Int64 -> IO (Vector Int64)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector Int64
MVector (PrimState IO) Int64
outvR
{-# NOINLINE vectorDotprodInnerOp #-}
vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
=> SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a -> Array (n + 1) a -> Array n a
vectorDotprodInnerOp :: forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array (n + 1) a
-> Array (n + 1) a
-> Array n a
vectorDotprodInnerOp sn :: SNat n
sn@SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv SNat n -> Array n a -> Array n a -> Array n a
fmul Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdotinner
arr1 :: Array (n + 1) a
arr1@(Array [Int]
sh1 [Int]
strides1 Int
offset1 Vector a
vec1)
arr2 :: Array (n + 1) a
arr2@(Array [Int]
sh2 [Int]
strides2 Int
offset2 Vector a
vec2)
| [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh1 Bool -> Bool -> Bool
|| [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error String
"unreachable"
| [Int]
sh1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"vectorDotprodInnerOp: shapes unequal: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" vs " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh2
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Int] -> a -> Array n a
forall a (n :: Nat). Storable a => [Int] -> a -> Array n a
arrayFromConstant ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) a
0
| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides1) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
SNat n -> Array n a -> Array n a -> Array n a
fmul SNat n
sn ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides1) Int
offset1 Vector a
vec1)
([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh2) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides2) Int
offset2 Vector a
vec2)
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
strides1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
SNat n -> Array n a -> Array n a -> Array n a
fmul SNat n
sn
([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides1) Int
offset1 Vector a
vec1)
(SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred Array (n + 1) a
arr2)
| [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
strides2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
SNat n -> Array n a -> Array n a -> Array n a
fmul SNat n
sn
(SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred Array (n + 1) a
arr1)
([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh2) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides2) Int
offset2 Vector a
vec2)
| Bool
otherwise =
Array (n + 1) a
-> Array (n + 1) a
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array (n + 1) a)
-> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
-> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> Array n a
-> (forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r)
-> r
simplifyArray2 Array (n + 1) a
arr1 Array (n + 1) a
arr2 ((forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array (n + 1) a)
-> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
-> Array n a)
-> Array n a)
-> (forall {n' :: Nat}.
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array (n + 1) a)
-> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
-> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides1' Int
offset1' Vector a
vec1' :: Array n' a) (Array [Int]
_ [Int]
strides2' Int
offset2' Vector a
vec2') Int
_ [Int] -> [Int]
_ Array n' a -> Array (n + 1) a
_ Array (n' - 1) a -> Array ((n + 1) - 1) a
restore ->
IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
let inrank :: Int
inrank = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh'))
IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
inrank ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
inrank ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides1')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides1 ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec1' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pvec1 ->
Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
inrank ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides2')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides2 ->
Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec2' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pvec2 ->
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdotinner (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Int64 Int
inrank) Ptr Int64
psh (Ptr a -> Ptr b
ptrconv Ptr a
poutv)
Ptr Int64
pstrides1 (Ptr a -> Ptr b
ptrconv Ptr a
pvec1 Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
offset1'))
Ptr Int64
pstrides2 (Ptr a -> Ptr b
ptrconv Ptr a
pvec2 Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
offset2'))
Nat
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
inrank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) ((forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a))
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat n'm1) -> do
(Dict (1 <= n')
Dict :: Dict (1 <= n')) <- case SNat 1 -> SNat n' -> OrderingI 1 n'
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
(proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (n :: Nat). KnownNat n => SNat n
natSing @1) (forall (n :: Nat). KnownNat n => SNat n
natSing @n') of
OrderingI 1 n'
LTI -> Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
(Dict
(Assert
(OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
OrderingI 1 n'
EQI -> Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
(Dict
(Assert
(OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
(Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
OrderingI 1 n'
GTI -> String
-> IO
(Dict
(Assert
(OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. HasCallStack => String -> a
error String
"impossible"
case SNat (n' - 1) -> SNat n -> Maybe ((n' - 1) :~: n)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
(proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe (a :~: b)
sameNat (forall (n :: Nat). KnownNat n => SNat n
natSing @(n' - 1)) (forall (n :: Nat). KnownNat n => SNat n
natSing @n'm1) of
Just (n' - 1) :~: n
Refl -> Array n a -> Array n a
Array (n' - 1) a -> Array ((n + 1) - 1) a
restore (Array n a -> Array n a)
-> (Vector a -> Array n a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Vector a -> Array n a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh') (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
Maybe ((n' - 1) :~: n)
Nothing -> String -> IO (Array n a)
forall a. HasCallStack => String -> a
error String
"impossible"
mulWithInt :: Num a => a -> Int -> a
mulWithInt :: forall a. Num a => a -> Int -> a
mulWithInt a
a Int
i = a
a a -> a -> a
forall a. Num a => a -> a -> a
* Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
cnamebase = "c_binary_" ++ atCName arithtype
c_ss_str = varE (aboNumOp arithop)
c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM intTypesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
cnamebase = "c_ibinary_" ++ atCName arithtype
c_ss_str = varE (aiboNumOp arithop)
c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM floatTypesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
cnamebase = "c_fbinary_" ++ atCName arithtype
c_ss_str = varE (afboNumOp arithop)
c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM floatTypesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let scaleVar = case arithop of
RO_SUM -> varE 'mulWithInt
RO_PRODUCT -> varE '(^)
let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
sequence [SigD name1 <$>
[t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
return $ FunD name1 [Clause [] (NormalB body) []]
,SigD namefull <$>
[t| forall n. SNat n -> Array n $ttyp -> $ttyp |]
,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
return $ FunD namefull [Clause [] (NormalB body) []]
])
$(fmap concat . forM typesList $ \arithtype ->
fmap concat . forM ["min", "max"] $ \fname -> do
let ttyp = conT (atType arithtype)
name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
sequence [SigD name <$>
[t| forall n. Array n $ttyp -> [Int] |]
,do body <- [| vectorExtremumOp id $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
sequence [SigD name <$>
[t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |]
,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()
statisticsEnable :: Bool -> IO ()
statisticsEnable :: Bool -> IO ()
statisticsEnable Bool
b = Int32 -> IO ()
c_stats_enable (if Bool
b then Int32
1 else Int32
0)
statisticsPrintAll :: IO ()
statisticsPrintAll :: IO ()
statisticsPrintAll = do
Handle -> IO ()
hFlush Handle
stdout
IO ()
c_stats_print_all
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
=> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (SNat n -> Array n i -> Array n i)
intWidBranch1 :: forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f32 forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f64 SNat n
sn
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = SNat n
-> (Ptr i -> Ptr Int32)
-> (Int64
-> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
liftOpEltwise1 SNat n
sn Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f32
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = SNat n
-> (Ptr i -> Ptr Int64)
-> (Int64
-> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
liftOpEltwise1 SNat n
sn Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f64
| Bool
otherwise = String -> Array n i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"
intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
=> (i -> i -> i)
-> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> (SNat n -> Array n i -> Array n i -> Array n i)
intWidBranch2 :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 i -> i -> i
ss forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv32 forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs32 forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv64 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs64 forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv64 SNat n
sn
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = SNat n
-> (i -> Int32)
-> (Ptr i -> Ptr Int32)
-> (i -> i -> i)
-> (Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ())
-> (Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ())
-> (Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ())
-> Array n i
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
liftOpEltwise2 SNat n
sn i -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr i -> i -> i
ss Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv32 Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs32 Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv32
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = SNat n
-> (i -> Int64)
-> (Ptr i -> Ptr Int64)
-> (i -> i -> i)
-> (Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ())
-> (Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ())
-> (Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ())
-> Array n i
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
-> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array n a
-> Array n a
-> Array n a
liftOpEltwise2 SNat n
sn i -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr i -> i -> i
ss Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv64 Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs64 Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv64
| Bool
otherwise = String -> Array n i -> Array n i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"
intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
=>
(forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (SNat n -> Array (n + 1) i -> Array n i)
intWidBranchRed1 :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64 SNat n
sn
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp @i @Int32 SNat n
sn i -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp @i @Int64 SNat n
sn i -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64
| Bool
otherwise = String -> Array (n + 1) i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"
intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
=> (i -> Int -> i)
-> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (SNat n -> Array n i -> i)
intWidBranchRedFull :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull i -> Int -> i
fsc forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred64 SNat n
sn
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a
-> a
vectorRedFullOp @i @Int32 SNat n
sn i -> Int -> i
fsc Int32 -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred32
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a
-> a
vectorRedFullOp @i @Int64 SNat n
sn i -> Int -> i
fsc Int64 -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred64
| Bool
otherwise = String -> Array n i -> i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"
intWidBranchExtr :: forall i n. (FiniteBits i, Storable i)
=>
(forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Array n i -> [Int])
intWidBranchExtr :: forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr32 forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr64
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
Storable a =>
(Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> [Int]
vectorExtremumOp @i @Int32 Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr32
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
Storable a =>
(Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> [Int]
vectorExtremumOp @i @Int64 Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr64
| Bool
otherwise = String -> Array n i -> [Int]
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"
intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
=>
(forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
-> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i)
intWidBranchDotprod :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i, NumElt i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array (n + 1) i
-> Array (n + 1) i
-> Array n i
intWidBranchDotprod forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32 forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64 forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot64 SNat n
sn
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array (n + 1) a
-> Array (n + 1) a
-> Array n a
vectorDotprodInnerOp @i @Int32 SNat n
sn i -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr SNat n -> Array n i -> Array n i -> Array n i
forall (n :: Nat). SNat n -> Array n i -> Array n i -> Array n i
forall a (n :: Nat).
NumElt a =>
SNat n -> Array n a -> Array n a -> Array n a
numEltMul Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32 Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot32
| i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64
-> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> Array (n + 1) a
-> Array (n + 1) a
-> Array n a
vectorDotprodInnerOp @i @Int64 SNat n
sn i -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr SNat n -> Array n i -> Array n i -> Array n i
forall (n :: Nat). SNat n -> Array n i -> Array n i -> Array n i
forall a (n :: Nat).
NumElt a =>
SNat n -> Array n a -> Array n a -> Array n a
numEltMul Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64 Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot64
| Bool
otherwise = String -> Array (n + 1) i -> Array (n + 1) i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"
class NumElt a where
numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a
numEltSub :: SNat n -> Array n a -> Array n a -> Array n a
numEltMul :: SNat n -> Array n a -> Array n a -> Array n a
numEltNeg :: SNat n -> Array n a -> Array n a
numEltAbs :: SNat n -> Array n a -> Array n a
numEltSignum :: SNat n -> Array n a -> Array n a
numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a
numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a
numEltSumFull :: SNat n -> Array n a -> a
numEltProductFull :: SNat n -> Array n a -> a
numEltMinIndex :: SNat n -> Array n a -> [Int]
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
instance NumElt Int32 where
numEltAdd :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
numEltAdd = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
addVectorInt32
numEltSub :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
numEltSub = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
subVectorInt32
numEltMul :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
numEltMul = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
mulVectorInt32
numEltNeg :: forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
numEltNeg = SNat n -> Array n Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
negVectorInt32
numEltAbs :: forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
numEltAbs = SNat n -> Array n Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
absVectorInt32
numEltSignum :: forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
numEltSignum = SNat n -> Array n Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
signumVectorInt32
numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
numEltSum1Inner = SNat n -> Array (n + 1) Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
sum1VectorInt32
numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
numEltProduct1Inner = SNat n -> Array (n + 1) Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
product1VectorInt32
numEltSumFull :: forall (n :: Nat). SNat n -> Array n Int32 -> Int32
numEltSumFull = SNat n -> Array n Int32 -> Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Int32
sumFullVectorInt32
numEltProductFull :: forall (n :: Nat). SNat n -> Array n Int32 -> Int32
numEltProductFull = SNat n -> Array n Int32 -> Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Int32
productFullVectorInt32
numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Int32 -> [Int]
numEltMinIndex SNat n
_ = Array n Int32 -> [Int]
forall (n :: Nat). Array n Int32 -> [Int]
minindexVectorInt32
numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Int32 -> [Int]
numEltMaxIndex SNat n
_ = Array n Int32 -> [Int]
forall (n :: Nat). Array n Int32 -> [Int]
maxindexVectorInt32
numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Int32 -> Array (n + 1) Int32 -> Array n Int32
numEltDotprodInner = SNat n
-> Array (n + 1) Int32 -> Array (n + 1) Int32 -> Array n Int32
forall (n :: Nat).
SNat n
-> Array (n + 1) Int32 -> Array (n + 1) Int32 -> Array n Int32
dotprodinnerVectorInt32
instance NumElt Int64 where
numEltAdd :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
numEltAdd = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
addVectorInt64
numEltSub :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
numEltSub = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
subVectorInt64
numEltMul :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
numEltMul = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
mulVectorInt64
numEltNeg :: forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
numEltNeg = SNat n -> Array n Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
negVectorInt64
numEltAbs :: forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
numEltAbs = SNat n -> Array n Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
absVectorInt64
numEltSignum :: forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
numEltSignum = SNat n -> Array n Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
signumVectorInt64
numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
numEltSum1Inner = SNat n -> Array (n + 1) Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
sum1VectorInt64
numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
numEltProduct1Inner = SNat n -> Array (n + 1) Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
product1VectorInt64
numEltSumFull :: forall (n :: Nat). SNat n -> Array n Int64 -> Int64
numEltSumFull = SNat n -> Array n Int64 -> Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Int64
sumFullVectorInt64
numEltProductFull :: forall (n :: Nat). SNat n -> Array n Int64 -> Int64
numEltProductFull = SNat n -> Array n Int64 -> Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Int64
productFullVectorInt64
numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Int64 -> [Int]
numEltMinIndex SNat n
_ = Array n Int64 -> [Int]
forall (n :: Nat). Array n Int64 -> [Int]
minindexVectorInt64
numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Int64 -> [Int]
numEltMaxIndex SNat n
_ = Array n Int64 -> [Int]
forall (n :: Nat). Array n Int64 -> [Int]
maxindexVectorInt64
numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Int64 -> Array (n + 1) Int64 -> Array n Int64
numEltDotprodInner = SNat n
-> Array (n + 1) Int64 -> Array (n + 1) Int64 -> Array n Int64
forall (n :: Nat).
SNat n
-> Array (n + 1) Int64 -> Array (n + 1) Int64 -> Array n Int64
dotprodinnerVectorInt64
instance NumElt Float where
numEltAdd :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
numEltAdd = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
addVectorFloat
numEltSub :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
numEltSub = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
subVectorFloat
numEltMul :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
numEltMul = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
mulVectorFloat
numEltNeg :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
numEltNeg = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
negVectorFloat
numEltAbs :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
numEltAbs = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
absVectorFloat
numEltSignum :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
numEltSignum = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
signumVectorFloat
numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
numEltSum1Inner = SNat n -> Array (n + 1) Float -> Array n Float
forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
sum1VectorFloat
numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
numEltProduct1Inner = SNat n -> Array (n + 1) Float -> Array n Float
forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
product1VectorFloat
numEltSumFull :: forall (n :: Nat). SNat n -> Array n Float -> Float
numEltSumFull = SNat n -> Array n Float -> Float
forall (n :: Nat). SNat n -> Array n Float -> Float
sumFullVectorFloat
numEltProductFull :: forall (n :: Nat). SNat n -> Array n Float -> Float
numEltProductFull = SNat n -> Array n Float -> Float
forall (n :: Nat). SNat n -> Array n Float -> Float
productFullVectorFloat
numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Float -> [Int]
numEltMinIndex SNat n
_ = Array n Float -> [Int]
forall (n :: Nat). Array n Float -> [Int]
minindexVectorFloat
numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Float -> [Int]
numEltMaxIndex SNat n
_ = Array n Float -> [Int]
forall (n :: Nat). Array n Float -> [Int]
maxindexVectorFloat
numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Float -> Array (n + 1) Float -> Array n Float
numEltDotprodInner = SNat n
-> Array (n + 1) Float -> Array (n + 1) Float -> Array n Float
forall (n :: Nat).
SNat n
-> Array (n + 1) Float -> Array (n + 1) Float -> Array n Float
dotprodinnerVectorFloat
instance NumElt Double where
numEltAdd :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
numEltAdd = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
addVectorDouble
numEltSub :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
numEltSub = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
subVectorDouble
numEltMul :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
numEltMul = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
mulVectorDouble
numEltNeg :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
numEltNeg = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
negVectorDouble
numEltAbs :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
numEltAbs = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
absVectorDouble
numEltSignum :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
numEltSignum = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
signumVectorDouble
numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
numEltSum1Inner = SNat n -> Array (n + 1) Double -> Array n Double
forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
sum1VectorDouble
numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
numEltProduct1Inner = SNat n -> Array (n + 1) Double -> Array n Double
forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
product1VectorDouble
numEltSumFull :: forall (n :: Nat). SNat n -> Array n Double -> Double
numEltSumFull = SNat n -> Array n Double -> Double
forall (n :: Nat). SNat n -> Array n Double -> Double
sumFullVectorDouble
numEltProductFull :: forall (n :: Nat). SNat n -> Array n Double -> Double
numEltProductFull = SNat n -> Array n Double -> Double
forall (n :: Nat). SNat n -> Array n Double -> Double
productFullVectorDouble
numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Double -> [Int]
numEltMinIndex SNat n
_ = Array n Double -> [Int]
forall (n :: Nat). Array n Double -> [Int]
minindexVectorDouble
numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Double -> [Int]
numEltMaxIndex SNat n
_ = Array n Double -> [Int]
forall (n :: Nat). Array n Double -> [Int]
maxindexVectorDouble
numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Double -> Array (n + 1) Double -> Array n Double
numEltDotprodInner = SNat n
-> Array (n + 1) Double -> Array (n + 1) Double -> Array n Double
forall (n :: Nat).
SNat n
-> Array (n + 1) Double -> Array (n + 1) Double -> Array n Double
dotprodinnerVectorDouble
instance NumElt Int where
numEltAdd :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
numEltAdd = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
numEltSub :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
numEltSub = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int (-)
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
numEltMul :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
numEltMul = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Num a => a -> a -> a
(*)
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
numEltNeg :: forall (n :: Nat). SNat n -> Array n Int -> Array n Int
numEltNeg = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @Int (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG))
numEltAbs :: forall (n :: Nat). SNat n -> Array n Int -> Array n Int
numEltAbs = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @Int (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS))
numEltSignum :: forall (n :: Nat). SNat n -> Array n Int -> Array n Int
numEltSignum = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @Int (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM))
numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int -> Array n Int
numEltSum1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @Int
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int -> Array n Int
numEltProduct1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @Int
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
numEltSumFull :: forall (n :: Nat). SNat n -> Array n Int -> Int
numEltSumFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @Int Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
numEltProductFull :: forall (n :: Nat). SNat n -> Array n Int -> Int
numEltProductFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @Int Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
(^) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Int -> [Int]
numEltMinIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @Int Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i64
numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Int -> [Int]
numEltMaxIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @Int Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i64
numEltDotprodInner :: forall (n :: Nat).
SNat n -> Array (n + 1) Int -> Array (n + 1) Int -> Array n Int
numEltDotprodInner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i, NumElt i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array (n + 1) i
-> Array (n + 1) i
-> Array n i
intWidBranchDotprod @Int (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i32
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i64
instance NumElt CInt where
numEltAdd :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
numEltAdd = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
(+)
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
numEltSub :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
numEltSub = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt (-)
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
numEltMul :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
numEltMul = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
(*)
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
numEltNeg :: forall (n :: Nat). SNat n -> Array n CInt -> Array n CInt
numEltNeg = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @CInt (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG))
numEltAbs :: forall (n :: Nat). SNat n -> Array n CInt -> Array n CInt
numEltAbs = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @CInt (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS))
numEltSignum :: forall (n :: Nat). SNat n -> Array n CInt -> Array n CInt
numEltSignum = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @CInt (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM))
numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) CInt -> Array n CInt
numEltSum1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @CInt
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) CInt -> Array n CInt
numEltProduct1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @CInt
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
numEltSumFull :: forall (n :: Nat). SNat n -> Array n CInt -> CInt
numEltSumFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @CInt CInt -> Int -> CInt
forall a. Num a => a -> Int -> a
mulWithInt (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
numEltProductFull :: forall (n :: Nat). SNat n -> Array n CInt -> CInt
numEltProductFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @CInt CInt -> Int -> CInt
forall a b. (Num a, Integral b) => a -> b -> a
(^) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
numEltMinIndex :: forall (n :: Nat). SNat n -> Array n CInt -> [Int]
numEltMinIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @CInt Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i64
numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n CInt -> [Int]
numEltMaxIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @CInt Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i64
numEltDotprodInner :: forall (n :: Nat).
SNat n -> Array (n + 1) CInt -> Array (n + 1) CInt -> Array n CInt
numEltDotprodInner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i, NumElt i) =>
(forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array (n + 1) i
-> Array (n + 1) i
-> Array n i
intWidBranchDotprod @CInt (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i32
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i64
class NumElt a => IntElt a where
intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
instance IntElt Int32 where
intEltQuot :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
intEltQuot = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
quotVectorInt32
intEltRem :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
intEltRem = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
remVectorInt32
instance IntElt Int64 where
intEltQuot :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
intEltQuot = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
quotVectorInt64
intEltRem :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
intEltRem = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
remVectorInt64
instance IntElt Int where
intEltQuot :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
intEltQuot = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
intEltRem :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
intEltRem = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))
instance IntElt CInt where
intEltQuot :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
intEltQuot = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Integral a => a -> a -> a
quot
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
intEltRem :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
intEltRem = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Integral a => a -> a -> a
rem
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))
(CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a
floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a
floatEltRecip :: SNat n -> Array n a -> Array n a
floatEltExp :: SNat n -> Array n a -> Array n a
floatEltLog :: SNat n -> Array n a -> Array n a
floatEltSqrt :: SNat n -> Array n a -> Array n a
floatEltSin :: SNat n -> Array n a -> Array n a
floatEltCos :: SNat n -> Array n a -> Array n a
floatEltTan :: SNat n -> Array n a -> Array n a
floatEltAsin :: SNat n -> Array n a -> Array n a
floatEltAcos :: SNat n -> Array n a -> Array n a
floatEltAtan :: SNat n -> Array n a -> Array n a
floatEltSinh :: SNat n -> Array n a -> Array n a
floatEltCosh :: SNat n -> Array n a -> Array n a
floatEltTanh :: SNat n -> Array n a -> Array n a
floatEltAsinh :: SNat n -> Array n a -> Array n a
floatEltAcosh :: SNat n -> Array n a -> Array n a
floatEltAtanh :: SNat n -> Array n a -> Array n a
floatEltLog1p :: SNat n -> Array n a -> Array n a
floatEltExpm1 :: SNat n -> Array n a -> Array n a
floatEltLog1pexp :: SNat n -> Array n a -> Array n a
floatEltLog1mexp :: SNat n -> Array n a -> Array n a
floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a
instance FloatElt Float where
floatEltDiv :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltDiv = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
divVectorFloat
floatEltPow :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltPow = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
powVectorFloat
floatEltLogbase :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltLogbase = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
logbaseVectorFloat
floatEltRecip :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltRecip = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
recipVectorFloat
floatEltExp :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltExp = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
expVectorFloat
floatEltLog :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
logVectorFloat
floatEltSqrt :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltSqrt = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
sqrtVectorFloat
floatEltSin :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltSin = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
sinVectorFloat
floatEltCos :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltCos = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
cosVectorFloat
floatEltTan :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltTan = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
tanVectorFloat
floatEltAsin :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAsin = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
asinVectorFloat
floatEltAcos :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAcos = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
acosVectorFloat
floatEltAtan :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAtan = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
atanVectorFloat
floatEltSinh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltSinh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
sinhVectorFloat
floatEltCosh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltCosh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
coshVectorFloat
floatEltTanh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltTanh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
tanhVectorFloat
floatEltAsinh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAsinh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
asinhVectorFloat
floatEltAcosh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAcosh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
acoshVectorFloat
floatEltAtanh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAtanh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
atanhVectorFloat
floatEltLog1p :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog1p = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
log1pVectorFloat
floatEltExpm1 :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltExpm1 = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
expm1VectorFloat
floatEltLog1pexp :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog1pexp = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
log1pexpVectorFloat
floatEltLog1mexp :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog1mexp = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
log1mexpVectorFloat
floatEltAtan2 :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltAtan2 = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
atan2VectorFloat
instance FloatElt Double where
floatEltDiv :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltDiv = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
divVectorDouble
floatEltPow :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltPow = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
powVectorDouble
floatEltLogbase :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltLogbase = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
logbaseVectorDouble
floatEltRecip :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltRecip = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
recipVectorDouble
floatEltExp :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltExp = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
expVectorDouble
floatEltLog :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
logVectorDouble
floatEltSqrt :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltSqrt = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
sqrtVectorDouble
floatEltSin :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltSin = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
sinVectorDouble
floatEltCos :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltCos = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
cosVectorDouble
floatEltTan :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltTan = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
tanVectorDouble
floatEltAsin :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAsin = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
asinVectorDouble
floatEltAcos :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAcos = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
acosVectorDouble
floatEltAtan :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAtan = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
atanVectorDouble
floatEltSinh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltSinh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
sinhVectorDouble
floatEltCosh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltCosh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
coshVectorDouble
floatEltTanh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltTanh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
tanhVectorDouble
floatEltAsinh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAsinh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
asinhVectorDouble
floatEltAcosh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAcosh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
acoshVectorDouble
floatEltAtanh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAtanh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
atanhVectorDouble
floatEltLog1p :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog1p = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
log1pVectorDouble
floatEltExpm1 :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltExpm1 = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
expm1VectorDouble
floatEltLog1pexp :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog1pexp = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
log1pexpVectorDouble
floatEltLog1mexp :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog1mexp = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
log1mexpVectorDouble
floatEltAtan2 :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltAtan2 = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
atan2VectorDouble