{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Strided.Arith.Internal where

import Control.Monad
import Data.Bifunctor (second)
import Data.Bits
import Data.Int
import Data.List (sort, zip4)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe

import Data.Array.Strided.Arith.Internal.Foreign
import Data.Array.Strided.Arith.Internal.Lists
import Data.Array.Strided.Array


-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition


-- TODO: move this to a utilities module
fromSNat' :: SNat n -> Int
fromSNat' :: forall (n :: Nat). SNat n -> Int
fromSNat' = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (SNat n -> Integer) -> SNat n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
fromSNat

data Dict c where
  Dict :: c => Dict c

debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String
debugShow :: forall (n :: Nat) a.
(Storable a, KnownNat n) =>
Array n a -> String
debugShow (Array [Int]
sh [Int]
strides Int
offset Vector a
vec) =
  String
"Array @" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
strides String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
offset String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" <_*" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Vector a -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector a
vec) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
">"


-- TODO: test all the cases of this thing with various input strides
liftOpEltwise1 :: Storable a
               => SNat n
               -> (Ptr a -> Ptr b)
               -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
               -> Array n a -> Array n a
liftOpEltwise1 :: forall a (n :: Nat) b.
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
liftOpEltwise1 sn :: SNat n
sn@SNat n
SNat Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided arr :: Array n a
arr@(Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
  | Just (Int
blockOff, Int
blockSz) <- [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh Int
offset [Int]
strides =
      if Int
blockSz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
        then [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) [Int]
strides) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
        else let resvec :: Vector a
resvec = Array n a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array n a -> Vector a) -> Array n a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary SNat n
sn Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int
blockSz] [Int
1] Int
blockOff Vector a
vec)
             in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh [Int]
strides (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector a
resvec
  | Bool
otherwise = SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary SNat n
sn Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided Array n a
arr

-- TODO: test all the cases of this thing with various input strides
liftOpEltwise2 :: Storable a
               => SNat n
               -> (a -> b)
               -> (Ptr a -> Ptr b)
               -> (a -> a -> a)
               -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ sv
               -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- ^ vs
               -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ vv
               -> Array n a -> Array n a -> Array n a
liftOpEltwise2 :: forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
liftOpEltwise2 sn :: SNat n
sn@SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv a -> a -> a
f_ss Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
f_sv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
f_vs Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
f_vv
    arr1 :: Array n a
arr1@(Array [Int]
sh1 [Int]
strides1 Int
offset1 Vector a
vec1)
    arr2 :: Array n a
arr2@(Array [Int]
sh2 [Int]
strides2 Int
offset2 Vector a
vec2)
  | [Int]
sh1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"liftOpEltwise2: shapes unequal: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" vs " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh2
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh1 = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int]
strides1) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
  | Bool
otherwise = case ([Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh1 Int
offset1 [Int]
strides1, [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh2 Int
offset2 [Int]
strides2) of
      (Just (Int
_, Int
1), Just (Int
_, Int
1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars
        let vec' :: Vector a
vec' = a -> Vector a
forall a. Storable a => a -> Vector a
VS.singleton (a -> a -> a
f_ss (Vector a
vec1 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset1) (Vector a
vec2 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset2))
        in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides1 Int
0 Vector a
vec'

      (Just (Int
_, Int
1), Just (Int
blockOff, Int
blockSz)) ->  -- scalar * dense
        let arr2' :: Array 1 a
arr2' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff Int
blockSz Vector a
vec2)
            resvec :: Vector a
resvec = Array 1 a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array 1 a -> Vector a) -> Array 1 a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat 1
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array 1 a
-> Array 1 a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV (forall (n :: Nat). KnownNat n => SNat n
SNat @1) a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
f_sv (Vector a
vec1 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset1) Array 1 a
arr2'
        in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides2 (Int
offset2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector a
resvec

      (Just (Int
_, Int
1), Maybe (Int, Int)
Nothing) ->  -- scalar * array
        SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
f_sv (Vector a
vec1 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset1) Array n a
arr2

      (Just (Int
blockOff, Int
blockSz), Just (Int
_, Int
1)) ->  -- dense * scalar
        let arr1' :: Array 1 a
arr1' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff Int
blockSz Vector a
vec1)
            resvec :: Vector a
resvec = Array 1 a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array 1 a -> Vector a) -> Array 1 a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat 1
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array 1 a
-> a
-> Array 1 a
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
wrapBinaryVS (forall (n :: Nat). KnownNat n => SNat n
SNat @1) a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
f_vs Array 1 a
arr1' (Vector a
vec2 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset2)
        in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides1 (Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff) Vector a
resvec

      (Maybe (Int, Int)
Nothing, Just (Int
_, Int
1)) ->  -- array * scalar
        SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
wrapBinaryVS SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
f_vs Array n a
arr1 (Vector a
vec2 Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset2)

      (Just (Int
blockOff1, Int
blockSz1), Just (Int
blockOff2, Int
blockSz2))
        | [Int]
strides1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
strides2
        ->  -- dense * dense but the strides match
          if Int
blockSz1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
blockSz2 Bool -> Bool -> Bool
|| Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
offset2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff2
            then String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ([Int], (Int, Int), [Int], (Int, Int)) -> String
forall a. Show a => a -> String
show ([Int]
strides1, (Int
blockOff1, Int
blockSz1), [Int]
strides2, (Int
blockOff2, Int
blockSz2))
            else
              let arr1' :: Array 1 a
arr1' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz1] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff1 Int
blockSz1 Vector a
vec1)
                  arr2' :: Array 1 a
arr2' = [Int] -> Vector a -> Array 1 a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int
blockSz1] (Int -> Int -> Vector a -> Vector a
forall a. Storable a => Int -> Int -> Vector a -> Vector a
VS.slice Int
blockOff2 Int
blockSz2 Vector a
vec2)
                  resvec :: Vector a
