{-# LANGUAGE QuantifiedConstraints #-}
-- | The environment datatype and operations for creating and accessing it.
module HordeAd.Core.AstEnv
  ( AstEnv, emptyEnv, showsPrecAstEnv
  , extendEnv, extendEnvI, extendEnvVarsS
  ) where

import Prelude

import Data.Coerce (coerce)
import Data.Dependent.EnumMap.Strict (DEnumMap)
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Dependent.Sum
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Text.Show (showListWith)

import Data.Array.Nested.Shaped.Shape

import HordeAd.Core.Ast
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types

-- | The environment that keeps values assigned to variables
-- during interpretation.
type AstEnv :: Target -> Type
type AstEnv target = DEnumMap (AstVarName FullSpan) target
  -- We can't easily index over span and tensor kind at once,
  -- so instead we represent PrimalSpan values as FullSpan
  -- (dual number) values with zero dual component and DualSpan values
  -- as FullSpan values with zero primal component.

emptyEnv :: AstEnv target
emptyEnv :: forall (target :: Target). AstEnv target
emptyEnv = DEnumMap @TK (AstVarName FullSpan) target
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
DMap.empty

showsPrecAstEnv
  :: (forall y. KnownSTK y => Show (target y))
  => Int -> AstEnv target -> ShowS
showsPrecAstEnv :: forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (target y)) =>
Int -> AstEnv target -> ShowS
showsPrecAstEnv Int
d AstEnv target
demap =
  Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
    String -> ShowS
showString String
"fromList "
    ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DSum @TK (AstVarName FullSpan) target -> ShowS)
-> [DSum @TK (AstVarName FullSpan) target] -> ShowS
forall a. (a -> ShowS) -> [a] -> ShowS
showListWith
        (\(AstVarName FullSpan a
k :=> target a
target) ->
           SingletonTK a -> (KnownSTK a => ShowS) -> ShowS
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK a -> SingletonTK a
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK a -> SingletonTK a) -> FullShapeTK a -> SingletonTK a
forall a b. (a -> b) -> a -> b
$ AstVarName FullSpan a -> FullShapeTK a
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName FullSpan a
k) ((KnownSTK a => ShowS) -> ShowS) -> (KnownSTK a => ShowS) -> ShowS
forall a b. (a -> b) -> a -> b
$
           Int -> AstVarName FullSpan a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
2 AstVarName FullSpan a
k ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" :=> " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> target a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
1 target a
target)
        (AstEnv target -> [DSum @TK (AstVarName FullSpan) target]
forall {k1} (k2 :: k1 -> Type) (v :: k1 -> Type).
Enum1 @k1 k2 =>
DEnumMap @k1 k2 v -> [DSum @k1 k2 v]
DMap.toList AstEnv target
demap)

extendEnv :: forall target s y.
             AstVarName s y -> target y -> AstEnv target
          -> AstEnv target
extendEnv :: forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName s y
var !target y
t !AstEnv target
env =
  let var2 :: AstVarName FullSpan y
      var2 :: AstVarName FullSpan y
var2 = AstVarName s y -> AstVarName FullSpan y
forall a b. Coercible @Type a b => a -> b
coerce AstVarName s y
var  -- only FullSpan variables permitted in env; see above
  in (AstVarName FullSpan y -> target y -> target y -> target y)
-> AstVarName FullSpan y
-> target y
-> AstEnv target
-> AstEnv target
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
(k a -> v a -> v a -> v a)
-> k a -> v a -> DEnumMap @kind k v -> DEnumMap @kind k v
DMap.insertWithKey (\AstVarName FullSpan y
_ target y
_ target y
_ -> String -> target y
forall a. HasCallStack => String -> a
error (String -> target y) -> String -> target y
forall a b. (a -> b) -> a -> b
$ String
"extendEnv: duplicate " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AstVarName s y -> String
forall a. Show a => a -> String
show AstVarName s y
var)
                        AstVarName FullSpan y
var2 target y
t AstEnv target
env

extendEnvI :: BaseTensor target
           => IntVarName -> IntOf target -> AstEnv target
           -> AstEnv target
extendEnvI :: forall (target :: Target).
BaseTensor target =>
IntVarName -> IntOf target -> AstEnv target -> AstEnv target
extendEnvI IntVarName
var !IntOf target
i !AstEnv target
env = IntVarName
-> target (TKScalar Int64) -> AstEnv target -> AstEnv target
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv IntVarName
var (SingletonTK (TKScalar Int64)
-> IntOf target -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar IntOf target
i) AstEnv target
env

extendEnvVarsS :: forall target sh. BaseTensor target
               => AstVarListS sh -> IxSOf target sh -> AstEnv target
               -> AstEnv target
extendEnvVarsS :: forall (target :: Target) (sh :: [Nat]).
BaseTensor target =>
AstVarListS sh -> IxSOf target sh -> AstEnv target -> AstEnv target
extendEnvVarsS AstVarListS sh
vars !IxSOf target sh
ix !AstEnv target
env =
  let assocs :: [(IntVarName, PrimalOf target (TKScalar Int64))]
assocs = [IntVarName]
-> [PrimalOf target (TKScalar Int64)]
-> [(IntVarName, PrimalOf target (TKScalar Int64))]
forall a b. [a] -> [b] -> [(a, b)]
zip (AstVarListS sh -> [IntVarName]
forall (sh :: [Nat]) i. ListS sh (Const @Nat i) -> [i]
listsToList AstVarListS sh
vars) (IxSOf target sh -> [PrimalOf target (TKScalar Int64)]
forall a. IxS sh a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList IxSOf target sh
ix)
  in ((IntVarName, PrimalOf target (TKScalar Int64))
 -> AstEnv target -> AstEnv target)
-> AstEnv target
-> [(IntVarName, PrimalOf target (TKScalar Int64))]
-> AstEnv target
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((IntVarName
 -> PrimalOf target (TKScalar Int64)
 -> AstEnv target
 -> AstEnv target)
-> (IntVarName, PrimalOf target (TKScalar Int64))
-> AstEnv target
-> AstEnv target
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry IntVarName
-> PrimalOf target (TKScalar Int64)
-> AstEnv target
-> AstEnv target
forall (target :: Target).
BaseTensor target =>
IntVarName -> IntOf target -> AstEnv target -> AstEnv target
extendEnvI) AstEnv target
env [(IntVarName, PrimalOf target (TKScalar Int64))]
assocs