-- | Operations that (impurely, via a strictly increasing thread-safe counter)
-- generate fresh variables and sometimes also produce AST terms
-- by applying functions to such variables. This module encapsulates
-- the impurity, though some functions are in IO and they are used
-- with @unsafePerformIO@ outside, so some of the impurity escapes
-- and is encapsulated elsewhere.
module HordeAd.Core.AstFreshId
  ( funToAstIO, funToAst, funToAst2, fun1ToAst
  , funToAstRevIO, funToAstFwdIO
  , funToAstIntVarIO, funToAstIntVar, funToAstI
  , funToVarsIxS, funToAstIxS
    -- * Low level counter manipulation to be used only in sequential tests
  , resetVarCounter
  ) where

import Prelude

import Control.Concurrent.Counter (Counter, add, new, set)
import Data.Int (Int64)
import GHC.Exts (IsList (..))
import System.IO.Unsafe (unsafePerformIO)

import Data.Array.Nested.Shaped.Shape

import HordeAd.Core.Ast
import HordeAd.Core.TensorKind
import HordeAd.Core.Types

-- | A counter that is impure but only in the most trivial way
-- (only ever incremented by one).
unsafeAstVarCounter :: Counter
{-# NOINLINE unsafeAstVarCounter #-}
unsafeAstVarCounter :: Counter
unsafeAstVarCounter = IO Counter -> Counter
forall a. IO a -> a
unsafePerformIO (Int -> IO Counter
new Int
100000001)

-- | Only for tests, e.g., to ensure `show` applied to terms has stable length.
-- Tests that use this tool need to be run sequentially
-- to avoid variable confusion.
resetVarCounter :: IO ()
resetVarCounter :: IO ()
resetVarCounter = Counter -> Int -> IO ()
set Counter
unsafeAstVarCounter Int
100000001

unsafeGetFreshAstVarId :: IO AstVarId
{-# INLINE unsafeGetFreshAstVarId #-}
unsafeGetFreshAstVarId :: IO AstVarId
unsafeGetFreshAstVarId =
  Int -> AstVarId
intToAstVarId (Int -> AstVarId) -> IO Int -> IO AstVarId
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Counter -> Int -> IO Int
add Counter
unsafeAstVarCounter Int
1

unsafeGetFreshAstVarName :: FullShapeTK y -> Maybe (Int64, Int64)
                         -> IO (AstVarName s y)
{-# INLINE unsafeGetFreshAstVarName #-}
unsafeGetFreshAstVarName :: forall (y :: TK) (s :: AstSpanType).
FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
unsafeGetFreshAstVarName FullShapeTK y
ftk Maybe (Int64, Int64)
bounds =
  FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK y
ftk Maybe (Int64, Int64)
bounds
  (AstVarId -> AstVarName s y)
-> (Int -> AstVarId) -> Int -> AstVarName s y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int -> AstVarName s y) -> IO Int -> IO (AstVarName s y)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Counter -> Int -> IO Int
add Counter
unsafeAstVarCounter Int
1

funToAstIO2 :: forall y z s s2 ms. AstSpan s
            => FullShapeTK y -> Maybe (Int64, Int64)
            -> (AstTensor ms s y -> AstTensor ms s2 z)
            -> IO (AstVarName s y, AstTensor ms s2 z)
{-# INLINE funToAstIO2 #-}
funToAstIO2 :: forall (y :: TK) (z :: TK) (s :: AstSpanType) (s2 :: AstSpanType)
       (ms :: AstMethodOfSharing).
AstSpan s =>
FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> IO (AstVarName s y, AstTensor ms s2 z)
funToAstIO2 FullShapeTK y
ftk Maybe (Int64, Int64)
bounds AstTensor ms s y -> AstTensor ms s2 z
f = do
  freshId <- FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
forall (y :: TK) (s :: AstSpanType).
FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
unsafeGetFreshAstVarName FullShapeTK y
ftk Maybe (Int64, Int64)
bounds
  let !x = AstTensor ms s y -> AstTensor ms s2 z
f (AstVarName s y -> AstTensor ms s y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar AstVarName s y
freshId)
  return (freshId, x)
-- Warning: adding a bang before freshId breaks fragile tests.
-- Probably GHC then optimizes differently and less predictably
-- and so changes results between -O0 vs -O1 and possibly also
-- between different GHC versions and between local vs CI setup.

funToAst2 :: AstSpan s
          => FullShapeTK y -> Maybe (Int64, Int64)
          -> (AstTensor ms s y -> AstTensor ms s2 z)
          -> (AstVarName s y, AstTensor ms s2 z)
{-# NOINLINE funToAst2 #-}
funToAst2 :: forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing)
       (s2 :: AstSpanType) (z :: TK).
AstSpan s =>
FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> (AstVarName s y, AstTensor ms s2 z)
funToAst2 FullShapeTK y
ftk Maybe (Int64, Int64)
bounds = IO (AstVarName s y, AstTensor ms s2 z)
-> (AstVarName s y, AstTensor ms s2 z)
forall a. IO a -> a
unsafePerformIO (IO (AstVarName s y, AstTensor ms s2 z)
 -> (AstVarName s y, AstTensor ms s2 z))
-> ((AstTensor ms s y -> AstTensor ms s2 z)
    -> IO (AstVarName s y, AstTensor ms s2 z))
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> (AstVarName s y, AstTensor ms s2 z)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> IO (AstVarName s y, AstTensor ms s2 z)
forall (y :: TK) (z :: TK) (s :: AstSpanType) (s2 :: AstSpanType)
       (ms :: AstMethodOfSharing).
AstSpan s =>
FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> IO (AstVarName s y, AstTensor ms s2 z)
funToAstIO2 FullShapeTK y
ftk Maybe (Int64, Int64)
bounds

funToAstIO :: forall y z s ms. AstSpan s
           => FullShapeTK y
           -> (AstTensor ms s y -> AstTensor ms s z)
           -> IO (AstVarName s y, AstTensor ms s z)
{-# INLINE funToAstIO #-}
funToAstIO :: forall (y :: TK) (z :: TK) (s :: AstSpanType)
       (ms :: AstMethodOfSharing).
AstSpan s =>
FullShapeTK y
-> (AstTensor ms s y -> AstTensor ms s z)
-> IO (AstVarName s y, AstTensor ms s z)
funToAstIO FullShapeTK y
ftk = FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s z)
-> IO (AstVarName s y, AstTensor ms s z)
forall (y :: TK) (z :: TK) (s :: AstSpanType) (s2 :: AstSpanType)
       (ms :: AstMethodOfSharing).
AstSpan s =>
FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> IO (AstVarName s y, AstTensor ms s2 z)
funToAstIO2 FullShapeTK y
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing

funToAst :: AstSpan s
         => FullShapeTK y -> Maybe (Int64, Int64)
         -> (AstTensor ms s y -> AstTensor ms s z)
         -> (AstVarName s y, AstTensor ms s z)
{-# NOINLINE funToAst #-}
funToAst :: forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing)
       (z :: TK).
AstSpan s =>
FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s z)
-> (AstVarName s y, AstTensor ms s z)
funToAst FullShapeTK y
ftk Maybe (Int64, Int64)
bounds = IO (AstVarName s y, AstTensor ms s z)
-> (AstVarName s y, AstTensor ms s z)
forall a. IO a -> a
unsafePerformIO (IO (AstVarName s y, AstTensor ms s z)
 -> (AstVarName s y, AstTensor ms s z))
-> ((AstTensor ms s y -> AstTensor ms s z)
    -> IO (AstVarName s y, AstTensor ms s z))
-> (AstTensor ms s y -> AstTensor ms s z)
-> (AstVarName s y, AstTensor ms s z)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s z)
-> IO (AstVarName s y, AstTensor ms s z)
forall (y :: TK) (z :: TK) (s :: AstSpanType) (s2 :: AstSpanType)
       (ms :: AstMethodOfSharing).
AstSpan s =>
FullShapeTK y
-> Maybe (Int64, Int64)
-> (AstTensor ms s y -> AstTensor ms s2 z)
-> IO (AstVarName s y, AstTensor ms s2 z)
funToAstIO2 FullShapeTK y
ftk Maybe (Int64, Int64)
bounds

fun1ToAstIO :: FullShapeTK y -> (AstVarName s y -> AstTensor ms s y)
            -> IO (AstTensor ms s y)
{-# INLINE fun1ToAstIO #-}
fun1ToAstIO :: forall (y :: TK) (s :: AstSpanType) (ms :: AstMethodOfSharing).
FullShapeTK y
-> (AstVarName s y -> AstTensor ms s y) -> IO (AstTensor ms s y)
fun1ToAstIO FullShapeTK y
ftk AstVarName s y -> AstTensor ms s y
f = do
  !freshId <- FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
forall (y :: TK) (s :: AstSpanType).
FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
unsafeGetFreshAstVarName FullShapeTK y
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing
  return $! f freshId

fun1ToAst :: FullShapeTK y -> (AstVarName s y -> AstTensor ms s y)
          -> AstTensor ms s y
{-# NOINLINE fun1ToAst #-}
fun1ToAst :: forall (y :: TK) (s :: AstSpanType) (ms :: AstMethodOfSharing).
FullShapeTK y
-> (AstVarName s y -> AstTensor ms s y) -> AstTensor ms s y
fun1ToAst FullShapeTK y
ftk = IO (AstTensor ms s y) -> AstTensor ms s y
forall a. IO a -> a
unsafePerformIO (IO (AstTensor ms s y) -> AstTensor ms s y)
-> ((AstVarName s y -> AstTensor ms s y) -> IO (AstTensor ms s y))
-> (AstVarName s y -> AstTensor ms s y)
-> AstTensor ms s y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FullShapeTK y
-> (AstVarName s y -> AstTensor ms s y) -> IO (AstTensor ms s y)
forall (y :: TK) (s :: AstSpanType) (ms :: AstMethodOfSharing).
FullShapeTK y
-> (AstVarName s y -> AstTensor ms s y) -> IO (AstTensor ms s y)
fun1ToAstIO FullShapeTK y
ftk

funToAstRevIO :: forall x.
                 FullShapeTK x
              -> IO ( AstVarName PrimalSpan x
                    , AstTensor AstMethodShare PrimalSpan x
                    , AstVarName FullSpan x
                    , AstTensor AstMethodLet FullSpan x )
{-# INLINE funToAstRevIO #-}
funToAstRevIO :: forall (x :: TK).
FullShapeTK x
-> IO
     (AstVarName PrimalSpan x, AstTensor AstMethodShare PrimalSpan x,
      AstVarName FullSpan x, AstTensor AstMethodLet FullSpan x)
funToAstRevIO FullShapeTK x
ftk = do
  !freshId <- IO AstVarId
unsafeGetFreshAstVarId
  let varPrimal :: AstVarName PrimalSpan x
      varPrimal = FullShapeTK x
-> Maybe (Int64, Int64) -> AstVarId -> AstVarName PrimalSpan x
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK x
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing AstVarId
freshId
      var :: AstVarName FullSpan x
      var = FullShapeTK x
-> Maybe (Int64, Int64) -> AstVarId -> AstVarName FullSpan x
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK x
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing AstVarId
freshId
      astVarPrimal :: AstTensor AstMethodShare PrimalSpan x
      !astVarPrimal = AstVarName PrimalSpan x -> AstTensor AstMethodShare PrimalSpan x
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar AstVarName PrimalSpan x
varPrimal
      astVarD :: AstTensor AstMethodLet FullSpan x
      !astVarD = AstVarName FullSpan x -> AstTensor AstMethodLet FullSpan x
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar AstVarName FullSpan x
var
  return (varPrimal, astVarPrimal, var, astVarD)

funToAstFwdIO :: forall x.
                 FullShapeTK x
              -> IO ( AstVarName PrimalSpan (ADTensorKind x)
                    , AstTensor AstMethodShare PrimalSpan (ADTensorKind x)
                    , AstVarName PrimalSpan x
                    , AstTensor AstMethodShare PrimalSpan x
                    , AstVarName FullSpan x
                    , AstTensor AstMethodLet FullSpan x )
{-# INLINE funToAstFwdIO #-}
funToAstFwdIO :: forall (x :: TK).
FullShapeTK x
-> IO
     (AstVarName PrimalSpan (ADTensorKind x),
      AstTensor AstMethodShare PrimalSpan (ADTensorKind x),
      AstVarName PrimalSpan x, AstTensor AstMethodShare PrimalSpan x,
      AstVarName FullSpan x, AstTensor AstMethodLet FullSpan x)
funToAstFwdIO FullShapeTK x
ftk = do
  !freshIdDs <- IO AstVarId
unsafeGetFreshAstVarId
  !freshId <- unsafeGetFreshAstVarId
  let varPrimalD :: AstVarName PrimalSpan (ADTensorKind x)
      varPrimalD = FullShapeTK (ADTensorKind x)
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName PrimalSpan (ADTensorKind x)
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName (FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
ftk) Maybe (Int64, Int64)
forall a. Maybe a
Nothing AstVarId
freshIdDs
      varPrimal :: AstVarName PrimalSpan x
      varPrimal = FullShapeTK x
-> Maybe (Int64, Int64) -> AstVarId -> AstVarName PrimalSpan x
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK x
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing AstVarId
freshId
      var :: AstVarName FullSpan x
      var = FullShapeTK x
-> Maybe (Int64, Int64) -> AstVarId -> AstVarName FullSpan x
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK x
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing AstVarId
freshId
      astVarPrimalD :: AstTensor AstMethodShare PrimalSpan (ADTensorKind x)
      !astVarPrimalD = AstVarName PrimalSpan (ADTensorKind x)
-> AstTensor AstMethodShare PrimalSpan (ADTensorKind x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar AstVarName PrimalSpan (ADTensorKind x)
varPrimalD
      astVarPrimal :: AstTensor AstMethodShare PrimalSpan x
      !astVarPrimal = AstVarName PrimalSpan x -> AstTensor AstMethodShare PrimalSpan x
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar AstVarName PrimalSpan x
varPrimal
      astVarD :: AstTensor AstMethodLet FullSpan x
      !astVarD = AstVarName FullSpan x -> AstTensor AstMethodLet FullSpan x
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar AstVarName FullSpan x
var
  return (varPrimalD, astVarPrimalD, varPrimal, astVarPrimal, var, astVarD)

funToAstIntVarIO :: Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a)
                 -> IO a
{-# INLINE funToAstIntVarIO #-}
funToAstIntVarIO :: forall (ms :: AstMethodOfSharing) a.
Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a) -> IO a
funToAstIntVarIO Maybe (Int64, Int64)
bounds (IntVarName, AstInt ms) -> a
f = do
  !varName <- FullShapeTK (TKScalar Int64)
-> Maybe (Int64, Int64) -> IO IntVarName
forall (y :: TK) (s :: AstSpanType).
FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
unsafeGetFreshAstVarName (forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar @Int64) Maybe (Int64, Int64)
bounds
  return $! f (varName, astVar varName)

funToAstIntVar :: Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a) -> a
{-# NOINLINE funToAstIntVar #-}
funToAstIntVar :: forall (ms :: AstMethodOfSharing) a.
Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a) -> a
funToAstIntVar Maybe (Int64, Int64)
bounds = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a)
-> (((IntVarName, AstInt ms) -> a) -> IO a)
-> ((IntVarName, AstInt ms) -> a)
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a) -> IO a
forall (ms :: AstMethodOfSharing) a.
Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a) -> IO a
funToAstIntVarIO Maybe (Int64, Int64)
bounds

funToAstI :: Maybe (Int64, Int64) -> (AstInt ms -> t) -> (IntVarName, t)
{-# NOINLINE funToAstI #-}
funToAstI :: forall (ms :: AstMethodOfSharing) t.
Maybe (Int64, Int64) -> (AstInt ms -> t) -> (IntVarName, t)
funToAstI Maybe (Int64, Int64)
bounds AstInt ms -> t
f = IO (IntVarName, t) -> (IntVarName, t)
forall a. IO a -> a
unsafePerformIO (IO (IntVarName, t) -> (IntVarName, t))
-> (((IntVarName, AstInt ms) -> (IntVarName, t))
    -> IO (IntVarName, t))
-> ((IntVarName, AstInt ms) -> (IntVarName, t))
-> (IntVarName, t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Int64, Int64)
-> ((IntVarName, AstInt ms) -> (IntVarName, t))
-> IO (IntVarName, t)
forall (ms :: AstMethodOfSharing) a.
Maybe (Int64, Int64) -> ((IntVarName, AstInt ms) -> a) -> IO a
funToAstIntVarIO Maybe (Int64, Int64)
bounds
                     (((IntVarName, AstInt ms) -> (IntVarName, t)) -> (IntVarName, t))
-> ((IntVarName, AstInt ms) -> (IntVarName, t)) -> (IntVarName, t)
forall a b. (a -> b) -> a -> b
$ \ (!IntVarName
var, !AstInt ms
i) -> let !x :: t
x = AstInt ms -> t
f AstInt ms
i in (IntVarName
var, t
x)

funToVarsIxIOS
  :: forall sh a ms.
     ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> IO a
{-# INLINE funToVarsIxIOS #-}
funToVarsIxIOS :: forall (sh :: [Nat]) a (ms :: AstMethodOfSharing).
ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> IO a
funToVarsIxIOS ShS sh
sh (AstVarListS sh, AstIxS ms sh) -> a
f = ShS sh -> (KnownShS sh => IO a) -> IO a
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => IO a) -> IO a) -> (KnownShS sh => IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ do
  let freshBound :: a -> IO (AstVarName s (TKScalar Int64))
freshBound a
n =
        FullShapeTK (TKScalar Int64)
-> Maybe (Int64, Int64) -> IO (AstVarName s (TKScalar Int64))
forall (y :: TK) (s :: AstSpanType).
FullShapeTK y -> Maybe (Int64, Int64) -> IO (AstVarName s y)
unsafeGetFreshAstVarName (forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar @Int64)
                                 ((Int64, Int64) -> Maybe (Int64, Int64)
forall a. a -> Maybe a
Just (Int64
0, a -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
1))
  !varList <- (Int -> IO IntVarName) -> [Int] -> IO [IntVarName]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM Int -> IO IntVarName
forall {a} {s :: AstSpanType}.
Integral a =>
a -> IO (AstVarName s (TKScalar Int64))
freshBound ([Int] -> IO [IntVarName]) -> [Int] -> IO [IntVarName]
forall a b. (a -> b) -> a -> b
$ ShS sh -> [Int]
forall (sh :: [Nat]). ShS sh -> [Int]
shsToList ShS sh
sh
  let !vars = [Item (AstVarListS sh)] -> AstVarListS sh
forall l. IsList l => [Item l] -> l
fromList [Item (AstVarListS sh)]
[IntVarName]
varList
  let !ix = [Item (AstIxS ms sh)] -> AstIxS ms sh
forall l. IsList l => [Item l] -> l
fromList ([Item (AstIxS ms sh)] -> AstIxS ms sh)
-> [Item (AstIxS ms sh)] -> AstIxS ms sh
forall a b. (a -> b) -> a -> b
$ (IntVarName -> AstTensor ms PrimalSpan (TKScalar Int64))
-> [IntVarName] -> [AstTensor ms PrimalSpan (TKScalar Int64)]
forall a b. (a -> b) -> [a] -> [b]
map IntVarName -> AstTensor ms PrimalSpan (TKScalar Int64)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar [IntVarName]
varList
  return $! f (vars, ix)

funToVarsIxS
  :: ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> a
{-# NOINLINE funToVarsIxS #-}
funToVarsIxS :: forall (sh :: [Nat]) (ms :: AstMethodOfSharing) a.
ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> a
funToVarsIxS ShS sh
sh = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a)
-> (((AstVarListS sh, AstIxS ms sh) -> a) -> IO a)
-> ((AstVarListS sh, AstIxS ms sh) -> a)
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> IO a
forall (sh :: [Nat]) a (ms :: AstMethodOfSharing).
ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> IO a
funToVarsIxIOS ShS sh
sh

funToAstIxS
  :: ShS sh -> (AstIxS ms sh -> AstIxS ms sh2)
  -> (AstVarListS sh, AstIxS ms sh2)
{-# NOINLINE funToAstIxS #-}
funToAstIxS :: forall (sh :: [Nat]) (ms :: AstMethodOfSharing) (sh2 :: [Nat]).
ShS sh
-> (AstIxS ms sh -> AstIxS ms sh2)
-> (AstVarListS sh, AstIxS ms sh2)
funToAstIxS ShS sh
sh AstIxS ms sh -> AstIxS ms sh2
f = IO (AstVarListS sh, AstIxS ms sh2)
-> (AstVarListS sh, AstIxS ms sh2)
forall a. IO a -> a
unsafePerformIO (IO (AstVarListS sh, AstIxS ms sh2)
 -> (AstVarListS sh, AstIxS ms sh2))
-> IO (AstVarListS sh, AstIxS ms sh2)
-> (AstVarListS sh, AstIxS ms sh2)
forall a b. (a -> b) -> a -> b
$ ShS sh
-> ((AstVarListS sh, AstIxS ms sh)
    -> (AstVarListS sh, AstIxS ms sh2))
-> IO (AstVarListS sh, AstIxS ms sh2)
forall (sh :: [Nat]) a (ms :: AstMethodOfSharing).
ShS sh -> ((AstVarListS sh, AstIxS ms sh) -> a) -> IO a
funToVarsIxIOS ShS sh
sh
                   (((AstVarListS sh, AstIxS ms sh)
  -> (AstVarListS sh, AstIxS ms sh2))
 -> IO (AstVarListS sh, AstIxS ms sh2))
-> ((AstVarListS sh, AstIxS ms sh)
    -> (AstVarListS sh, AstIxS ms sh2))
-> IO (AstVarListS sh, AstIxS ms sh2)
forall a b. (a -> b) -> a -> b
$ \ (!AstVarListS sh
vars, !AstIxS ms sh
ix) -> let !x :: AstIxS ms sh2
x = AstIxS ms sh -> AstIxS ms sh2
f AstIxS ms sh
ix in (AstVarListS sh
vars, AstIxS ms sh2
x)