resvec = Array 1 a -> Vector a
forall (n :: Nat) a. Array n a -> Vector a
arrValues (Array 1 a -> Vector a) -> Array 1 a -> Vector a
forall a b. (a -> b) -> a -> b
$ SNat 1
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array 1 a
-> Array 1 a
-> Array 1 a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
wrapBinaryVV (forall (n :: Nat). KnownNat n => SNat n
SNat @1) Ptr a -> Ptr b
ptrconv Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
f_vv Array 1 a
arr1' Array 1 a
arr2'
              in [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh1 [Int]
strides1 (Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockOff1) Vector a
resvec

      (Maybe (Int, Int)
_, Maybe (Int, Int)
_) ->  -- fallback case
        SNat n
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
wrapBinaryVV SNat n
sn Ptr a -> Ptr b
ptrconv Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
f_vv Array n a
arr1 Array n a
arr2

-- | Given shape vector, offset and stride vector, check whether this virtual
-- vector uses a dense subarray of its backing array. If so, the first index
-- and the number of elements in this subarray is returned.
-- This excludes any offset.
stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
stridesDense [Int]
sh Int
offset [Int]
_ | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
offset, Int
0)
stridesDense [Int]
sh Int
offsetNeg [Int]
stridesNeg =
  -- First reverse all dimensions with negative stride, so that the first used
  -- value is at 'offset' and the rest is >= offset.
  let (Int
offset, [Int]
strides) = [Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [Int]
sh Int
offsetNeg [Int]
stridesNeg
  in -- sort dimensions on their stride, ascending, dropping any zero strides
     case ((Int, Int) -> Bool) -> [(Int, Int)] -> [(Int, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (Int -> Bool) -> ((Int, Int) -> Int) -> (Int, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Int) -> Int
forall a b. (a, b) -> a
fst) ([(Int, Int)] -> [(Int, Int)]
forall a. Ord a => [a] -> [a]
sort ([Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
strides [Int]
sh)) of
       [] -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
offset, Int
1)
       (Int
1, Int
n) : [(Int, Int)]
pairs -> (Int
offset,) (Int -> (Int, Int)) -> Maybe Int -> Maybe (Int, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [(Int, Int)] -> Maybe Int
checkCover Int
n [(Int, Int)]
pairs
       [(Int, Int)]
_ -> Maybe (Int, Int)
forall a. Maybe a
Nothing  -- if the smallest stride is not 1, it will never be dense
  where
    -- Given size of currently densely covered region at beginning of the
    -- array and the remaining (stride, size) pairs with all strides >=1,
    -- return whether this all together covers a dense prefix of the array. If
    -- it does, return the number of elements in this prefix.
    checkCover :: Int -> [(Int, Int)] -> Maybe Int
    checkCover :: Int -> [(Int, Int)] -> Maybe Int
checkCover Int
block [] = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
block
    checkCover Int
block ((Int
s, Int
n) : [(Int, Int)]
pairs) = Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
block) Maybe () -> Maybe Int -> Maybe Int
forall a b. Maybe a -> Maybe b -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> [(Int, Int)] -> Maybe Int
checkCover ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
block) [(Int, Int)]
pairs

    -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
    flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
    flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [] Int
off [] = (Int
off, [])
    flipReverseds (Int
n : [Int]
sh') Int
off (Int
s : [Int]
str')
      | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = ([Int] -> [Int]) -> (Int, [Int]) -> (Int, [Int])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Int
s Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) ([Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [Int]
sh' Int
off [Int]
str')
      | Bool
otherwise =
          let off' :: Int
off' = Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
s
          in ([Int] -> [Int]) -> (Int, [Int]) -> (Int, [Int])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((-Int
s) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) ([Int] -> Int -> [Int] -> (Int, [Int])
flipReverseds [Int]
sh' Int
off' [Int]
str')
    flipReverseds [Int]
_ Int
_ [Int]
_ = String -> (Int, [Int])
forall a. HasCallStack => String -> a
error String
"flipReverseds: invalid arguments"

data Unreplicated a =
  forall n'. KnownNat n' =>
    -- | Let the original array, with replicated dimensions, be called A.
    Unreplicated -- | An array with all strides /= 0. Call this array U. It has
                 -- the same shape as A, except with all the replicated (stride
                 -- == 0) dimensions removed. The shape of U is the
                 -- "unreplicated shape".
                 (Array n' a)
                 -- | Product of sizes of the unreplicated dimensions
                 Int
                 -- | Given the stride vector of an array with the unreplicated
                 -- shape, this function reinserts zeros so that it may be
                 -- combined with the original shape of A.
                 ([Int] -> [Int])

-- | Removes all replicated dimensions (i.e. those with stride == 0) from the array.
unreplicateStrides :: Array n a -> Unreplicated a
unreplicateStrides :: forall (n :: Nat) a. Array n a -> Unreplicated a
unreplicateStrides (Array [Int]
sh [Int]
strides Int
offset Vector a
vec) =
  let replDims :: [Bool]
replDims = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides
      ([Int]
shF, [Int]
stridesF) = [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int
n, Int
s) | (Int
n, Int
s) <- [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
sh [Int]
strides, Int
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0]

      reinsertZeros :: [Bool] -> [a] -> [a]
reinsertZeros (Bool
False : [Bool]
zeros) (a
s : [a]
strides') = a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
      reinsertZeros (Bool
True : [Bool]
zeros) [a]
strides' = a
0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
      reinsertZeros [] [] = []
      reinsertZeros (Bool
False : [Bool]
_) [] = String -> [a]
forall a. HasCallStack => String -> a
error String
"unreplicateStrides: Internal error: reply strides too short"
      reinsertZeros [] (a
_:[a]
_) = String -> [a]
forall a. HasCallStack => String -> a
error String
"unreplicateStrides: Internal error: reply strides too long"

      unrepSize :: Int
unrepSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int
n | (Int
n, Bool
True) <- [Int] -> [Bool] -> [(Int, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
sh [Bool]
replDims]

  in Nat
-> (forall {n :: Nat}. SNat n -> Unreplicated a) -> Unreplicated a
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shF)) ((forall {n :: Nat}. SNat n -> Unreplicated a) -> Unreplicated a)
-> (forall {n :: Nat}. SNat n -> Unreplicated a) -> Unreplicated a
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat lenshF) ->
       Array n a -> Int -> ([Int] -> [Int]) -> Unreplicated a
forall a (n' :: Nat).
KnownNat n' =>
Array n' a -> Int -> ([Int] -> [Int]) -> Unreplicated a
Unreplicated (forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array @lenshF [Int]
shF [Int]
stridesF Int
offset Vector a
vec) Int
unrepSize ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros [Bool]
replDims)

simplifyArray :: Array n a
              -> (forall n'. KnownNat n'
                          => Array n' a  -- U
                          -- Product of sizes of the unreplicated dimensions
                          -> Int
                          -- Convert index in U back to index into original
                          -- array. Replicated dimensions get 0.
                          -> ([Int] -> [Int])
                          -- Given a new array of the same shape as U, convert
                          -- it back to the original shape and iteration order.
                          -> (Array n' a -> Array n a)
                          -- Do the same except without the INNER dimension.
                          -- This throws an error if the inner dimension had
                          -- stride 0.
                          -> (Array (n' - 1) a -> Array (n - 1) a)
                          -> r)
              -> r
simplifyArray :: forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray Array n a
array forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k
  | let revDims :: [Bool]
revDims = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
0) (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
array)
  , Unreplicated Array n' a
array' Int
unrepSize [Int] -> [Int]
rereplicate <- Array n a -> Unreplicated a
forall (n :: Nat) a. Array n a -> Unreplicated a
unreplicateStrides ([Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims Array n a
array)
  = Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k Array n' a
array'
      Int
unrepSize
      (\[Int]
idx -> [Int] -> [Int]
rereplicate ((Bool -> Int -> Int -> Int) -> [Bool] -> [Int] -> [Int] -> [Int]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\Bool
b Int
n Int
i -> if Bool
b then Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i else Int
i)
                                     [Bool]
revDims (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') [Int]
idx))
      (\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
         if [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array'
           then [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n a
array) ([Int] -> [Int]
rereplicate [Int]
strides') Int
offset' Vector a
vec')
           else String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
      (\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
         if | [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') ->
                String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error (String -> Array (n - 1) a) -> String -> Array (n - 1) a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show (Array n' a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n' a
array') String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
            | [Int] -> Int
forall a. HasCallStack => [a] -> a
last (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
array) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
                String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error String
"simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
            | Bool
otherwise ->
                [Bool] -> Array (n - 1) a -> Array (n - 1) a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
init [Bool]
revDims) ([Int] -> [Int] -> Int -> Vector a -> Array (n - 1) a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrShape Array n a
array)) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init ([Int] -> [Int]
rereplicate ([Int]
strides' [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]))) Int
offset' Vector a
vec'))

-- | The two input arrays must have the same shape.
simplifyArray2 :: Array n a -> Array n a
               -> (forall n'. KnownNat n'
                           => Array n' a  -- U1
                           -> Array n' a  -- U2 (same shape as U1)
                           -- Product of sizes of the dimensions that are
                           -- replicated in neither input
                           -> Int
                           -- Convert index in U{1,2} back to index into original
                           -- arrays. Dimensions that are replicated in both
                           -- inputs get 0.
                           -> ([Int] -> [Int])
                           -- Given a new array of the same shape as U1 (& U2),
                           -- convert it back to the original shape and
                           -- iteration order.
                           -> (Array n' a -> Array n a)
                           -- Do the same except without the INNER dimension.
                           -- This throws an error if the inner dimension had
                           -- stride 0 in both inputs.
                           -> (Array (n' - 1) a -> Array (n - 1) a)
                           -> r)
               -> r
simplifyArray2 :: forall (n :: Nat) a r.
Array n a
-> Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray2 arr1 :: Array n a
arr1@(Array [Int]
sh [Int]
_ Int
_ Vector a
_) arr2 :: Array n a
arr2@(Array [Int]
sh2 [Int]
_ Int
_ Vector a
_) forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k
  | [Int]
sh [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> r
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Unequal shapes"

  | let revDims :: [Bool]
revDims = (Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
s1 Int
s2 -> Int
s1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
&& Int
s2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
arr1) (Array n a -> [Int]
forall (n :: Nat) a. Array n a -> [Int]
arrStrides Array n a
arr2)
  , Array [Int]
_ [Int]
strides1 Int
offset1 Vector a
vec1 <- [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims Array n a
arr1
  , Array [Int]
_ [Int]
strides2 Int
offset2 Vector a
vec2 <- [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims Array n a
arr2

  , let replDims :: [Bool]
replDims = (Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
s1 Int
s2 -> Int
s1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
s2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides1 [Int]
strides2
  , let ([Int]
shF, [Int]
strides1F, [Int]
strides2F) = [(Int, Int, Int)] -> ([Int], [Int], [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Int
n, Int
s1, Int
s2) | (Int
n, Int
s1, Int
s2, Bool
False) <- [Int] -> [Int] -> [Int] -> [Bool] -> [(Int, Int, Int, Bool)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int]
sh [Int]
strides1 [Int]
strides2 [Bool]
replDims]

  , let reinsertZeros :: [Bool] -> [a] -> [a]
reinsertZeros (Bool
False : [Bool]
zeros) (a
s : [a]
strides') = a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
        reinsertZeros (Bool
True : [Bool]
zeros) [a]
strides' = a
0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [Bool] -> [a] -> [a]
reinsertZeros [Bool]
zeros [a]
strides'
        reinsertZeros [] [] = []
        reinsertZeros (Bool
False : [Bool]
_) [] = String -> [a]
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Internal error: reply strides too short"
        reinsertZeros [] (a
_:[a]
_) = String -> [a]
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Internal error: reply strides too long"

  , let unrepSize :: Int
unrepSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int
n | (Int
n, Bool
True) <- [Int] -> [Bool] -> [(Int, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
sh [Bool]
replDims]

  = Nat -> (forall {n :: Nat}. SNat n -> r) -> r
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shF)) ((forall {n :: Nat}. SNat n -> r) -> r)
-> (forall {n :: Nat}. SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat lenshF) ->
    forall (n' :: Nat).
KnownNat n' =>
Array n' a
-> Array n' a
-> Int
-> ([Int] -> [Int])
-> (Array n' a -> Array n a)
-> (Array (n' - 1) a -> Array (n - 1) a)
-> r
k @lenshF
      ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
shF [Int]
strides1F Int
offset1 Vector a
vec1)
      ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
shF [Int]
strides2F Int
offset2 Vector a
vec2)
      Int
unrepSize
      (\[Int]
idx -> (Bool -> Int -> Int -> Int) -> [Bool] -> [Int] -> [Int] -> [Int]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\Bool
b Int
n Int
i -> if Bool
b then Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i else Int
i)
                        [Bool]
revDims [Int]
sh ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros [Bool]
replDims [Int]
idx))
      (\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
         if [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
shF then String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray2: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
shF String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
         else [Bool] -> Array n a -> Array n a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims [Bool]
revDims ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array [Int]
sh ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros [Bool]
replDims [Int]
strides') Int
offset' Vector a
vec'))
      (\(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') ->
         if | [Int]
sh' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
shF ->
                String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error (String -> Array (n - 1) a) -> String -> Array (n - 1) a
forall a b. (a -> b) -> a -> b
$ String
"simplifyArray2: Internal error: reply shape wrong (reply " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", unreplicated " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
shF String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
            | [Bool] -> Bool
forall a. HasCallStack => [a] -> a
last [Bool]
replDims ->
                String -> Array (n - 1) a
forall a. HasCallStack => String -> a
error String
"simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
            | Bool
otherwise ->
                [Bool] -> Array (n - 1) a -> Array (n - 1) a
forall (n :: Nat) a. [Bool] -> Array n a -> Array n a
arrayRevDims ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
init [Bool]
revDims) ([Int] -> [Int] -> Int -> Vector a -> Array (n - 1) a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) ([Bool] -> [Int] -> [Int]
forall {a}. Num a => [Bool] -> [a] -> [a]
reinsertZeros ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
init [Bool]
replDims) [Int]
strides') Int
offset' Vector a
vec'))

{-# NOINLINE wrapUnary #-}
wrapUnary :: forall a b n. Storable a
          => SNat n
          -> (Ptr a -> Ptr b)
          -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
          -> Array n a
          -> Array n a
wrapUnary :: forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
wrapUnary SNat n
_ Ptr a -> Ptr b
ptrconv Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided Array n a
array =
  Array n a
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
  KnownNat n' =>
  Array n' a
  -> Int
  -> ([Int] -> [Int])
  -> (Array n' a -> Array n a)
  -> (Array (n' - 1) a -> Array (n - 1) a)
  -> Array n a)
 -> Array n a)
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh [Int]
strides Int
offset Vector a
vec) Int
_ [Int] -> [Int]
_ Array n' a -> Array n a
restore Array (n' - 1) a -> Array (n - 1) a
_ -> IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
    let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh
    IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh)
    IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
      Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
      Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
      Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
        let pv' :: Ptr b
pv' = Ptr a
pv Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
        in Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
cf_strided (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') (Ptr a -> Ptr b
ptrconv Ptr a
poutv) Ptr Int64
psh Ptr Int64
pstrides Ptr b
pv'
    Array n' a -> Array n a
restore (Array n' a -> Array n a)
-> (Vector a -> Array n' a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Vector a -> Array n' a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv

{-# NOINLINE wrapBinarySV #-}
wrapBinarySV :: forall a b n. Storable a
             => SNat n
             -> (a -> b)
             -> (Ptr a -> Ptr b)
             -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
             -> a -> Array n a
             -> Array n a
wrapBinarySV :: forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
cf_strided a
x Array n a
array =
  Array n a
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
  KnownNat n' =>
  Array n' a
  -> Int
  -> ([Int] -> [Int])
  -> (Array n' a -> Array n a)
  -> (Array (n' - 1) a -> Array (n - 1) a)
  -> Array n a)
 -> Array n a)
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh [Int]
strides Int
offset Vector a
vec) Int
_ [Int] -> [Int]
_ Array n' a -> Array n a
restore Array (n' - 1) a -> Array (n - 1) a
_ -> IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
    let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh
    IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh)
    IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
      Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
      Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
      Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
        let pv' :: Ptr b
pv' = Ptr a
pv Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
        in Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
cf_strided (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') Ptr Int64
psh (Ptr a -> Ptr b
ptrconv Ptr a
poutv) (a -> b
valconv a
x) Ptr Int64
pstrides Ptr b
pv'
    Array n' a -> Array n a
restore (Array n' a -> Array n a)
-> (Vector a -> Array n' a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Vector a -> Array n' a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv

wrapBinaryVS :: Storable a
             => SNat n
             -> (a -> b)
             -> (Ptr a -> Ptr b)
             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
             -> Array n a -> a
             -> Array n a
wrapBinaryVS :: forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> Array n a
-> a
-> Array n a
wrapBinaryVS SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
cf_strided Array n a
arr a
y =
  SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv
               (\Int64
rank Ptr Int64
psh Ptr b
poutv b
y' Ptr Int64
pstrides Ptr b
pv -> Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
cf_strided Int64
rank Ptr Int64
psh Ptr b
poutv Ptr Int64
pstrides Ptr b
pv b
y') a
y Array n a
arr

-- | The two shapes must be equal and non-empty. This is checked.
{-# NOINLINE wrapBinaryVV #-}
wrapBinaryVV :: forall a b n. Storable a
             => SNat n
             -> (Ptr a -> Ptr b)
             -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
             -> Array n a -> Array n a
             -> Array n a
-- TODO: do unreversing and unreplication on the input arrays (but
-- simultaneously: can only unreplicate if _both_ are replicated on that
-- dimension)
wrapBinaryVV :: forall a b (n :: Nat).
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
wrapBinaryVV sn :: SNat n
sn@SNat n
SNat Ptr a -> Ptr b
ptrconv Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
cf_strided
    (Array [Int]
sh [Int]
strides1 Int
offset1 Vector a
vec1)
    (Array [Int]
sh2 [Int]
strides2 Int
offset2 Vector a
vec2)
  | [Int]
sh [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"wrapBinaryVV: unequal shapes: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh2
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"wrapBinaryVV: empty shape: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh
  | Bool
otherwise = IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
      IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh)
      IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
        Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn) ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
        Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn) ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides1)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides1 ->
        Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn) ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides2)) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides2 ->
        Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec1 ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv1 ->
        Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec2 ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv2 ->
          let pv1' :: Ptr b
pv1' = Ptr a
pv1 Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
              pv2' :: Ptr b
pv2' = Ptr a
pv2 Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
          in Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
cf_strided (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SNat n -> Int
forall (n :: Nat). SNat n -> Int
fromSNat' SNat n
sn)) Ptr Int64
psh (Ptr a -> Ptr b
ptrconv Ptr a
poutv) Ptr Int64
pstrides1 Ptr b
pv1' Ptr Int64
pstrides2 Ptr b
pv2'
      [Int] -> Vector a -> Array n a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector [Int]
sh (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv

-- TODO: test handling of negative strides
-- | Reduce along the inner dimension
{-# NOINLINE vectorRedInnerOp #-}
vectorRedInnerOp :: forall a b n. (Num a, Storable a)
                 => SNat n
                 -> (a -> b)
                 -> (Ptr a -> Ptr b)
                 -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant
                 -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                 -> Array (n + 1) a -> Array n a
vectorRedInnerOp :: forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp sn :: SNat n
sn@SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred array :: Array (n + 1) a
array@(Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
  | [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh = String -> Array n a
forall a. HasCallStack => String -> a
error String
"unreachable"
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Int] -> a -> Array n a
forall a (n :: Nat). Storable a => [Int] -> a -> Array n a
arrayFromConstant ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) a
0
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
  -- now the input array is nonempty
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides) Int
offset Vector a
vec
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
strides Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
      SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
forall a b (n :: Nat).
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> a
-> Array n a
-> Array n a
wrapBinarySV SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @a ([Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh))
                   ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides) Int
offset Vector a
vec)
  -- now there is useful work along the inner dimension
  -- Note that unreplication keeps the inner dimension intact, because `last strides /= 0` at this point.
  | Bool
otherwise =
      Array (n + 1) a
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array (n + 1) a)
    -> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
    -> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray Array (n + 1) a
array ((forall {n' :: Nat}.
  KnownNat n' =>
  Array n' a
  -> Int
  -> ([Int] -> [Int])
  -> (Array n' a -> Array (n + 1) a)
  -> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
  -> Array n a)
 -> Array n a)
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array (n + 1) a)
    -> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
    -> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec' :: Array n' a) Int
_ [Int] -> [Int]
_ Array n' a -> Array (n + 1) a
_ Array (n' - 1) a -> Array ((n + 1) - 1) a
restore -> IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
        let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
        IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh'))
        IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
          Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
            let pv' :: Ptr a
pv' = Ptr a
pv Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset' Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
            in Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') (Ptr a -> Ptr b
ptrconv Ptr a
poutv) Ptr Int64
psh Ptr Int64
pstrides (Ptr a -> Ptr b
ptrconv Ptr a
pv')
        Nat
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
ndims' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) ((forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a))
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat n'm1) -> do
          (Dict (1 <= n')
Dict :: Dict (1 <= n')) <- case SNat 1 -> SNat n' -> OrderingI 1 n'
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
       (proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (n :: Nat). KnownNat n => SNat n
natSing @1) (forall (n :: Nat). KnownNat n => SNat n
natSing @n') of
                                        OrderingI 1 n'
LTI -> Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
     (Dict
        (Assert
           (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
                                        OrderingI 1 n'
EQI -> Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
     (Dict
        (Assert
           (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
                                        OrderingI 1 n'
_ -> String
-> IO
     (Dict
        (Assert
           (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. HasCallStack => String -> a
error String
"impossible"  -- because `last strides /= 0`
          case SNat (n' - 1) -> SNat n -> Maybe ((n' - 1) :~: n)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
       (proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe (a :~: b)
sameNat (forall (n :: Nat). KnownNat n => SNat n
natSing @(n' - 1)) (forall (n :: Nat). KnownNat n => SNat n
natSing @n'm1) of
            Just (n' - 1) :~: n
Refl -> Array n a -> Array n a
Array (n' - 1) a -> Array ((n + 1) - 1) a
restore (Array n a -> Array n a)
-> (Vector a -> Array n a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector @_ @n'm1 ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh') (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
            Maybe ((n' - 1) :~: n)
Nothing -> String -> IO (Array n a)
forall a. HasCallStack => String -> a
error String
"impossible"

-- TODO: test handling of negative strides
-- | Reduce full array
{-# NOINLINE vectorRedFullOp #-}
vectorRedFullOp :: forall a b n. (Num a, Storable a)
                => SNat n
                -> (a -> Int -> a)
                -> (b -> a)
                -> (Ptr a -> Ptr b)
                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel
                -> Array n a -> a
vectorRedFullOp :: forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a
-> a
vectorRedFullOp SNat n
_ a -> Int -> a
scaleval b -> a
valbackconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred array :: Array n a
array@(Array [Int]
sh [Int]
strides Int
offset Vector a
vec)
  | [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh = Vector a
vec Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset  -- 0D array has one element
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = a
0
  -- now the input array is nonempty
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides = Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh) a -> a -> a
forall a. Num a => a -> a -> a
* Vector a
vec Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
VS.! Int
offset
  -- now there is at least one non-replicated dimension
  | Bool
otherwise =
      Array n a
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> a)
-> a
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
  KnownNat n' =>
  Array n' a
  -> Int
  -> ([Int] -> [Int])
  -> (Array n' a -> Array n a)
  -> (Array (n' - 1) a -> Array (n - 1) a)
  -> a)
 -> a)
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> a)
-> a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') Int
unrepSize [Int] -> [Int]
_ Array n' a -> Array n a
_ Array (n' - 1) a -> Array (n - 1) a
_ -> IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
        let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
        Vector Int64 -> (Ptr Int64 -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO a) -> IO a) -> (Ptr Int64 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
          Vector Int64 -> (Ptr Int64 -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides')) ((Ptr Int64 -> IO a) -> IO a) -> (Ptr Int64 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
          Vector a -> (Ptr a -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec' ((Ptr a -> IO a) -> IO a) -> (Ptr a -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
            let pv' :: Ptr a
pv' = Ptr a
pv Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset' Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
            in (a -> Int -> a
`scaleval` Int
unrepSize) (a -> a) -> (b -> a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
valbackconv
                 (b -> a) -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') Ptr Int64
psh Ptr Int64
pstrides (Ptr a -> Ptr b
ptrconv Ptr a
pv')

-- TODO: test this function
-- | Find extremum (minindex ("argmin") or maxindex) in full array
{-# NOINLINE vectorExtremumOp #-}
vectorExtremumOp :: forall a b n. Storable a
                 => (Ptr a -> Ptr b)
                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel
                 -> Array n a -> [Int]  -- result length: n
vectorExtremumOp :: forall a b (n :: Nat).
Storable a =>
(Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> [Int]
vectorExtremumOp Ptr a -> Ptr b
ptrconv Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextrem array :: Array n a
array@(Array [Int]
sh [Int]
strides Int
_ Vector a
_)
  | [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh = []
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) [Int]
sh = String -> [Int]
forall a. HasCallStack => String -> a
error String
"Extremum (minindex/maxindex): empty array"
  -- now the input array is nonempty
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int]
strides = Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int]
sh
  -- now there is at least one non-replicated dimension
  | Bool
otherwise =
      Array n a
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> [Int])
-> [Int]
forall (n :: Nat) a r.
Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray Array n a
array ((forall {n' :: Nat}.
  KnownNat n' =>
  Array n' a
  -> Int
  -> ([Int] -> [Int])
  -> (Array n' a -> Array n a)
  -> (Array (n' - 1) a -> Array (n - 1) a)
  -> [Int])
 -> [Int])
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> [Int])
-> [Int]
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides' Int
offset' Vector a
vec') Int
_ [Int] -> [Int]
upindex Array n' a -> Array n a
_ Array (n' - 1) a -> Array (n - 1) a
_ -> IO [Int] -> [Int]
forall a. IO a -> a
unsafePerformIO (IO [Int] -> [Int]) -> IO [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ do
        let ndims' :: Int
ndims' = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
        IOVector Int64
outvR <- Int -> IO (MVector (PrimState IO) Int64)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh')
        IOVector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector Int64
outvR ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
poutv ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
ndims' ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides ->
          Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pv ->
            let pv' :: Ptr a
pv' = Ptr a
pv Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
offset' Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))
            in Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextrem Ptr Int64
poutv (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndims') Ptr Int64
psh Ptr Int64
pstrides (Ptr a -> Ptr b
ptrconv Ptr a
pv')
        [Int] -> [Int]
upindex ([Int] -> [Int])
-> (Vector Int64 -> [Int]) -> Vector Int64 -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64 -> Int) -> [Int64] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int64 @Int) ([Int64] -> [Int])
-> (Vector Int64 -> [Int64]) -> Vector Int64 -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int64 -> [Int64]
forall a. Storable a => Vector a -> [a]
VS.toList (Vector Int64 -> [Int]) -> IO (Vector Int64) -> IO [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) Int64 -> IO (Vector Int64)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector Int64
MVector (PrimState IO) Int64
outvR

{-# NOINLINE vectorDotprodInnerOp #-}
vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
                     => SNat n
                     -> (a -> b)
                     -> (Ptr a -> Ptr b)
                     -> (SNat n -> Array n a -> Array n a -> Array n a)  -- ^ elementwise multiplication
                     -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant
                     -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                     -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel
                     -> Array (n + 1) a -> Array (n + 1) a -> Array n a
vectorDotprodInnerOp :: forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array (n + 1) a
-> Array (n + 1) a
-> Array n a
vectorDotprodInnerOp sn :: SNat n
sn@SNat n
SNat a -> b
valconv Ptr a -> Ptr b
ptrconv SNat n -> Array n a -> Array n a -> Array n a
fmul Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdotinner
    arr1 :: Array (n + 1) a
arr1@(Array [Int]
sh1 [Int]
strides1 Int
offset1 Vector a
vec1)
    arr2 :: Array (n + 1) a
arr2@(Array [Int]
sh2 [Int]
strides2 Int
offset2 Vector a
vec2)
  | [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh1 Bool -> Bool -> Bool
|| [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error String
"unreachable"
  | [Int]
sh1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int]
sh2 = String -> Array n a
forall a. HasCallStack => String -> a
error (String -> Array n a) -> String -> Array n a
forall a b. (a -> b) -> a -> b
$ String
"vectorDotprodInnerOp: shapes unequal: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" vs " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
sh2
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Int] -> a -> Array n a
forall a (n :: Nat). Storable a => [Int] -> a -> Array n a
arrayFromConstant ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) a
0
  | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) = [Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) (Int
0 Int -> [Int] -> [Int]
forall a b. a -> [b] -> [a]
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides1) Int
0 Vector a
forall a. Storable a => Vector a
VS.empty
  -- now the input arrays are nonempty
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
sh1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
      SNat n -> Array n a -> Array n a -> Array n a
fmul SNat n
sn ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides1) Int
offset1 Vector a
vec1)
              ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh2) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides2) Int
offset2 Vector a
vec2)
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
strides1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
      SNat n -> Array n a -> Array n a -> Array n a
fmul SNat n
sn
        ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh1) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides1) Int
offset1 Vector a
vec1)
        (SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred Array (n + 1) a
arr2)
  | [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
strides2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
      SNat n -> Array n a -> Array n a -> Array n a
fmul SNat n
sn
        (SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp SNat n
sn a -> b
valconv Ptr a -> Ptr b
ptrconv Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fscale Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred Array (n + 1) a
arr1)
        ([Int] -> [Int] -> Int -> Vector a -> Array n a
forall (n :: Nat) a. [Int] -> [Int] -> Int -> Vector a -> Array n a
Array ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh2) ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
strides2) Int
offset2 Vector a
vec2)
  -- now there is useful dotprod work along the inner dimension
  | Bool
otherwise =
      Array (n + 1) a
-> Array (n + 1) a
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array (n + 1) a)
    -> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
    -> Array n a)
-> Array n a
forall (n :: Nat) a r.
Array n a
-> Array n a
-> (forall (n' :: Nat).
    KnownNat n' =>
    Array n' a
    -> Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array n a)
    -> (Array (n' - 1) a -> Array (n - 1) a)
    -> r)
-> r
simplifyArray2 Array (n + 1) a
arr1 Array (n + 1) a
arr2 ((forall {n' :: Nat}.
  KnownNat n' =>
  Array n' a
  -> Array n' a
  -> Int
  -> ([Int] -> [Int])
  -> (Array n' a -> Array (n + 1) a)
  -> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
  -> Array n a)
 -> Array n a)
-> (forall {n' :: Nat}.
    KnownNat n' =>
    Array n' a
    -> Array n' a
    -> Int
    -> ([Int] -> [Int])
    -> (Array n' a -> Array (n + 1) a)
    -> (Array (n' - 1) a -> Array ((n + 1) - 1) a)
    -> Array n a)
-> Array n a
forall a b. (a -> b) -> a -> b
$ \(Array [Int]
sh' [Int]
strides1' Int
offset1' Vector a
vec1' :: Array n' a) (Array [Int]
_ [Int]
strides2' Int
offset2' Vector a
vec2') Int
_ [Int] -> [Int]
_ Array n' a -> Array (n + 1) a
_ Array (n' - 1) a -> Array ((n + 1) - 1) a
restore ->
      IO (Array n a) -> Array n a
forall a. IO a -> a
unsafePerformIO (IO (Array n a) -> Array n a) -> IO (Array n a) -> Array n a
forall a b. (a -> b) -> a -> b
$ do
        let inrank :: Int
inrank = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh'
        IOVector a
outv <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh'))
        IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
outv ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
poutv ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
inrank ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
sh')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
psh ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
inrank ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides1')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides1 ->
          Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec1' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pvec1 ->
          Vector Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith (Int -> [Int64] -> Vector Int64
forall a. Storable a => Int -> [a] -> Vector a
VS.fromListN Int
inrank ((Int -> Int64) -> [Int] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
strides2')) ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Int64
pstrides2 ->
          Vector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
vec2' ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
pvec2 ->
            Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdotinner (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Int64 Int
inrank) Ptr Int64
psh (Ptr a -> Ptr b
ptrconv Ptr a
poutv)
                      Ptr Int64
pstrides1 (Ptr a -> Ptr b
ptrconv Ptr a
pvec1 Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
offset1'))
                      Ptr Int64
pstrides2 (Ptr a -> Ptr b
ptrconv Ptr a
pvec2 Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
offset2'))
        Nat
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall r. Nat -> (forall (n :: Nat). SNat n -> r) -> r
TypeNats.withSomeSNat (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
inrank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) ((forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a))
-> (forall {n :: Nat}. SNat n -> IO (Array n a)) -> IO (Array n a)
forall a b. (a -> b) -> a -> b
$ \(SNat n
SNat :: SNat n'm1) -> do
          (Dict (1 <= n')
Dict :: Dict (1 <= n')) <- case SNat 1 -> SNat n' -> OrderingI 1 n'
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
       (proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (n :: Nat). KnownNat n => SNat n
natSing @1) (forall (n :: Nat). KnownNat n => SNat n
natSing @n') of
                                        OrderingI 1 n'
LTI -> Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
     (Dict
        (Assert
           (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
                                        OrderingI 1 n'
EQI -> Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
-> IO
     (Dict
        (Assert
           (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dict (() :: Constraint)
Dict
  (Assert (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...))
forall (c :: Constraint). c => Dict c
Dict
                                        OrderingI 1 n'
GTI -> String
-> IO
     (Dict
        (Assert
           (OrdCond (CmpNat 1 n') 'True 'True 'False) (TypeError ...)))
forall a. HasCallStack => String -> a
error String
"impossible"  -- because `last strides1 /= 0`
          case SNat (n' - 1) -> SNat n -> Maybe ((n' - 1) :~: n)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> *)
       (proxy2 :: Nat -> *).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe (a :~: b)
sameNat (forall (n :: Nat). KnownNat n => SNat n
natSing @(n' - 1)) (forall (n :: Nat). KnownNat n => SNat n
natSing @n'm1) of
            Just (n' - 1) :~: n
Refl -> Array n a -> Array n a
Array (n' - 1) a -> Array ((n + 1) - 1) a
restore (Array n a -> Array n a)
-> (Vector a -> Array n a) -> Vector a -> Array n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Vector a -> Array n a
forall a (n :: Nat).
(Storable a, KnownNat n) =>
[Int] -> Vector a -> Array n a
arrayFromVector ([Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init [Int]
sh') (Vector a -> Array n a) -> IO (Vector a) -> IO (Array n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
outv
            Maybe ((n' - 1) :~: n)
Nothing -> String -> IO (Array n a)
forall a. HasCallStack => String -> a
error String
"impossible"

mulWithInt :: Num a => a -> Int -> a
mulWithInt :: forall a. Num a => a -> Int -> a
mulWithInt a
a Int
i = a
a a -> a -> a
forall a. Num a => a -> a -> a
* Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i


$(fmap concat . forM typesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
    fmap concat . forM [minBound..maxBound] $ \arithop -> do
      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
          cnamebase = "c_binary_" ++ atCName arithtype
          c_ss_str = varE (aboNumOp arithop)
          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
      sequence [SigD name <$>
                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
               ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
                   return $ FunD name [Clause [] (NormalB body) []]])

$(fmap concat . forM intTypesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
    fmap concat . forM [minBound..maxBound] $ \arithop -> do
      let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
          cnamebase = "c_ibinary_" ++ atCName arithtype
          c_ss_str = varE (aiboNumOp arithop)
          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
      sequence [SigD name <$>
                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
               ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
                   return $ FunD name [Clause [] (NormalB body) []]])

$(fmap concat . forM floatTypesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
    fmap concat . forM [minBound..maxBound] $ \arithop -> do
      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
          cnamebase = "c_fbinary_" ++ atCName arithtype
          c_ss_str = varE (afboNumOp arithop)
          c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
          c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
          c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
      sequence [SigD name <$>
                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
               ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
                   return $ FunD name [Clause [] (NormalB body) []]])

$(fmap concat . forM typesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
    fmap concat . forM [minBound..maxBound] $ \arithop -> do
      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
          c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
      sequence [SigD name <$>
                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
               ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
                   return $ FunD name [Clause [] (NormalB body) []]])

$(fmap concat . forM floatTypesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
    fmap concat . forM [minBound..maxBound] $ \arithop -> do
      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
          c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
      sequence [SigD name <$>
                     [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
               ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
                   return $ FunD name [Clause [] (NormalB body) []]])

$(fmap concat . forM typesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
    fmap concat . forM [minBound..maxBound] $ \arithop -> do
      let scaleVar = case arithop of
                       RO_SUM -> varE 'mulWithInt
                       RO_PRODUCT -> varE '(^)
      let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
          namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
          c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
          c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
      sequence [SigD name1 <$>
                     [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |]
               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
                   return $ FunD name1 [Clause [] (NormalB body) []]
               ,SigD namefull <$>
                     [t| forall n. SNat n -> Array n $ttyp -> $ttyp |]
               ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
                   return $ FunD namefull [Clause [] (NormalB body) []]
               ])

$(fmap concat . forM typesList $ \arithtype ->
    fmap concat . forM ["min", "max"] $ \fname -> do
      let ttyp = conT (atType arithtype)
          name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
          c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
      sequence [SigD name <$>
                     [t| forall n. Array n $ttyp -> [Int] |]
               ,do body <- [| vectorExtremumOp id $c_op |]
                   return $ FunD name [Clause [] (NormalB body) []]])

$(fmap concat . forM typesList $ \arithtype -> do
    let ttyp = conT (atType arithtype)
        name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
        c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
        mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
        c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
        c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
    sequence [SigD name <$>
                   [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |]
             ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
                 return $ FunD name [Clause [] (NormalB body) []]])

foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()

statisticsEnable :: Bool -> IO ()
statisticsEnable :: Bool -> IO ()
statisticsEnable Bool
b = Int32 -> IO ()
c_stats_enable (if Bool
b then Int32
1 else Int32
0)

-- | Consumes the log: one particular event will only ever be printed once,
-- even if statisticsPrintAll is called multiple times.
statisticsPrintAll :: IO ()
statisticsPrintAll :: IO ()
statisticsPrintAll = do
  Handle -> IO ()
hFlush Handle
stdout  -- lower the chance of overlapping output
  IO ()
c_stats_print_all

-- This branch is ostensibly a runtime branch, but will (hopefully) be
-- constant-folded away by GHC.
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
              => (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
              -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
              -> (SNat n -> Array n i -> Array n i)
intWidBranch1 :: forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f32 forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f64 SNat n
sn
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = SNat n
-> (Ptr i -> Ptr Int32)
-> (Int64
    -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
liftOpEltwise1 SNat n
sn Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f32
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = SNat n
-> (Ptr i -> Ptr Int64)
-> (Int64
    -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> Array n a
liftOpEltwise1 SNat n
sn Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
f64
  | Bool
otherwise = String -> Array n i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"

intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
              => (i -> i -> i)  -- ss
                 -- int32
              -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- sv
              -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- vs
              -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- vv
                 -- int64
              -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- sv
              -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())  -- vs
              -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- vv
              -> (SNat n -> Array n i -> Array n i -> Array n i)
intWidBranch2 :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 i -> i -> i
ss forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv32 forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs32 forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv64 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs64 forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv64 SNat n
sn
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = SNat n
-> (i -> Int32)
-> (Ptr i -> Ptr Int32)
-> (i -> i -> i)
-> (Int64
    -> Ptr Int64
    -> Ptr Int32
    -> Int32
    -> Ptr Int64
    -> Ptr Int32
    -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr Int32
    -> Ptr Int64
    -> Ptr Int32
    -> Int32
    -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr Int32
    -> Ptr Int64
    -> Ptr Int32
    -> Ptr Int64
    -> Ptr Int32
    -> IO ())
-> Array n i
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
liftOpEltwise2 SNat n
sn i -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr i -> i -> i
ss Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv32 Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs32 Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv32
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = SNat n
-> (i -> Int64)
-> (Ptr i -> Ptr Int64)
-> (i -> i -> i)
-> (Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Int64
    -> Ptr Int64
    -> Ptr Int64
    -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Int64
    -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Ptr Int64
    -> Ptr Int64
    -> IO ())
-> Array n i
-> Array n i
-> Array n i
forall a (n :: Nat) b.
Storable a =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (a -> a -> a)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
    -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array n a
-> Array n a
-> Array n a
liftOpEltwise2 SNat n
sn i -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr i -> i -> i
ss Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
sv64 Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()
vs64 Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
vv64
  | Bool
otherwise = String -> Array n i -> Array n i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"

intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
                 => -- int32
                    (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant
                 -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                    -- int64
                 -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant
                 -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                 -> (SNat n -> Array (n + 1) i -> Array n i)
intWidBranchRed1 :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64 SNat n
sn
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp @i @Int32 SNat n
sn i -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array (n + 1) a
-> Array n a
vectorRedInnerOp @i @Int64 SNat n
sn i -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64
  | Bool
otherwise = String -> Array (n + 1) i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"

intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
                    => (i -> Int -> i)  -- ^ scale op
                       -- int32
                    -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel
                       -- int64
                    -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)  -- ^ reduction kernel
                    -> (SNat n -> Array n i -> i)
intWidBranchRedFull :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull i -> Int -> i
fsc forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred64 SNat n
sn
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a
-> a
vectorRedFullOp @i @Int32 SNat n
sn i -> Int -> i
fsc Int32 -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred32
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> Int -> a)
-> (b -> a)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> Array n a
-> a
vectorRedFullOp @i @Int64 SNat n
sn i -> Int -> i
fsc Int64 -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b
fred64
  | Bool
otherwise = String -> Array n i -> i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"

intWidBranchExtr :: forall i n. (FiniteBits i, Storable i)
                 => -- int32
                    (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel
                    -- int64
                 -> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel
                 -> (Array n i -> [Int])
intWidBranchExtr :: forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr32 forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr64
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
Storable a =>
(Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> [Int]
vectorExtremumOp @i @Int32 Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr32
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
Storable a =>
(Ptr a -> Ptr b)
-> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n a
-> [Int]
vectorExtremumOp @i @Int64 Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fextr64
  | Bool
otherwise = String -> Array n i -> [Int]
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"

intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
                    => -- int32
                       (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant
                    -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                    -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel
                       -- int64
                    -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ scale by constant
                    -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                    -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())  -- ^ dotprod kernel
                    -> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i)
intWidBranchDotprod :: forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i, NumElt i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array (n + 1) i
-> Array n i
intWidBranchDotprod forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32 forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot32 forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64 forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot64 SNat n
sn
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array (n + 1) a
-> Array (n + 1) a
-> Array n a
vectorDotprodInnerOp @i @Int32 SNat n
sn i -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int32
forall a b. Ptr a -> Ptr b
castPtr SNat n -> Array n i -> Array n i -> Array n i
forall (n :: Nat). SNat n -> Array n i -> Array n i -> Array n i
forall a (n :: Nat).
NumElt a =>
SNat n -> Array n a -> Array n a -> Array n a
numEltMul Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc32 Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred32 Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot32
  | i -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (i
forall a. HasCallStack => a
undefined :: i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = forall a b (n :: Nat).
(Num a, Storable a) =>
SNat n
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (SNat n -> Array n a -> Array n a -> Array n a)
-> (Int64
    -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> Array (n + 1) a
-> Array (n + 1) a
-> Array n a
vectorDotprodInnerOp @i @Int64 SNat n
sn i -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Ptr i -> Ptr Int64
forall a b. Ptr a -> Ptr b
castPtr SNat n -> Array n i -> Array n i -> Array n i
forall (n :: Nat). SNat n -> Array n i -> Array n i -> Array n i
forall a (n :: Nat).
NumElt a =>
SNat n -> Array n a -> Array n a -> Array n a
numEltMul Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()
fsc64 Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
fred64 Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
fdot64
  | Bool
otherwise = String -> Array (n + 1) i -> Array (n + 1) i -> Array n i
forall a. HasCallStack => String -> a
error String
"Unsupported Int width"

class NumElt a where
  numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a
  numEltSub :: SNat n -> Array n a -> Array n a -> Array n a
  numEltMul :: SNat n -> Array n a -> Array n a -> Array n a
  numEltNeg :: SNat n -> Array n a -> Array n a
  numEltAbs :: SNat n -> Array n a -> Array n a
  numEltSignum :: SNat n -> Array n a -> Array n a
  numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a
  numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a
  numEltSumFull :: SNat n -> Array n a -> a
  numEltProductFull :: SNat n -> Array n a -> a
  numEltMinIndex :: SNat n -> Array n a -> [Int]
  numEltMaxIndex :: SNat n -> Array n a -> [Int]
  numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a

instance NumElt Int32 where
  numEltAdd :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
numEltAdd = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
addVectorInt32
  numEltSub :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
numEltSub = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
subVectorInt32
  numEltMul :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
numEltMul = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
mulVectorInt32
  numEltNeg :: forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
numEltNeg = SNat n -> Array n Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
negVectorInt32
  numEltAbs :: forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
numEltAbs = SNat n -> Array n Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
absVectorInt32
  numEltSignum :: forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
numEltSignum = SNat n -> Array n Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Array n Int32
signumVectorInt32
  numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
numEltSum1Inner = SNat n -> Array (n + 1) Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
sum1VectorInt32
  numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
numEltProduct1Inner = SNat n -> Array (n + 1) Int32 -> Array n Int32
forall (n :: Nat). SNat n -> Array (n + 1) Int32 -> Array n Int32
product1VectorInt32
  numEltSumFull :: forall (n :: Nat). SNat n -> Array n Int32 -> Int32
numEltSumFull = SNat n -> Array n Int32 -> Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Int32
sumFullVectorInt32
  numEltProductFull :: forall (n :: Nat). SNat n -> Array n Int32 -> Int32
numEltProductFull = SNat n -> Array n Int32 -> Int32
forall (n :: Nat). SNat n -> Array n Int32 -> Int32
productFullVectorInt32
  numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Int32 -> [Int]
numEltMinIndex SNat n
_ = Array n Int32 -> [Int]
forall (n :: Nat). Array n Int32 -> [Int]
minindexVectorInt32
  numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Int32 -> [Int]
numEltMaxIndex SNat n
_ = Array n Int32 -> [Int]
forall (n :: Nat). Array n Int32 -> [Int]
maxindexVectorInt32
  numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Int32 -> Array (n + 1) Int32 -> Array n Int32
numEltDotprodInner = SNat n
-> Array (n + 1) Int32 -> Array (n + 1) Int32 -> Array n Int32
forall (n :: Nat).
SNat n
-> Array (n + 1) Int32 -> Array (n + 1) Int32 -> Array n Int32
dotprodinnerVectorInt32

instance NumElt Int64 where
  numEltAdd :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
numEltAdd = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
addVectorInt64
  numEltSub :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
numEltSub = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
subVectorInt64
  numEltMul :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
numEltMul = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
mulVectorInt64
  numEltNeg :: forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
numEltNeg = SNat n -> Array n Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
negVectorInt64
  numEltAbs :: forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
numEltAbs = SNat n -> Array n Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
absVectorInt64
  numEltSignum :: forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
numEltSignum = SNat n -> Array n Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Array n Int64
signumVectorInt64
  numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
numEltSum1Inner = SNat n -> Array (n + 1) Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
sum1VectorInt64
  numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
numEltProduct1Inner = SNat n -> Array (n + 1) Int64 -> Array n Int64
forall (n :: Nat). SNat n -> Array (n + 1) Int64 -> Array n Int64
product1VectorInt64
  numEltSumFull :: forall (n :: Nat). SNat n -> Array n Int64 -> Int64
numEltSumFull = SNat n -> Array n Int64 -> Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Int64
sumFullVectorInt64
  numEltProductFull :: forall (n :: Nat). SNat n -> Array n Int64 -> Int64
numEltProductFull = SNat n -> Array n Int64 -> Int64
forall (n :: Nat). SNat n -> Array n Int64 -> Int64
productFullVectorInt64
  numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Int64 -> [Int]
numEltMinIndex SNat n
_ = Array n Int64 -> [Int]
forall (n :: Nat). Array n Int64 -> [Int]
minindexVectorInt64
  numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Int64 -> [Int]
numEltMaxIndex SNat n
_ = Array n Int64 -> [Int]
forall (n :: Nat). Array n Int64 -> [Int]
maxindexVectorInt64
  numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Int64 -> Array (n + 1) Int64 -> Array n Int64
numEltDotprodInner = SNat n
-> Array (n + 1) Int64 -> Array (n + 1) Int64 -> Array n Int64
forall (n :: Nat).
SNat n
-> Array (n + 1) Int64 -> Array (n + 1) Int64 -> Array n Int64
dotprodinnerVectorInt64

instance NumElt Float where
  numEltAdd :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
numEltAdd = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
addVectorFloat
  numEltSub :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
numEltSub = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
subVectorFloat
  numEltMul :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
numEltMul = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
mulVectorFloat
  numEltNeg :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
numEltNeg = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
negVectorFloat
  numEltAbs :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
numEltAbs = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
absVectorFloat
  numEltSignum :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
numEltSignum = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
signumVectorFloat
  numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
numEltSum1Inner = SNat n -> Array (n + 1) Float -> Array n Float
forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
sum1VectorFloat
  numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
numEltProduct1Inner = SNat n -> Array (n + 1) Float -> Array n Float
forall (n :: Nat). SNat n -> Array (n + 1) Float -> Array n Float
product1VectorFloat
  numEltSumFull :: forall (n :: Nat). SNat n -> Array n Float -> Float
numEltSumFull = SNat n -> Array n Float -> Float
forall (n :: Nat). SNat n -> Array n Float -> Float
sumFullVectorFloat
  numEltProductFull :: forall (n :: Nat). SNat n -> Array n Float -> Float
numEltProductFull = SNat n -> Array n Float -> Float
forall (n :: Nat). SNat n -> Array n Float -> Float
productFullVectorFloat
  numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Float -> [Int]
numEltMinIndex SNat n
_ = Array n Float -> [Int]
forall (n :: Nat). Array n Float -> [Int]
minindexVectorFloat
  numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Float -> [Int]
numEltMaxIndex SNat n
_ = Array n Float -> [Int]
forall (n :: Nat). Array n Float -> [Int]
maxindexVectorFloat
  numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Float -> Array (n + 1) Float -> Array n Float
numEltDotprodInner = SNat n
-> Array (n + 1) Float -> Array (n + 1) Float -> Array n Float
forall (n :: Nat).
SNat n
-> Array (n + 1) Float -> Array (n + 1) Float -> Array n Float
dotprodinnerVectorFloat

instance NumElt Double where
  numEltAdd :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
numEltAdd = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
addVectorDouble
  numEltSub :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
numEltSub = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
subVectorDouble
  numEltMul :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
numEltMul = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
mulVectorDouble
  numEltNeg :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
numEltNeg = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
negVectorDouble
  numEltAbs :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
numEltAbs = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
absVectorDouble
  numEltSignum :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
numEltSignum = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
signumVectorDouble
  numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
numEltSum1Inner = SNat n -> Array (n + 1) Double -> Array n Double
forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
sum1VectorDouble
  numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
numEltProduct1Inner = SNat n -> Array (n + 1) Double -> Array n Double
forall (n :: Nat). SNat n -> Array (n + 1) Double -> Array n Double
product1VectorDouble
  numEltSumFull :: forall (n :: Nat). SNat n -> Array n Double -> Double
numEltSumFull = SNat n -> Array n Double -> Double
forall (n :: Nat). SNat n -> Array n Double -> Double
sumFullVectorDouble
  numEltProductFull :: forall (n :: Nat). SNat n -> Array n Double -> Double
numEltProductFull = SNat n -> Array n Double -> Double
forall (n :: Nat). SNat n -> Array n Double -> Double
productFullVectorDouble
  numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Double -> [Int]
numEltMinIndex SNat n
_ = Array n Double -> [Int]
forall (n :: Nat). Array n Double -> [Int]
minindexVectorDouble
  numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Double -> [Int]
numEltMaxIndex SNat n
_ = Array n Double -> [Int]
forall (n :: Nat). Array n Double -> [Int]
maxindexVectorDouble
  numEltDotprodInner :: forall (n :: Nat).
SNat n
-> Array (n + 1) Double -> Array (n + 1) Double -> Array n Double
numEltDotprodInner = SNat n
-> Array (n + 1) Double -> Array (n + 1) Double -> Array n Double
forall (n :: Nat).
SNat n
-> Array (n + 1) Double -> Array (n + 1) Double -> Array n Double
dotprodinnerVectorDouble

instance NumElt Int where
  numEltAdd :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
numEltAdd = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
  numEltSub :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
numEltSub = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int (-)
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
  numEltMul :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
numEltMul = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Num a => a -> a -> a
(*)
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
  numEltNeg :: forall (n :: Nat). SNat n -> Array n Int -> Array n Int
numEltNeg = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @Int (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG))
  numEltAbs :: forall (n :: Nat). SNat n -> Array n Int -> Array n Int
numEltAbs = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @Int (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS))
  numEltSignum :: forall (n :: Nat). SNat n -> Array n Int -> Array n Int
numEltSignum = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @Int (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM))
  numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int -> Array n Int
numEltSum1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @Int
                      (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
                      (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
  numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) Int -> Array n Int
numEltProduct1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @Int
                          (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
                          (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
  numEltSumFull :: forall (n :: Nat). SNat n -> Array n Int -> Int
numEltSumFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @Int Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
  numEltProductFull :: forall (n :: Nat). SNat n -> Array n Int -> Int
numEltProductFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @Int Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
(^) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
  numEltMinIndex :: forall (n :: Nat). SNat n -> Array n Int -> [Int]
numEltMinIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @Int Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i64
  numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n Int -> [Int]
numEltMaxIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @Int Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i64
  numEltDotprodInner :: forall (n :: Nat).
SNat n -> Array (n + 1) Int -> Array (n + 1) Int -> Array n Int
numEltDotprodInner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i, NumElt i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array (n + 1) i
-> Array n i
intWidBranchDotprod @Int (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i32
                                                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i64

instance NumElt CInt where
  numEltAdd :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
numEltAdd = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
(+)
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_ADD))
  numEltSub :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
numEltSub = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt (-)
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_SUB))
  numEltMul :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
numEltMul = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
(*)
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL))
  numEltNeg :: forall (n :: Nat). SNat n -> Array n CInt -> Array n CInt
numEltNeg = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @CInt (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_NEG))
  numEltAbs :: forall (n :: Nat). SNat n -> Array n CInt -> Array n CInt
numEltAbs = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @CInt (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_ABS))
  numEltSignum :: forall (n :: Nat). SNat n -> Array n CInt -> Array n CInt
numEltSignum = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array n i
-> Array n i
intWidBranch1 @CInt (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_unary_i32_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_unary_i64_strided (ArithUOp -> CInt
auoEnum ArithUOp
UO_SIGNUM))
  numEltSum1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) CInt -> Array n CInt
numEltSum1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @CInt
                      (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
                      (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
  numEltProduct1Inner :: forall (n :: Nat). SNat n -> Array (n + 1) CInt -> Array n CInt
numEltProduct1Inner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array n i
intWidBranchRed1 @CInt
                          (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
                          (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
  numEltSumFull :: forall (n :: Nat). SNat n -> Array n CInt -> CInt
numEltSumFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @CInt CInt -> Int -> CInt
forall a. Num a => a -> Int -> a
mulWithInt (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM))
  numEltProductFull :: forall (n :: Nat). SNat n -> Array n CInt -> CInt
numEltProductFull = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> Int -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b)
-> SNat n
-> Array n i
-> i
intWidBranchRedFull @CInt CInt -> Int -> CInt
forall a b. (Num a, Integral b) => a -> b -> a
(^) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32
c_reducefull_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT)) (CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64
c_reducefull_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_PRODUCT))
  numEltMinIndex :: forall (n :: Nat). SNat n -> Array n CInt -> [Int]
numEltMinIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @CInt Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_min_i64
  numEltMaxIndex :: forall (n :: Nat). SNat n -> Array n CInt -> [Int]
numEltMaxIndex SNat n
_ = forall i (n :: Nat).
(FiniteBits i, Storable i) =>
(forall b.
 (b ~ Int32) =>
 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> Array n i
-> [Int]
intWidBranchExtr @CInt Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()
forall b.
(b ~ Int32) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i32 Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()
forall b.
(b ~ Int64) =>
Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()
c_extremum_max_i64
  numEltDotprodInner :: forall (n :: Nat).
SNat n -> Array (n + 1) CInt -> Array (n + 1) CInt -> Array n CInt
numEltDotprodInner = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i, NumElt i) =>
(forall b.
 (b ~ Int32) =>
 Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array (n + 1) i
-> Array (n + 1) i
-> Array n i
intWidBranchDotprod @CInt (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_reduce1_i32 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
forall b.
(b ~ Int32) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i32
                                                 (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithBOp -> CInt
aboEnum ArithBOp
BO_MUL)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_reduce1_i64 (ArithRedOp -> CInt
aroEnum ArithRedOp
RO_SUM)) Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
forall b.
(b ~ Int64) =>
Int64
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> Ptr Int64
-> Ptr b
-> IO ()
c_dotprodinner_i64

class NumElt a => IntElt a where
  intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
  intEltRem :: SNat n -> Array n a -> Array n a -> Array n a

instance IntElt Int32 where
  intEltQuot :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
intEltQuot = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
quotVectorInt32
  intEltRem :: forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
intEltRem = SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
forall (n :: Nat).
SNat n -> Array n Int32 -> Array n Int32 -> Array n Int32
remVectorInt32

instance IntElt Int64 where
  intEltQuot :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
intEltQuot = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
quotVectorInt64
  intEltRem :: forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
intEltRem = SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
forall (n :: Nat).
SNat n -> Array n Int64 -> Array n Int64 -> Array n Int64
remVectorInt64

instance IntElt Int where
  intEltQuot :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
intEltQuot = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot
                 (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
                 (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
  intEltRem :: forall (n :: Nat).
SNat n -> Array n Int -> Array n Int -> Array n Int
intEltRem = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @Int Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))

instance IntElt CInt where
  intEltQuot :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
intEltQuot = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Integral a => a -> a -> a
quot
                 (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
                 (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_QUOT))
  intEltRem :: forall (n :: Nat).
SNat n -> Array n CInt -> Array n CInt -> Array n CInt
intEltRem = forall i (n :: Nat).
(FiniteBits i, Storable i, Integral i) =>
(i -> i -> i)
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int32) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
-> (forall b.
    (b ~ Int64) =>
    Int64
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> Ptr Int64
    -> Ptr b
    -> IO ())
-> SNat n
-> Array n i
-> Array n i
-> Array n i
intWidBranch2 @CInt CInt -> CInt -> CInt
forall a. Integral a => a -> a -> a
rem
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Int32
-> IO ()
c_binary_i32_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> Ptr Int64
-> Ptr Int32
-> IO ()
c_binary_i32_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))
                (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_sv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Int64
-> IO ()
c_binary_i64_vs_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM)) (CInt
-> Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> Ptr Int64
-> IO ()
c_binary_i64_vv_strided (ArithIBOp -> CInt
aiboEnum ArithIBOp
IB_REM))

class NumElt a => FloatElt a where
  floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
  floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a
  floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a
  floatEltRecip :: SNat n -> Array n a -> Array n a
  floatEltExp :: SNat n -> Array n a -> Array n a
  floatEltLog :: SNat n -> Array n a -> Array n a
  floatEltSqrt :: SNat n -> Array n a -> Array n a
  floatEltSin :: SNat n -> Array n a -> Array n a
  floatEltCos :: SNat n -> Array n a -> Array n a
  floatEltTan :: SNat n -> Array n a -> Array n a
  floatEltAsin :: SNat n -> Array n a -> Array n a
  floatEltAcos :: SNat n -> Array n a -> Array n a
  floatEltAtan :: SNat n -> Array n a -> Array n a
  floatEltSinh :: SNat n -> Array n a -> Array n a
  floatEltCosh :: SNat n -> Array n a -> Array n a
  floatEltTanh :: SNat n -> Array n a -> Array n a
  floatEltAsinh :: SNat n -> Array n a -> Array n a
  floatEltAcosh :: SNat n -> Array n a -> Array n a
  floatEltAtanh :: SNat n -> Array n a -> Array n a
  floatEltLog1p :: SNat n -> Array n a -> Array n a
  floatEltExpm1 :: SNat n -> Array n a -> Array n a
  floatEltLog1pexp :: SNat n -> Array n a -> Array n a
  floatEltLog1mexp :: SNat n -> Array n a -> Array n a
  floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a

instance FloatElt Float where
  floatEltDiv :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltDiv = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
divVectorFloat
  floatEltPow :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltPow = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
powVectorFloat
  floatEltLogbase :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltLogbase = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
logbaseVectorFloat
  floatEltRecip :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltRecip = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
recipVectorFloat
  floatEltExp :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltExp = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
expVectorFloat
  floatEltLog :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
logVectorFloat
  floatEltSqrt :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltSqrt = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
sqrtVectorFloat
  floatEltSin :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltSin = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
sinVectorFloat
  floatEltCos :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltCos = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
cosVectorFloat
  floatEltTan :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltTan = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
tanVectorFloat
  floatEltAsin :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAsin = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
asinVectorFloat
  floatEltAcos :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAcos = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
acosVectorFloat
  floatEltAtan :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAtan = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
atanVectorFloat
  floatEltSinh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltSinh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
sinhVectorFloat
  floatEltCosh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltCosh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
coshVectorFloat
  floatEltTanh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltTanh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
tanhVectorFloat
  floatEltAsinh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAsinh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
asinhVectorFloat
  floatEltAcosh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAcosh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
acoshVectorFloat
  floatEltAtanh :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltAtanh = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
atanhVectorFloat
  floatEltLog1p :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog1p = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
log1pVectorFloat
  floatEltExpm1 :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltExpm1 = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
expm1VectorFloat
  floatEltLog1pexp :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog1pexp = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
log1pexpVectorFloat
  floatEltLog1mexp :: forall (n :: Nat). SNat n -> Array n Float -> Array n Float
floatEltLog1mexp = SNat n -> Array n Float -> Array n Float
forall (n :: Nat). SNat n -> Array n Float -> Array n Float
log1mexpVectorFloat
  floatEltAtan2 :: forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
floatEltAtan2 = SNat n -> Array n Float -> Array n Float -> Array n Float
forall (n :: Nat).
SNat n -> Array n Float -> Array n Float -> Array n Float
atan2VectorFloat

instance FloatElt Double where
  floatEltDiv :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltDiv = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
divVectorDouble
  floatEltPow :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltPow = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
powVectorDouble
  floatEltLogbase :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltLogbase = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
logbaseVectorDouble
  floatEltRecip :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltRecip = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
recipVectorDouble
  floatEltExp :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltExp = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
expVectorDouble
  floatEltLog :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
logVectorDouble
  floatEltSqrt :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltSqrt = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
sqrtVectorDouble
  floatEltSin :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltSin = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
sinVectorDouble
  floatEltCos :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltCos = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
cosVectorDouble
  floatEltTan :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltTan = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
tanVectorDouble
  floatEltAsin :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAsin = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
asinVectorDouble
  floatEltAcos :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAcos = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
acosVectorDouble
  floatEltAtan :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAtan = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
atanVectorDouble
  floatEltSinh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltSinh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
sinhVectorDouble
  floatEltCosh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltCosh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
coshVectorDouble
  floatEltTanh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltTanh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
tanhVectorDouble
  floatEltAsinh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAsinh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
asinhVectorDouble
  floatEltAcosh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAcosh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
acoshVectorDouble
  floatEltAtanh :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltAtanh = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
atanhVectorDouble
  floatEltLog1p :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog1p = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
log1pVectorDouble
  floatEltExpm1 :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltExpm1 = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
expm1VectorDouble
  floatEltLog1pexp :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog1pexp = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
log1pexpVectorDouble
  floatEltLog1mexp :: forall (n :: Nat). SNat n -> Array n Double -> Array n Double
floatEltLog1mexp = SNat n -> Array n Double -> Array n Double
forall (n :: Nat). SNat n -> Array n Double -> Array n Double
log1mexpVectorDouble
  floatEltAtan2 :: forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
floatEltAtan2 = SNat n -> Array n Double -> Array n Double -> Array n Double
forall (n :: Nat).
SNat n -> Array n Double -> Array n Double -> Array n Double
atan2VectorDouble