{-# LANGUAGE QuantifiedConstraints #-}
module HordeAd.Core.AstEnv
( AstEnv, emptyEnv, showsPrecAstEnv
, extendEnv, extendEnvI, extendEnvVarsS
) where
import Prelude
import Data.Coerce (coerce)
import Data.Dependent.EnumMap.Strict (DEnumMap)
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Dependent.Sum
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Text.Show (showListWith)
import Data.Array.Nested.Shaped.Shape
import HordeAd.Core.Ast
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
type AstEnv :: Target -> Type
type AstEnv target = DEnumMap (AstVarName FullSpan) target
emptyEnv :: AstEnv target
emptyEnv :: forall (target :: Target). AstEnv target
emptyEnv = DEnumMap @TK (AstVarName FullSpan) target
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
DMap.empty
showsPrecAstEnv
:: (forall y. KnownSTK y => Show (target y))
=> Int -> AstEnv target -> ShowS
showsPrecAstEnv :: forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (target y)) =>
Int -> AstEnv target -> ShowS
showsPrecAstEnv Int
d AstEnv target
demap =
Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
String -> ShowS
showString String
"fromList "
ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DSum @TK (AstVarName FullSpan) target -> ShowS)
-> [DSum @TK (AstVarName FullSpan) target] -> ShowS
forall a. (a -> ShowS) -> [a] -> ShowS
showListWith
(\(AstVarName FullSpan a
k :=> target a
target) ->
SingletonTK a -> (KnownSTK a => ShowS) -> ShowS
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK a -> SingletonTK a
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK a -> SingletonTK a) -> FullShapeTK a -> SingletonTK a
forall a b. (a -> b) -> a -> b
$ AstVarName FullSpan a -> FullShapeTK a
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName FullSpan a
k) ((KnownSTK a => ShowS) -> ShowS) -> (KnownSTK a => ShowS) -> ShowS
forall a b. (a -> b) -> a -> b
$
Int -> AstVarName FullSpan a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
2 AstVarName FullSpan a
k ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" :=> " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> target a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
1 target a
target)
(AstEnv target -> [DSum @TK (AstVarName FullSpan) target]
forall {k1} (k2 :: k1 -> Type) (v :: k1 -> Type).
Enum1 @k1 k2 =>
DEnumMap @k1 k2 v -> [DSum @k1 k2 v]
DMap.toList AstEnv target
demap)
extendEnv :: forall target s y.
AstVarName s y -> target y -> AstEnv target
-> AstEnv target
extendEnv :: forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName s y
var !target y
t !AstEnv target
env =
let var2 :: AstVarName FullSpan y
var2 :: AstVarName FullSpan y
var2 = AstVarName s y -> AstVarName FullSpan y
forall a b. Coercible @Type a b => a -> b
coerce AstVarName s y
var
in (AstVarName FullSpan y -> target y -> target y -> target y)
-> AstVarName FullSpan y
-> target y
-> AstEnv target
-> AstEnv target
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
(k a -> v a -> v a -> v a)
-> k a -> v a -> DEnumMap @kind k v -> DEnumMap @kind k v
DMap.insertWithKey (\AstVarName FullSpan y
_ target y
_ target y
_ -> String -> target y
forall a. HasCallStack => String -> a
error (String -> target y) -> String -> target y
forall a b. (a -> b) -> a -> b
$ String
"extendEnv: duplicate " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AstVarName s y -> String
forall a. Show a => a -> String
show AstVarName s y
var)
AstVarName FullSpan y
var2 target y
t AstEnv target
env
extendEnvI :: BaseTensor target
=> IntVarName -> IntOf target -> AstEnv target
-> AstEnv target
extendEnvI :: forall (target :: Target).
BaseTensor target =>
IntVarName -> IntOf target -> AstEnv target -> AstEnv target
extendEnvI IntVarName
var !IntOf target
i !AstEnv target
env = IntVarName
-> target (TKScalar Int64) -> AstEnv target -> AstEnv target
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv IntVarName
var (SingletonTK (TKScalar Int64)
-> IntOf target -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar IntOf target
i) AstEnv target
env
extendEnvVarsS :: forall target sh. BaseTensor target
=> AstVarListS sh -> IxSOf target sh -> AstEnv target
-> AstEnv target
AstVarListS sh
vars !IxSOf target sh
ix !AstEnv target
env =
let assocs :: [(IntVarName, PrimalOf target (TKScalar Int64))]
assocs = [IntVarName]
-> [PrimalOf target (TKScalar Int64)]
-> [(IntVarName, PrimalOf target (TKScalar Int64))]
forall a b. [a] -> [b] -> [(a, b)]
zip (AstVarListS sh -> [IntVarName]
forall (sh :: [Nat]) i. ListS sh (Const @Nat i) -> [i]
listsToList AstVarListS sh
vars) (IxSOf target sh -> [PrimalOf target (TKScalar Int64)]
forall a. IxS sh a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList IxSOf target sh
ix)
in ((IntVarName, PrimalOf target (TKScalar Int64))
-> AstEnv target -> AstEnv target)
-> AstEnv target
-> [(IntVarName, PrimalOf target (TKScalar Int64))]
-> AstEnv target
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((IntVarName
-> PrimalOf target (TKScalar Int64)
-> AstEnv target
-> AstEnv target)
-> (IntVarName, PrimalOf target (TKScalar Int64))
-> AstEnv target
-> AstEnv target
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry IntVarName
-> PrimalOf target (TKScalar Int64)
-> AstEnv target
-> AstEnv target
forall (target :: Target).
BaseTensor target =>
IntVarName -> IntOf target -> AstEnv target -> AstEnv target
extendEnvI) AstEnv target
env [(IntVarName, PrimalOf target (TKScalar Int64))]
assocs