{-# LANGUAGE CPP #-}
module HordeAd.Core.AstInterpret
( interpretAstFull, interpretAstPrimal, interpretAstDual, interpretAst
, interpretAstBool
) where
import Prelude
import Data.Coerce (coerce)
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import Type.Reflection (typeRep)
import Data.Array.Nested.Shaped.Shape
import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
import HordeAd.Core.AstTools
import HordeAd.Core.ConvertTensor
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
#ifdef WITH_EXPENSIVE_ASSERTIONS
import Control.Exception.Assert.Sugar
#endif
interpretAstFull
:: forall target y. ADReady target
=> AstEnv target -> AstTensor AstMethodLet FullSpan y
-> target y
{-# INLINE interpretAstFull #-}
interpretAstFull :: forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull = AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst
interpretAstPrimal
:: forall target y. ADReady target
=> AstEnv target -> AstTensor AstMethodLet PrimalSpan y
-> PrimalOf target y
interpretAstPrimal :: forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal !AstEnv target
env AstTensor AstMethodLet PrimalSpan y
v1 = case AstTensor AstMethodLet PrimalSpan y
v1 of
AstMapAccumRDer SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk AstHFun
PrimalSpan PrimalSpan (TKProduct accy ey) (TKProduct accy by)
f0 AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
df0 AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
rf0 AstTensor AstMethodLet PrimalSpan accy
acc0 AstTensor AstMethodLet PrimalSpan (BuildTensorKind k ey)
es ->
let f :: HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
f = AstEnv target
-> AstHFun
PrimalSpan PrimalSpan (TKProduct accy ey) (TKProduct accy by)
-> HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
env AstHFun
PrimalSpan PrimalSpan (TKProduct accy ey) (TKProduct accy by)
f0
df :: HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df = AstEnv target
-> AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
-> HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
env AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df0
rf :: HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf = AstEnv target
-> AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
env AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf0
acc02 :: PrimalOf target accy
acc02 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan accy -> PrimalOf target accy
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan accy
acc0
es2 :: PrimalOf target (BuildTensorKind k ey)
es2 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (BuildTensorKind k ey)
-> PrimalOf target (BuildTensorKind k ey)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan (BuildTensorKind k ey)
es
in Proxy @Target (PrimalOf target)
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> PrimalOf target accy
-> PrimalOf target (BuildTensorKind k ey)
-> PrimalOf target (TKProduct accy (BuildTensorKind k by))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target (PrimalOf target)
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> PrimalOf target accy
-> PrimalOf target (BuildTensorKind k ey)
-> PrimalOf target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
(k :: Nat).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumRDer (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @(PrimalOf target))
SNat k
k (AstTensor AstMethodLet PrimalSpan accy -> FullShapeTK accy
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet PrimalSpan accy
acc0) FullShapeTK by
bftk FullShapeTK ey
eftk HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
f HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf PrimalOf target accy
acc02 PrimalOf target (BuildTensorKind k ey)
es2
AstMapAccumLDer SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk AstHFun
PrimalSpan PrimalSpan (TKProduct accy ey) (TKProduct accy by)
f0 AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
df0 AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
rf0 AstTensor AstMethodLet PrimalSpan accy
acc0 AstTensor AstMethodLet PrimalSpan (BuildTensorKind k ey)
es ->
let f :: HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
f = AstEnv target
-> AstHFun
PrimalSpan PrimalSpan (TKProduct accy ey) (TKProduct accy by)
-> HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
env AstHFun
PrimalSpan PrimalSpan (TKProduct accy ey) (TKProduct accy by)
f0
df :: HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df = AstEnv target
-> AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
-> HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
env AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df0
rf :: HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf = AstEnv target
-> AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
env AstHFun
PrimalSpan
PrimalSpan
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
AstHFun
PrimalSpan
PrimalSpan
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf0
acc02 :: PrimalOf target accy
acc02 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan accy -> PrimalOf target accy
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan accy
acc0
es2 :: PrimalOf target (BuildTensorKind k ey)
es2 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (BuildTensorKind k ey)
-> PrimalOf target (BuildTensorKind k ey)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan (BuildTensorKind k ey)
es
in Proxy @Target (PrimalOf target)
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> PrimalOf target accy
-> PrimalOf target (BuildTensorKind k ey)
-> PrimalOf target (TKProduct accy (BuildTensorKind k by))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target (PrimalOf target)
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> PrimalOf target accy
-> PrimalOf target (BuildTensorKind k ey)
-> PrimalOf target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
(k :: Nat).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumLDer (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @(PrimalOf target))
SNat k
k (AstTensor AstMethodLet PrimalSpan accy -> FullShapeTK accy
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet PrimalSpan accy
acc0) FullShapeTK by
bftk FullShapeTK ey
eftk HFunOf (PrimalOf target) (TKProduct accy ey) (TKProduct accy by)
f HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df HFunOf
(PrimalOf target)
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
HFunOf
(PrimalOf target)
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf PrimalOf target accy
acc02 PrimalOf target (BuildTensorKind k ey)
es2
AstCond AstBool AstMethodLet
b AstTensor AstMethodLet PrimalSpan y
a1 AstTensor AstMethodLet PrimalSpan y
a2 ->
let c :: BoolOf target
c = AstEnv target -> AstBool AstMethodLet -> BoolOf target
forall (target :: Target).
ADReady target =>
AstEnv target -> AstBool AstMethodLet -> BoolOf target
interpretAstBool AstEnv target
env AstBool AstMethodLet
b
in SingletonTK y
-> BoolOf (PrimalOf target)
-> PrimalOf target y
-> PrimalOf target y
-> PrimalOf target y
forall (y :: TK).
Boolean (BoolOf (PrimalOf target)) =>
SingletonTK y
-> BoolOf (PrimalOf target)
-> PrimalOf target y
-> PrimalOf target y
-> PrimalOf target y
forall (target :: Target) (y :: TK).
(BaseTensor target, Boolean (BoolOf target)) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
tcond (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK y -> SingletonTK y) -> FullShapeTK y -> SingletonTK y
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet PrimalSpan y
a1) BoolOf target
BoolOf (PrimalOf target)
c
(AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan y
a1) (AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan y
a2)
AstTensor AstMethodLet PrimalSpan y
_ -> target y -> PrimalOf target y
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (AstEnv target -> AstTensor AstMethodLet PrimalSpan y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan y
v1)
interpretAstDual
:: forall target y. ADReady target
=> AstEnv target -> AstTensor AstMethodLet DualSpan y
-> DualOf target y
{-# INLINE interpretAstDual #-}
interpretAstDual :: forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet DualSpan y -> DualOf target y
interpretAstDual !AstEnv target
env AstTensor AstMethodLet DualSpan y
v1 =
SingletonTK y -> target y -> DualOf target y
forall (y :: TK). SingletonTK y -> target y -> DualOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> DualOf target y
tdualPart (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK y -> SingletonTK y) -> FullShapeTK y -> SingletonTK y
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet DualSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet DualSpan y
v1) (AstEnv target -> AstTensor AstMethodLet DualSpan y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet DualSpan y
v1)
interpretAst
:: forall target s y. (ADReady target, AstSpan s)
=> AstEnv target -> AstTensor AstMethodLet s y
-> target y
interpretAst :: forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst !AstEnv target
env = \case
AstPair AstTensor AstMethodLet s y
t1 AstTensor AstMethodLet s z
t2 -> target y -> target z -> target (TKProduct y z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
t1) (AstEnv target -> AstTensor AstMethodLet s z -> target z
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s z
t2)
AstProject1 AstTensor AstMethodLet s (TKProduct y z)
t -> target (TKProduct y z) -> target y
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (AstEnv target
-> AstTensor AstMethodLet s (TKProduct y z)
-> target (TKProduct y z)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKProduct y z)
t)
AstProject2 AstTensor AstMethodLet s (TKProduct y y)
t -> target (TKProduct y y) -> target y
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (AstEnv target
-> AstTensor AstMethodLet s (TKProduct y y)
-> target (TKProduct y y)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKProduct y y)
t)
AstFromVector SNat k
snat SingletonTK y
stk Vector (AstTensor AstMethodLet s y)
l ->
let l2 :: Vector (target y)
l2 = (AstTensor AstMethodLet s y -> target y)
-> Vector (AstTensor AstMethodLet s y) -> Vector (target y)
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env) Vector (AstTensor AstMethodLet s y)
l
in SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Nat).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector SNat k
snat SingletonTK y
stk Vector (target y)
l2
AstSum SNat k
snat SingletonTK y
stk AstTensor AstMethodLet s (BuildTensorKind k y)
v -> SNat k -> SingletonTK y -> target (BuildTensorKind k y) -> target y
forall (z :: TK) (k :: Nat).
ConvertTensor target =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
tsum SNat k
snat SingletonTK y
stk (target (BuildTensorKind k y) -> target y)
-> target (BuildTensorKind k y) -> target y
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (BuildTensorKind k y)
-> target (BuildTensorKind k y)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (BuildTensorKind k y)
v
AstReplicate SNat k
snat SingletonTK y
stk AstTensor AstMethodLet s y
v ->
SNat k -> SingletonTK y -> target y -> target (BuildTensorKind k y)
forall (z :: TK) (k :: Nat).
ConvertTensor target =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat k
snat SingletonTK y
stk (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
v)
AstMapAccumRDer SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk AstHFun s s (TKProduct accy ey) (TKProduct accy by)
f0 AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
df0 AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
rf0 AstTensor AstMethodLet s accy
acc0 AstTensor AstMethodLet s (BuildTensorKind k ey)
es ->
let f :: HFunOf target (TKProduct accy ey) (TKProduct accy by)
f = AstEnv target
-> AstHFun s s (TKProduct accy ey) (TKProduct accy by)
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun s s (TKProduct accy ey) (TKProduct accy by)
f0
df :: HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df = AstEnv target
-> AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
-> HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df0
rf :: HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf = AstEnv target
-> AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf0
acc02 :: target accy
acc02 = AstEnv target -> AstTensor AstMethodLet s accy -> target accy
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s accy
acc0
es2 :: target (BuildTensorKind k ey)
es2 = AstEnv target
-> AstTensor AstMethodLet s (BuildTensorKind k ey)
-> target (BuildTensorKind k ey)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (BuildTensorKind k ey)
es
in Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
(k :: Nat).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumRDer (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target) SNat k
k (AstTensor AstMethodLet s accy -> FullShapeTK accy
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s accy
acc0) FullShapeTK by
bftk FullShapeTK ey
eftk HFunOf target (TKProduct accy ey) (TKProduct accy by)
f HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf target accy
acc02 target (BuildTensorKind k ey)
es2
AstMapAccumLDer SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk AstHFun s s (TKProduct accy ey) (TKProduct accy by)
f0 AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
df0 AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
rf0 AstTensor AstMethodLet s accy
acc0 AstTensor AstMethodLet s (BuildTensorKind k ey)
es ->
let f :: HFunOf target (TKProduct accy ey) (TKProduct accy by)
f = AstEnv target
-> AstHFun s s (TKProduct accy ey) (TKProduct accy by)
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun s s (TKProduct accy ey) (TKProduct accy by)
f0
df :: HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df = AstEnv target
-> AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
-> HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df0
rf :: HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf = AstEnv target
-> AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
AstHFun
s
s
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf0
acc02 :: target accy
acc02 = AstEnv target -> AstTensor AstMethodLet s accy -> target accy
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s accy
acc0
es2 :: target (BuildTensorKind k ey)
es2 = AstEnv target
-> AstTensor AstMethodLet s (BuildTensorKind k ey)
-> target (BuildTensorKind k ey)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (BuildTensorKind k ey)
es
in Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
(k :: Nat).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumLDer (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target) SNat k
k (AstTensor AstMethodLet s accy -> FullShapeTK accy
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s accy
acc0) FullShapeTK by
bftk FullShapeTK ey
eftk HFunOf target (TKProduct accy ey) (TKProduct accy by)
f HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df HFunOf
target
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
HFunOf
target
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf target accy
acc02 target (BuildTensorKind k ey)
es2
AstApply AstHFun s1 s x y
t AstTensor AstMethodLet s1 x
ll ->
let t2 :: HFunOf target x y
t2 = AstEnv target -> AstHFun s1 s x y -> HFunOf target x y
forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
env AstHFun s1 s x y
t
ll2 :: target x
ll2 = AstEnv target -> AstTensor AstMethodLet s1 x -> target x
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s1 x
ll
in HFunOf target x y -> target x -> target y
forall (x :: TK) (z :: TK).
HFunOf target x z -> target x -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
HFunOf target x z -> target x -> target z
tApply HFunOf target x y
t2 target x
ll2
AstVar AstVarName s y
var ->
let var2 :: AstVarName FullSpan y
var2 :: AstVarName FullSpan y
var2 = AstVarName s y -> AstVarName FullSpan y
forall a b. Coercible @Type a b => a -> b
coerce AstVarName s y
var
in case AstVarName FullSpan y -> AstEnv target -> Maybe (target y)
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
k a -> DEnumMap @kind k v -> Maybe (v a)
DMap.lookup AstVarName FullSpan y
var2 AstEnv target
env of
Just target y
t ->
#ifdef WITH_EXPENSIVE_ASSERTIONS
withKnownSTK (ftkToSTK $ varNameToFTK var) $
assert (tftk (ftkToSTK $ varNameToFTK var) t == varNameToFTK var
`blame` ( tftk (ftkToSTK $ varNameToFTK var) t
, varNameToFTK var, var, t ))
#endif
target y
t
Maybe (target y)
_ -> [Char] -> target y
forall a. HasCallStack => [Char] -> a
error ([Char] -> target y) -> [Char] -> target y
forall a b. (a -> b) -> a -> b
$ [Char]
"interpretAst: unknown AstVar " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ AstVarName s y -> [Char]
forall a. Show a => a -> [Char]
show AstVarName s y
var
AstCond AstBool AstMethodLet
b AstTensor AstMethodLet s y
a1 AstTensor AstMethodLet s y
a2 ->
let c :: BoolOf target
c = AstEnv target -> AstBool AstMethodLet -> BoolOf target
forall (target :: Target).
ADReady target =>
AstEnv target -> AstBool AstMethodLet -> BoolOf target
interpretAstBool AstEnv target
env AstBool AstMethodLet
b
in SingletonTK y -> BoolOf target -> target y -> target y -> target y
forall (y :: TK).
Boolean (BoolOf target) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
(BaseTensor target, Boolean (BoolOf target)) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
tcond (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s y
a1)) BoolOf target
c
(AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
a1) (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
a2)
AstBuild1 SNat k
snat SingletonTK y
stk (IntVarName
var, AstTensor AstMethodLet s y
v) ->
let f :: PrimalOf target (TKScalar Int64) -> target y
f PrimalOf target (TKScalar Int64)
i = AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst (IntVarName
-> PrimalOf target (TKScalar Int64)
-> AstEnv target
-> AstEnv target
forall (target :: Target).
BaseTensor target =>
IntVarName -> IntOf target -> AstEnv target -> AstEnv target
extendEnvI IntVarName
var PrimalOf target (TKScalar Int64)
i AstEnv target
env) AstTensor AstMethodLet s y
v
in SNat k
-> SingletonTK y
-> (PrimalOf target (TKScalar Int64) -> target y)
-> target (BuildTensorKind k y)
forall (y :: TK) (k :: Nat).
ConvertTensor target =>
SNat k
-> SingletonTK y
-> (PrimalOf target (TKScalar Int64) -> target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y
-> (IntOf target -> target y)
-> target (BuildTensorKind k y)
tbuild1 SNat k
snat SingletonTK y
stk PrimalOf target (TKScalar Int64) -> target y
f
AstLet AstVarName s y
var AstTensor AstMethodLet s y
u AstTensor AstMethodLet s y
v ->
let t :: target y
t = AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
u
env2 :: target y -> AstEnv target
env2 target y
w = AstVarName s y -> target y -> AstEnv target -> AstEnv target
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName s y
var target y
w AstEnv target
env
in target y -> (target y -> target y) -> target y
forall (x :: TK) (z :: TK).
target x -> (target x -> target z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet target y
t (\target y
w -> AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst (target y -> AstEnv target
env2 target y
w) AstTensor AstMethodLet s y
v)
AstPrimalPart AstTensor AstMethodLet FullSpan y
a ->
SingletonTK y -> PrimalOf target y -> target y
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet FullSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet FullSpan y
a)) (target y -> PrimalOf target y
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target y -> PrimalOf target y) -> target y -> PrimalOf target y
forall a b. (a -> b) -> a -> b
$ AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv target
env AstTensor AstMethodLet FullSpan y
a)
AstDualPart AstTensor AstMethodLet FullSpan y
a ->
DualOf target y -> target y
forall (y :: TK). DualOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
DualOf target y -> target y
tfromDual (SingletonTK y -> target y -> DualOf target y
forall (y :: TK). SingletonTK y -> target y -> DualOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> DualOf target y
tdualPart (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet FullSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet FullSpan y
a)) (target y -> DualOf target y) -> target y -> DualOf target y
forall a b. (a -> b) -> a -> b
$ AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv target
env AstTensor AstMethodLet FullSpan y
a)
AstFromPrimal AstTensor AstMethodLet PrimalSpan y
a ->
SingletonTK y -> PrimalOf target y -> target y
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet PrimalSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet PrimalSpan y
a)) (AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan y
a)
AstFromDual AstTensor AstMethodLet DualSpan y
a -> AstEnv target -> AstTensor AstMethodLet DualSpan y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet DualSpan y
a
AstPlusK AstTensor AstMethodLet s (TKScalar r)
u AstTensor AstMethodLet s (TKScalar r)
v -> AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKScalar r)
u target y -> target y -> target y
forall a. Num a => a -> a -> a
+ AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKScalar r)
v
AstTimesK AstTensor AstMethodLet s (TKScalar r)
u AstTensor AstMethodLet s (TKScalar r)
v -> AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKScalar r)
u target y -> target y -> target y
forall a. Num a => a -> a -> a
* AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKScalar r)
v
AstN1K OpCodeNum1
opCode AstTensor AstMethodLet s (TKScalar r)
u ->
let u2 :: target (TKScalar r)
u2 = AstEnv target
-> AstTensor AstMethodLet s (TKScalar r) -> target (TKScalar r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r)
u
in OpCodeNum1 -> target y -> target y
forall a. Num a => OpCodeNum1 -> a -> a
interpretAstN1 OpCodeNum1
opCode target y
target (TKScalar r)
u2
AstR1K OpCode1
opCode AstTensor AstMethodLet s (TKScalar r)
u ->
let u2 :: target (TKScalar r)
u2 = AstEnv target
-> AstTensor AstMethodLet s (TKScalar r) -> target (TKScalar r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r)
u
in OpCode1 -> target y -> target y
forall a. Floating a => OpCode1 -> a -> a
interpretAstR1 OpCode1
opCode target y
target (TKScalar r)
u2
AstR2K OpCode2
opCode AstTensor AstMethodLet s (TKScalar r)
u AstTensor AstMethodLet s (TKScalar r)
v ->
let u2 :: target (TKScalar r)
u2 = AstEnv target
-> AstTensor AstMethodLet s (TKScalar r) -> target (TKScalar r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r)
u
v2 :: target (TKScalar r)
v2 = AstEnv target
-> AstTensor AstMethodLet s (TKScalar r) -> target (TKScalar r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r)
v
in OpCode2 -> target y -> target y -> target y
forall a. RealFloatH a => OpCode2 -> a -> a -> a
interpretAstR2 OpCode2
opCode target y
target (TKScalar r)
u2 target y
target (TKScalar r)
v2
AstI2K OpCodeIntegral2
opCode AstTensor AstMethodLet s (TKScalar r)
u AstTensor AstMethodLet s (TKScalar r)
v ->
let u2 :: target (TKScalar r)
u2 = AstEnv target
-> AstTensor AstMethodLet s (TKScalar r) -> target (TKScalar r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r)
u
v2 :: target (TKScalar r)
v2 = AstEnv target
-> AstTensor AstMethodLet s (TKScalar r) -> target (TKScalar r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r)
v
in OpCodeIntegral2 -> target y -> target y -> target y
forall a. IntegralH a => OpCodeIntegral2 -> a -> a -> a
interpretAstI2 OpCodeIntegral2
opCode target y
target (TKScalar r)
u2 target y
target (TKScalar r)
v2
AstConcreteK r
k ->
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
r -> target (TKScalar r)
tkconcrete @target r
k
AstFloorK AstTensor AstMethodLet PrimalSpan (TKScalar r1)
v ->
target (TKScalar r1) -> target (TKScalar r2)
forall r r2.
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
target (TKScalar r) -> target (TKScalar r2)
forall (target :: Target) r r2.
(BaseTensor target, GoodScalar r, RealFrac r, GoodScalar r2,
Integral r2) =>
target (TKScalar r) -> target (TKScalar r2)
tkfloor (target (TKScalar r1) -> target (TKScalar r2))
-> target (TKScalar r1) -> target (TKScalar r2)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar r1)
-> target (TKScalar r1)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKScalar r1)
v
AstFromIntegralK AstTensor AstMethodLet PrimalSpan (TKScalar r1)
v ->
target (TKScalar r1) -> target (TKScalar r2)
forall r1 r2.
(GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
forall (target :: Target) r1 r2.
(BaseTensor target, GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
tkfromIntegral (target (TKScalar r1) -> target (TKScalar r2))
-> target (TKScalar r1) -> target (TKScalar r2)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar r1)
-> target (TKScalar r1)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKScalar r1)
v
AstCastK AstTensor AstMethodLet s (TKScalar r1)
v -> target (TKScalar r1) -> target (TKScalar r2)
forall r1 r2.
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
forall (target :: Target) r1 r2.
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
tkcast (target (TKScalar r1) -> target (TKScalar r2))
-> target (TKScalar r1) -> target (TKScalar r2)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKScalar r1) -> target (TKScalar r1)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKScalar r1)
v
AstPlusS AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v -> AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u target y -> target y -> target y
forall a. Num a => a -> a -> a
+ AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v
AstTimesS AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v -> AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u target y -> target y -> target y
forall a. Num a => a -> a -> a
* AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v
AstN1S OpCodeNum1
opCode AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u -> OpCodeNum1 -> target y -> target y
forall a. Num a => OpCodeNum1 -> a -> a
interpretAstN1 OpCodeNum1
opCode (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u)
AstR1S OpCode1
opCode AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u -> OpCode1 -> target y -> target y
forall a. Floating a => OpCode1 -> a -> a
interpretAstR1 OpCode1
opCode (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u)
AstR2S OpCode2
opCode AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v ->
OpCode2 -> target y -> target y -> target y
forall a. RealFloatH a => OpCode2 -> a -> a -> a
interpretAstR2 OpCode2
opCode (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u) (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v)
AstI2S OpCodeIntegral2
opCode AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v ->
OpCodeIntegral2 -> target y -> target y -> target y
forall a. IntegralH a => OpCodeIntegral2 -> a -> a -> a
interpretAstI2 OpCodeIntegral2
opCode (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
u) (AstEnv target -> AstTensor AstMethodLet s y -> target y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s y
AstTensor AstMethodLet s (TKS2 sh (TKScalar r))
v)
AstConcreteS Shaped sh r
a -> Shaped sh r -> target (TKS2 sh (TKScalar r))
forall r (sh :: [Nat]).
GoodScalar r =>
Shaped sh r -> target (TKS sh r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, GoodScalar r) =>
Shaped sh r -> target (TKS sh r)
tsconcrete Shaped sh r
a
AstFloorS AstTensor AstMethodLet PrimalSpan (TKS sh r1)
v ->
target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall r r2 (sh :: [Nat]).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
target (TKS sh r) -> target (TKS sh r2)
forall (target :: Target) r r2 (sh :: [Nat]).
(BaseTensor target, GoodScalar r, RealFrac r, GoodScalar r2,
Integral r2) =>
target (TKS sh r) -> target (TKS sh r2)
tsfloor (target (TKS sh r1) -> target (TKS2 sh (TKScalar r2)))
-> target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKS sh r1)
-> target (TKS sh r1)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKS sh r1)
v
AstFromIntegralS AstTensor AstMethodLet PrimalSpan (TKS sh r1)
v ->
target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall r1 r2 (sh :: [Nat]).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tsfromIntegral (target (TKS sh r1) -> target (TKS2 sh (TKScalar r2)))
-> target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKS sh r1)
-> target (TKS sh r1)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKS sh r1)
v
AstCastS @r1 @r2 AstTensor AstMethodLet s (TKS sh r1)
v ->
case TypeRep @Type r1
-> TypeRep @Type Double -> Maybe ((:~:) @Type r1 Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
Just (:~:) @Type r1 Double
Refl -> case TypeRep @Type r2
-> TypeRep @Type Float -> Maybe ((:~:) @Type r2 Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
Just (:~:) @Type r2 Float
Refl -> forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast @_ @Double @Float (target (TKS sh Double) -> target (TKS sh Float))
-> target (TKS sh Double) -> target (TKS sh Float)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS sh Double)
-> target (TKS sh Double)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r1)
AstTensor AstMethodLet s (TKS sh Double)
v
Maybe ((:~:) @Type r2 Float)
_ -> forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast @_ @Double (target (TKS sh Double) -> target (TKS2 sh (TKScalar r2)))
-> target (TKS sh Double) -> target (TKS2 sh (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS sh Double)
-> target (TKS sh Double)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r1)
AstTensor AstMethodLet s (TKS sh Double)
v
Maybe ((:~:) @Type r1 Double)
_ -> case TypeRep @Type r1
-> TypeRep @Type Float -> Maybe ((:~:) @Type r1 Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
Just (:~:) @Type r1 Float
Refl -> case TypeRep @Type r2
-> TypeRep @Type Double -> Maybe ((:~:) @Type r2 Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
Just (:~:) @Type r2 Double
Refl -> forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast @_ @Float @Double (target (TKS sh Float) -> target (TKS sh Double))
-> target (TKS sh Float) -> target (TKS sh Double)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS sh Float) -> target (TKS sh Float)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r1)
AstTensor AstMethodLet s (TKS sh Float)
v
Maybe ((:~:) @Type r2 Double)
_ -> forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast @_ @Float (target (TKS sh Float) -> target (TKS2 sh (TKScalar r2)))
-> target (TKS sh Float) -> target (TKS2 sh (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS sh Float) -> target (TKS sh Float)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r1)
AstTensor AstMethodLet s (TKS sh Float)
v
Maybe ((:~:) @Type r1 Float)
_ -> target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall r1 r2 (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast (target (TKS sh r1) -> target (TKS2 sh (TKScalar r2)))
-> target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS sh r1) -> target (TKS sh r1)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r1)
v
AstIndexS @sh1 ShS shn
sh2 AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
v AstIxS AstMethodLet shm
ix -> case FullShapeTK (TKS2 ((++) @Nat shm shn) x)
-> SingletonTK (TKS2 ((++) @Nat shm shn) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
ShS shm -> (KnownShS shm => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (AstIxS AstMethodLet shm -> ShS shm
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS AstIxS AstMethodLet shm
ix) ((KnownShS shm => target y) -> target y)
-> (KnownShS shm => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS shn -> (KnownShS shn => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
sh2 ((KnownShS shn => target y) -> target y)
-> (KnownShS shn => target y) -> target y
forall a b. (a -> b) -> a -> b
$
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
let v2 :: target (TKS2 sh x)
v2 = AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
v
ix3 :: IxS shm (PrimalOf target (TKScalar Int64))
ix3 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env (AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64))
-> AstIxS AstMethodLet shm
-> IxS shm (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> AstIxS AstMethodLet shm
ix
in forall (target :: Target) (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex @target @sh1 target (TKS2 sh x)
target (TKS2 ((++) @Nat shm shn) x)
v2 IxS shm (PrimalOf target (TKScalar Int64))
ix3
AstScatterS @shm @shn @shp
ShS shn
shn AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
v (AstVarListS shm
vars, AstIxS AstMethodLet shp
ix) -> case FullShapeTK (TKS2 ((++) @Nat shm shn) x)
-> SingletonTK (TKS2 ((++) @Nat shm shn) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
ShS shm -> (KnownShS shm => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (AstVarListS shm -> ShS shm
forall (sh :: [Nat]) (f :: Nat -> Type). ListS sh f -> ShS sh
shsFromListS AstVarListS shm
vars) ((KnownShS shm => target y) -> target y)
-> (KnownShS shm => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS shn -> (KnownShS shn => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => target y) -> target y)
-> (KnownShS shn => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS shp -> (KnownShS shp => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (AstIxS AstMethodLet shp -> ShS shp
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS AstIxS AstMethodLet shp
ix) ((KnownShS shp => target y) -> target y)
-> (KnownShS shp => target y) -> target y
forall a b. (a -> b) -> a -> b
$
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
let t1 :: target (TKS2 sh x)
t1 = AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
AstTensor AstMethodLet s (TKS2 ((++) @Nat shm shn) x)
v
f2 :: IxSOf target shm -> IxSOf target shp
f2 :: IxSOf target shm -> IxSOf target shp
f2 !IxSOf target shm
ix2 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal (AstVarListS shm
-> IxSOf target shm -> AstEnv target -> AstEnv target
forall (target :: Target) (sh :: [Nat]).
BaseTensor target =>
AstVarListS sh -> IxSOf target sh -> AstEnv target -> AstEnv target
extendEnvVarsS AstVarListS shm
vars IxSOf target shm
ix2 AstEnv target
env) (AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64))
-> AstIxS AstMethodLet shp -> IxSOf target shp
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> AstIxS AstMethodLet shp
ix
in forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
(shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shp shn) x)
tsscatter @_ @shm @shn @shp target (TKS2 sh x)
target (TKS2 ((++) @Nat shm shn) x)
t1 IxSOf target shm -> IxSOf target shp
f2
AstGatherS ShS shn
shn AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
v (AstVarListS shm
ZS, AstIxS AstMethodLet shp
ix) -> case FullShapeTK (TKS2 ((++) @Nat shp shn) x)
-> SingletonTK (TKS2 ((++) @Nat shp shn) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
ShS shn -> (KnownShS shn => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => target y) -> target y)
-> (KnownShS shn => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS shp -> (KnownShS shp => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (AstIxS AstMethodLet shp -> ShS shp
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS AstIxS AstMethodLet shp
ix) ((KnownShS shp => target y) -> target y)
-> (KnownShS shp => target y) -> target y
forall a b. (a -> b) -> a -> b
$
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
target (TKS2 ((++) @Nat shp shn) x)
-> IxSOf target shp -> target (TKS2 shn x)
forall (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex (AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
v) (AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env (AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64))
-> AstIxS AstMethodLet shp -> IxSOf target shp
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> AstIxS AstMethodLet shp
ix)
AstGatherS @shm @shn @shp
ShS shn
shn AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
v (AstVarListS shm
vars, AstIxS AstMethodLet shp
ix) -> case FullShapeTK (TKS2 ((++) @Nat shp shn) x)
-> SingletonTK (TKS2 ((++) @Nat shp shn) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
ShS shm -> (KnownShS shm => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (AstVarListS shm -> ShS shm
forall (sh :: [Nat]) (f :: Nat -> Type). ListS sh f -> ShS sh
shsFromListS AstVarListS shm
vars) ((KnownShS shm => target y) -> target y)
-> (KnownShS shm => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS shn -> (KnownShS shn => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => target y) -> target y)
-> (KnownShS shn => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS shp -> (KnownShS shp => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (AstIxS AstMethodLet shp -> ShS shp
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS AstIxS AstMethodLet shp
ix) ((KnownShS shp => target y) -> target y)
-> (KnownShS shp => target y) -> target y
forall a b. (a -> b) -> a -> b
$
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
let t1 :: target (TKS2 sh x)
t1 = AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
AstTensor AstMethodLet s (TKS2 ((++) @Nat shp shn) x)
v
f2 :: IxSOf target shm -> IxSOf target shp
f2 :: IxSOf target shm -> IxSOf target shp
f2 !IxSOf target shm
ix2 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal (AstVarListS shm
-> IxSOf target shm -> AstEnv target -> AstEnv target
forall (target :: Target) (sh :: [Nat]).
BaseTensor target =>
AstVarListS sh -> IxSOf target sh -> AstEnv target -> AstEnv target
extendEnvVarsS AstVarListS shm
vars IxSOf target shm
ix2 AstEnv target
env) (AstTensor AstMethodLet PrimalSpan (TKScalar Int64)
-> PrimalOf target (TKScalar Int64))
-> AstIxS AstMethodLet shp -> IxSOf target shp
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> AstIxS AstMethodLet shp
ix
in forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
(shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
KnownSTK x) =>
target (TKS2 ((++) @Nat shp shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shm shn) x)
tsgather @_ @shm @shn @shp target (TKS2 sh x)
target (TKS2 ((++) @Nat shp shn) x)
t1 IxSOf target shm -> IxSOf target shp
f2
AstMinIndexS AstTensor AstMethodLet PrimalSpan (TKS ((':) @Nat n sh) r)
v ->
target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (n :: Nat) (sh :: [Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) r r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
tsminIndex (target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2)))
-> target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKS ((':) @Nat n sh) r)
-> target (TKS ((':) @Nat n sh) r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKS ((':) @Nat n sh) r)
v
AstMaxIndexS AstTensor AstMethodLet PrimalSpan (TKS ((':) @Nat n sh) r)
v ->
target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (n :: Nat) (sh :: [Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) r r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
tsmaxIndex (target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2)))
-> target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKS ((':) @Nat n sh) r)
-> target (TKS ((':) @Nat n sh) r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKS ((':) @Nat n sh) r)
v
AstIotaS SNat n
SNat -> target y
target (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))
forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, KnownNat n, GoodScalar r) =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
tsiota
AstAppendS AstTensor AstMethodLet s (TKS2 ((':) @Nat m sh) x)
a AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
b -> case FullShapeTK (TKS2 ((':) @Nat m sh) x)
-> SingletonTK (TKS2 ((':) @Nat m sh) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((':) @Nat m sh) x)
-> FullShapeTK (TKS2 ((':) @Nat m sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((':) @Nat m sh) x)
a) of
STKS ShS sh
_ SingletonTK x
x ->
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
let t1 :: target (TKS2 ((':) @Nat m sh) x)
t1 = AstEnv target
-> AstTensor AstMethodLet s (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat m sh) x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 ((':) @Nat m sh) x)
a
t2 :: target (TKS2 ((':) @Nat n sh) x)
t2 = AstEnv target
-> AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
b
in target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
tsappend target (TKS2 ((':) @Nat m sh) x)
t1 target (TKS2 ((':) @Nat n sh) x)
t2
AstSliceS SNat i
i SNat n
n SNat k
k AstTensor AstMethodLet s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v -> case FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> SingletonTK (TKS2 ((':) @Nat ((i + n) + k) sh) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsslice SNat i
i SNat n
n SNat k
k (target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x))
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v
AstReverseS AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
v -> case FullShapeTK (TKS2 ((':) @Nat n sh) x)
-> SingletonTK (TKS2 ((':) @Nat n sh) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
-> FullShapeTK (TKS2 ((':) @Nat n sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsreverse (AstEnv target
-> AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 ((':) @Nat n sh) x)
v)
AstTransposeS Perm perm
perm AstTensor AstMethodLet s (TKS2 sh x)
v -> case FullShapeTK (TKS2 sh x) -> SingletonTK (TKS2 sh x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 sh x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(BaseTensor target, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose Perm perm
perm (target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x))
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
forall a b. (a -> b) -> a -> b
$ AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
v
AstReshapeS ShS sh2
sh2 AstTensor AstMethodLet s (TKS2 sh x)
v -> case FullShapeTK (TKS2 sh x) -> SingletonTK (TKS2 sh x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 sh x)
v) of
STKS ShS sh
_ SingletonTK x
x ->
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
((Product sh :: Nat) ~ (Product sh2 :: Nat), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(BaseTensor target, (Product sh :: Nat) ~ (Product sh2 :: Nat),
KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS sh2
sh2 (AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
v)
AstConvert TKConversion a1 y
c AstTensor AstMethodLet s a1
a ->
TKConversion a1 y -> SingletonTK a1 -> target a1 -> target y
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion a1 y
c (FullShapeTK a1 -> SingletonTK a1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s a1 -> FullShapeTK a1
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s a1
a)) (AstEnv target -> AstTensor AstMethodLet s a1 -> target a1
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s a1
a)
AstSum0S AstTensor AstMethodLet s (TKS2 sh x)
v -> case FullShapeTK (TKS2 sh x) -> SingletonTK (TKS2 sh x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (AstTensor AstMethodLet s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS2 sh x)
v) of
STKS ShS sh
sh SingletonTK x
x ->
ShS sh -> (KnownShS sh => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target y) -> target y)
-> (KnownShS sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
SingletonTK x -> (KnownSTK x => target y) -> target y
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK SingletonTK x
x ((KnownSTK x => target y) -> target y)
-> (KnownSTK x => target y) -> target y
forall a b. (a -> b) -> a -> b
$
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
tssum0 (AstEnv target
-> AstTensor AstMethodLet s (TKS2 sh x) -> target (TKS2 sh x)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS2 sh x)
v)
AstDot0S AstTensor AstMethodLet s (TKS sh r)
u AstTensor AstMethodLet s (TKS sh r)
v -> case AstTensor AstMethodLet s (TKS sh r) -> FullShapeTK (TKS sh r)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s (TKS sh r)
u of
FTKS ShS sh
sh FullShapeTK x
_ ->
ShS sh -> (KnownShS sh => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target y) -> target y)
-> (KnownShS sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
target (TKS sh r)
-> target (TKS sh r) -> target (TKS2 ('[] @Nat) (TKScalar r))
forall (sh :: [Nat]) r.
(KnownShS sh, GoodScalar r) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
forall (target :: Target) (sh :: [Nat]) r.
(BaseTensor target, KnownShS sh, GoodScalar r) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
tsdot0 (AstEnv target
-> AstTensor AstMethodLet s (TKS sh r) -> target (TKS sh r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r)
u) (AstEnv target
-> AstTensor AstMethodLet s (TKS sh r) -> target (TKS sh r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor AstMethodLet s (TKS sh r)
v)
AstDot1InS @sh @n ShS sh
sh SNat n
SNat AstTensor
AstMethodLet s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
u AstTensor
AstMethodLet s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
v ->
ShS sh -> (KnownShS sh => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target y) -> target y)
-> (KnownShS sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
forall (target :: Target) (sh :: [Nat]) r (n :: Nat).
(BaseTensor target, KnownShS sh, GoodScalar r) =>
SNat n
-> target (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
-> target (TKS sh r)
tsdot1In @_ @sh (forall (n :: Nat). KnownNat n => SNat n
SNat @n) (AstEnv target
-> AstTensor
AstMethodLet s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor
AstMethodLet s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
u) (AstEnv target
-> AstTensor
AstMethodLet s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor
AstMethodLet s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
v)
AstMatmul2S SNat m
SNat SNat n
SNat SNat p
SNat AstTensor
AstMethodLet s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
u AstTensor
AstMethodLet s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
v ->
target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target
(TKS2 ((':) @Nat m ((':) @Nat p ('[] @Nat))) (TKScalar r))
forall (m :: Nat) (n :: Nat) (p :: Nat) r.
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r) =>
target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target (TKS ((':) @Nat m ((':) @Nat p ('[] @Nat))) r)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat) r.
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
GoodScalar r) =>
target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target (TKS ((':) @Nat m ((':) @Nat p ('[] @Nat))) r)
tsmatmul2 (AstEnv target
-> AstTensor
AstMethodLet s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor
AstMethodLet s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
u) (AstEnv target
-> AstTensor
AstMethodLet s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst AstEnv target
env AstTensor
AstMethodLet s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
v)
interpretAstHFun
:: forall target x y s s2. (AstSpan s2, BaseTensor target)
=> AstEnv target -> AstHFun s s2 x y
-> HFunOf target x y
{-# INLINE interpretAstHFun #-}
interpretAstHFun :: forall (target :: Target) (x :: TK) (y :: TK) (s :: AstSpanType)
(s2 :: AstSpanType).
(AstSpan s2, BaseTensor target) =>
AstEnv target -> AstHFun s s2 x y -> HFunOf target x y
interpretAstHFun AstEnv target
_env (AstLambda AstVarName s x
var AstTensor AstMethodLet s2 y
l) =
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target (AstVarName s x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName s x
var)
(HFun x y -> HFunOf target x y) -> HFun x y -> HFunOf target x y
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target). ADReady f => f x -> f y) -> HFun x y
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun ((forall (f :: Target). ADReady f => f x -> f y) -> HFun x y)
-> (forall (f :: Target). ADReady f => f x -> f y) -> HFun x y
forall a b. (a -> b) -> a -> b
$ \f x
ws -> AstEnv f -> AstTensor AstMethodLet s2 y -> f y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst (AstVarName s x -> f x -> AstEnv f -> AstEnv f
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName s x
var f x
ws AstEnv f
forall (target :: Target). AstEnv target
emptyEnv) AstTensor AstMethodLet s2 y
l
interpretAstHFunPrimal
:: forall target x y. ADReady target
=> AstEnv target -> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
{-# INLINE interpretAstHFunPrimal #-}
interpretAstHFunPrimal :: forall (target :: Target) (x :: TK) (y :: TK).
ADReady target =>
AstEnv target
-> AstHFun PrimalSpan PrimalSpan x y
-> HFunOf (PrimalOf target) x y
interpretAstHFunPrimal AstEnv target
_env (AstLambda AstVarName PrimalSpan x
var AstTensor AstMethodLet PrimalSpan y
l) =
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @(PrimalOf target) (AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
var)
(HFun x y -> HFunOf (PrimalOf target) x y)
-> HFun x y -> HFunOf (PrimalOf target) x y
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target). ADReady f => f x -> f y) -> HFun x y
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun ((forall (f :: Target). ADReady f => f x -> f y) -> HFun x y)
-> (forall (f :: Target). ADReady f => f x -> f y) -> HFun x y
forall a b. (a -> b) -> a -> b
$ \f x
ws -> AstEnv f -> AstTensor AstMethodLet PrimalSpan y -> f y
forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst (AstVarName PrimalSpan x -> f x -> AstEnv f -> AstEnv f
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
var f x
ws AstEnv f
forall (target :: Target). AstEnv target
emptyEnv) AstTensor AstMethodLet PrimalSpan y
l
interpretAstBool :: ADReady target
=> AstEnv target -> AstBool AstMethodLet
-> BoolOf target
interpretAstBool :: forall (target :: Target).
ADReady target =>
AstEnv target -> AstBool AstMethodLet -> BoolOf target
interpretAstBool !AstEnv target
env = \case
AstBoolConst Bool
a -> if Bool
a then BoolOf target
forall b. Boolean b => b
true else BoolOf target
forall b. Boolean b => b
false
AstBoolNot AstBool AstMethodLet
arg -> BoolOf target -> BoolOf target
forall b. Boolean b => b -> b
notB (BoolOf target -> BoolOf target) -> BoolOf target -> BoolOf target
forall a b. (a -> b) -> a -> b
$ AstEnv target -> AstBool AstMethodLet -> BoolOf target
forall (target :: Target).
ADReady target =>
AstEnv target -> AstBool AstMethodLet -> BoolOf target
interpretAstBool AstEnv target
env AstBool AstMethodLet
arg
AstBoolAnd AstBool AstMethodLet
arg1 AstBool AstMethodLet
arg2 ->
let b1 :: BoolOf target
b1 = AstEnv target -> AstBool AstMethodLet -> BoolOf target
forall (target :: Target).
ADReady target =>
AstEnv target -> AstBool AstMethodLet -> BoolOf target
interpretAstBool AstEnv target
env AstBool AstMethodLet
arg1
b2 :: BoolOf target
b2 = AstEnv target -> AstBool AstMethodLet -> BoolOf target
forall (target :: Target).
ADReady target =>
AstEnv target -> AstBool AstMethodLet -> BoolOf target
interpretAstBool AstEnv target
env AstBool AstMethodLet
arg2
in BoolOf target
b1 BoolOf target -> BoolOf target -> BoolOf target
forall b. Boolean b => b -> b -> b
&&* BoolOf target
b2
AstLeqK AstTensor AstMethodLet PrimalSpan (TKScalar r)
arg1 AstTensor AstMethodLet PrimalSpan (TKScalar r)
arg2 ->
let r1 :: PrimalOf target (TKScalar r)
r1 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar r)
-> PrimalOf target (TKScalar r)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKScalar r)
arg1
r2 :: PrimalOf target (TKScalar r)
r2 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKScalar r)
-> PrimalOf target (TKScalar r)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKScalar r)
arg2
in PrimalOf target (TKScalar r)
r1 PrimalOf target (TKScalar r)
-> PrimalOf target (TKScalar r) -> BoolOf (PrimalOf target)
forall (f :: Target) (y :: TK). OrdH f y => f y -> f y -> BoolOf f
<=. PrimalOf target (TKScalar r)
r2
AstLeqS AstTensor AstMethodLet PrimalSpan (TKS sh r)
arg1 AstTensor AstMethodLet PrimalSpan (TKS sh r)
arg2 ->
let r1 :: PrimalOf target (TKS sh r)
r1 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKS sh r)
-> PrimalOf target (TKS sh r)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKS sh r)
arg1
r2 :: PrimalOf target (TKS sh r)
r2 = AstEnv target
-> AstTensor AstMethodLet PrimalSpan (TKS sh r)
-> PrimalOf target (TKS sh r)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv target
env AstTensor AstMethodLet PrimalSpan (TKS sh r)
arg2
in PrimalOf target (TKS sh r)
r1 PrimalOf target (TKS sh r)
-> PrimalOf target (TKS sh r) -> BoolOf (PrimalOf target)
forall (f :: Target) (y :: TK). OrdH f y => f y -> f y -> BoolOf f
<=. PrimalOf target (TKS sh r)
r2
interpretAstN1 :: Num a
=> OpCodeNum1 -> a -> a
{-# INLINE interpretAstN1 #-}
interpretAstN1 :: forall a. Num a => OpCodeNum1 -> a -> a
interpretAstN1 OpCodeNum1
NegateOp a
u = a -> a
forall a. Num a => a -> a
negate a
u
interpretAstN1 OpCodeNum1
AbsOp a
u = a -> a
forall a. Num a => a -> a
abs a
u
interpretAstN1 OpCodeNum1
SignumOp a
u = a -> a
forall a. Num a => a -> a
signum a
u
interpretAstR1 :: Floating a
=> OpCode1 -> a -> a
{-# INLINE interpretAstR1 #-}
interpretAstR1 :: forall a. Floating a => OpCode1 -> a -> a
interpretAstR1 OpCode1
RecipOp a
u = a -> a
forall a. Fractional a => a -> a
recip a
u
interpretAstR1 OpCode1
ExpOp a
u = a -> a
forall a. Floating a => a -> a
exp a
u
interpretAstR1 OpCode1
LogOp a
u = a -> a
forall a. Floating a => a -> a
log a
u
interpretAstR1 OpCode1
SqrtOp a
u = a -> a
forall a. Floating a => a -> a
sqrt a
u
interpretAstR1 OpCode1
SinOp a
u = a -> a
forall a. Floating a => a -> a
sin a
u
interpretAstR1 OpCode1
CosOp a
u = a -> a
forall a. Floating a => a -> a
cos a
u
interpretAstR1 OpCode1
TanOp a
u = a -> a
forall a. Floating a => a -> a
tan a
u
interpretAstR1 OpCode1
AsinOp a
u = a -> a
forall a. Floating a => a -> a
asin a
u
interpretAstR1 OpCode1
AcosOp a
u = a -> a
forall a. Floating a => a -> a
acos a
u
interpretAstR1 OpCode1
AtanOp a
u = a -> a
forall a. Floating a => a -> a
atan a
u
interpretAstR1 OpCode1
SinhOp a
u = a -> a
forall a. Floating a => a -> a
sinh a
u
interpretAstR1 OpCode1
CoshOp a
u = a -> a
forall a. Floating a => a -> a
cosh a
u
interpretAstR1 OpCode1
TanhOp a
u = a -> a
forall a. Floating a => a -> a
tanh a
u
interpretAstR1 OpCode1
AsinhOp a
u = a -> a
forall a. Floating a => a -> a
asinh a
u
interpretAstR1 OpCode1
AcoshOp a
u = a -> a
forall a. Floating a => a -> a
acosh a
u
interpretAstR1 OpCode1
AtanhOp a
u = a -> a
forall a. Floating a => a -> a
atanh a
u
interpretAstR2 :: RealFloatH a
=> OpCode2 -> a -> a -> a
{-# INLINE interpretAstR2 #-}
interpretAstR2 :: forall a. RealFloatH a => OpCode2 -> a -> a -> a
interpretAstR2 OpCode2
DivideOp a
u a
v = a
u a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
v
interpretAstR2 OpCode2
PowerOp a
u a
v = a
u a -> a -> a
forall a. Floating a => a -> a -> a
** a
v
interpretAstR2 OpCode2
LogBaseOp a
u a
v = a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
u a
v
interpretAstR2 OpCode2
Atan2Op a
u a
v = a -> a -> a
forall a. RealFloatH a => a -> a -> a
atan2H a
u a
v
interpretAstI2 :: IntegralH a
=> OpCodeIntegral2 -> a -> a -> a
{-# INLINE interpretAstI2 #-}
interpretAstI2 :: forall a. IntegralH a => OpCodeIntegral2 -> a -> a -> a
interpretAstI2 OpCodeIntegral2
QuotOp a
u a
v = a -> a -> a
forall a. IntegralH a => a -> a -> a
quotH a
u a
v
interpretAstI2 OpCodeIntegral2
RemOp a
u a
v = a -> a -> a
forall a. IntegralH a => a -> a -> a
remH a
u a
v