{-# LANGUAGE ViewPatterns #-}
module HordeAd.Core.Ast
(
AstSpanType(..), AstSpan(..), sameAstSpan
, AstVarId, intToAstVarId
, AstInt, IntVarName, pattern AstIntVar
, AstVarName, mkAstVarName, varNameToAstVarId, varNameToFTK, varNameToBounds
, astVar
, AstArtifactRev(..), AstArtifactFwd(..)
, AstIxS, AstVarListS, pattern AstLeqInt
, AstMethodOfSharing(..), AstTensor(..)
, AstHFun(..)
, AstBool(..), OpCodeNum1(..), OpCode1(..), OpCode2(..), OpCodeIntegral2(..)
) where
import Prelude hiding (foldl')
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Functor.Const
import Data.Int (Int64)
import Data.Kind (Type)
import Data.Some
import Data.Type.Equality (TestEquality (..), (:~:) (Refl))
import Data.Vector.Strict qualified as Data.Vector
import GHC.TypeLits (type (+), type (<=))
import Type.Reflection (Typeable, typeRep)
import Data.Array.Nested (type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (Init)
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
type data AstSpanType = PrimalSpan | DualSpan | FullSpan
class Typeable s => AstSpan (s :: AstSpanType) where
fromPrimal :: AstTensor ms PrimalSpan y -> AstTensor ms s y
fromDual :: AstTensor ms DualSpan y -> AstTensor ms s y
primalPart :: AstTensor ms s y -> AstTensor ms PrimalSpan y
dualPart :: AstTensor ms s y -> AstTensor ms DualSpan y
instance AstSpan PrimalSpan where
fromPrimal :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms PrimalSpan y
fromPrimal = AstTensor ms PrimalSpan y -> AstTensor ms PrimalSpan y
forall a. a -> a
id
fromDual :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms DualSpan y -> AstTensor ms PrimalSpan y
fromDual AstTensor ms DualSpan y
t = AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
AstPrimalPart (AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y)
-> AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
forall a b. (a -> b) -> a -> b
$ AstTensor ms DualSpan y -> AstTensor ms FullSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms DualSpan y -> AstTensor ms FullSpan y
AstFromDual AstTensor ms DualSpan y
t
primalPart :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms PrimalSpan y
primalPart AstTensor ms PrimalSpan y
t = AstTensor ms PrimalSpan y
t
dualPart :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms DualSpan y
dualPart AstTensor ms PrimalSpan y
t = AstTensor ms FullSpan y -> AstTensor ms DualSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms FullSpan y -> AstTensor ms DualSpan y
AstDualPart (AstTensor ms FullSpan y -> AstTensor ms DualSpan y)
-> AstTensor ms FullSpan y -> AstTensor ms DualSpan y
forall a b. (a -> b) -> a -> b
$ AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
AstFromPrimal AstTensor ms PrimalSpan y
t
instance AstSpan DualSpan where
fromPrimal :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms DualSpan y
fromPrimal AstTensor ms PrimalSpan y
t = AstTensor ms FullSpan y -> AstTensor ms DualSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms FullSpan y -> AstTensor ms DualSpan y
AstDualPart (AstTensor ms FullSpan y -> AstTensor ms DualSpan y)
-> AstTensor ms FullSpan y -> AstTensor ms DualSpan y
forall a b. (a -> b) -> a -> b
$ AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
AstFromPrimal AstTensor ms PrimalSpan y
t
fromDual :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms DualSpan y -> AstTensor ms DualSpan y
fromDual = AstTensor ms DualSpan y -> AstTensor ms DualSpan y
forall a. a -> a
id
primalPart :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms DualSpan y -> AstTensor ms PrimalSpan y
primalPart AstTensor ms DualSpan y
t = AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
AstPrimalPart (AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y)
-> AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
forall a b. (a -> b) -> a -> b
$ AstTensor ms DualSpan y -> AstTensor ms FullSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms DualSpan y -> AstTensor ms FullSpan y
AstFromDual AstTensor ms DualSpan y
t
dualPart :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms DualSpan y -> AstTensor ms DualSpan y
dualPart AstTensor ms DualSpan y
t = AstTensor ms DualSpan y
t
instance AstSpan FullSpan where
fromPrimal :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
fromPrimal = AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
AstFromPrimal
fromDual :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms DualSpan y -> AstTensor ms FullSpan y
fromDual = AstTensor ms DualSpan y -> AstTensor ms FullSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms DualSpan y -> AstTensor ms FullSpan y
AstFromDual
primalPart :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
primalPart (AstFromPrimal AstTensor ms PrimalSpan y
t) = AstTensor ms PrimalSpan y
t
primalPart AstTensor ms FullSpan y
t = AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
AstPrimalPart AstTensor ms FullSpan y
t
dualPart :: forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms FullSpan y -> AstTensor ms DualSpan y
dualPart (AstFromDual AstTensor ms DualSpan y
t) = AstTensor ms DualSpan y
t
dualPart AstTensor ms FullSpan y
t = AstTensor ms FullSpan y -> AstTensor ms DualSpan y
forall (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms FullSpan y -> AstTensor ms DualSpan y
AstDualPart AstTensor ms FullSpan y
t
sameAstSpan :: forall s1 s2. (AstSpan s1, AstSpan s2) => Maybe (s1 :~: s2)
sameAstSpan :: forall (s1 :: AstSpanType) (s2 :: AstSpanType).
(AstSpan s1, AstSpan s2) =>
Maybe ((:~:) @AstSpanType s1 s2)
sameAstSpan = TypeRep @AstSpanType s1
-> TypeRep @AstSpanType s2 -> Maybe ((:~:) @AstSpanType s1 s2)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
forall (a :: AstSpanType) (b :: AstSpanType).
TypeRep @AstSpanType a
-> TypeRep @AstSpanType b -> Maybe ((:~:) @AstSpanType a b)
testEquality (forall {k} (a :: k). Typeable @k a => TypeRep @k a
forall (a :: AstSpanType).
Typeable @AstSpanType a =>
TypeRep @AstSpanType a
typeRep @s1) (forall {k} (a :: k). Typeable @k a => TypeRep @k a
forall (a :: AstSpanType).
Typeable @AstSpanType a =>
TypeRep @AstSpanType a
typeRep @s2)
newtype AstVarId = AstVarId Int
deriving (AstVarId -> AstVarId -> Bool
(AstVarId -> AstVarId -> Bool)
-> (AstVarId -> AstVarId -> Bool) -> Eq AstVarId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AstVarId -> AstVarId -> Bool
== :: AstVarId -> AstVarId -> Bool
$c/= :: AstVarId -> AstVarId -> Bool
/= :: AstVarId -> AstVarId -> Bool
Eq, Eq AstVarId
Eq AstVarId =>
(AstVarId -> AstVarId -> Ordering)
-> (AstVarId -> AstVarId -> Bool)
-> (AstVarId -> AstVarId -> Bool)
-> (AstVarId -> AstVarId -> Bool)
-> (AstVarId -> AstVarId -> Bool)
-> (AstVarId -> AstVarId -> AstVarId)
-> (AstVarId -> AstVarId -> AstVarId)
-> Ord AstVarId
AstVarId -> AstVarId -> Bool
AstVarId -> AstVarId -> Ordering
AstVarId -> AstVarId -> AstVarId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: AstVarId -> AstVarId -> Ordering
compare :: AstVarId -> AstVarId -> Ordering
$c< :: AstVarId -> AstVarId -> Bool
< :: AstVarId -> AstVarId -> Bool
$c<= :: AstVarId -> AstVarId -> Bool
<= :: AstVarId -> AstVarId -> Bool
$c> :: AstVarId -> AstVarId -> Bool
> :: AstVarId -> AstVarId -> Bool
$c>= :: AstVarId -> AstVarId -> Bool
>= :: AstVarId -> AstVarId -> Bool
$cmax :: AstVarId -> AstVarId -> AstVarId
max :: AstVarId -> AstVarId -> AstVarId
$cmin :: AstVarId -> AstVarId -> AstVarId
min :: AstVarId -> AstVarId -> AstVarId
Ord, Int -> AstVarId -> ShowS
[AstVarId] -> ShowS
AstVarId -> String
(Int -> AstVarId -> ShowS)
-> (AstVarId -> String) -> ([AstVarId] -> ShowS) -> Show AstVarId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AstVarId -> ShowS
showsPrec :: Int -> AstVarId -> ShowS
$cshow :: AstVarId -> String
show :: AstVarId -> String
$cshowList :: [AstVarId] -> ShowS
showList :: [AstVarId] -> ShowS
Show, Int -> AstVarId
AstVarId -> Int
AstVarId -> [AstVarId]
AstVarId -> AstVarId
AstVarId -> AstVarId -> [AstVarId]
AstVarId -> AstVarId -> AstVarId -> [AstVarId]
(AstVarId -> AstVarId)
-> (AstVarId -> AstVarId)
-> (Int -> AstVarId)
-> (AstVarId -> Int)
-> (AstVarId -> [AstVarId])
-> (AstVarId -> AstVarId -> [AstVarId])
-> (AstVarId -> AstVarId -> [AstVarId])
-> (AstVarId -> AstVarId -> AstVarId -> [AstVarId])
-> Enum AstVarId
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: AstVarId -> AstVarId
succ :: AstVarId -> AstVarId
$cpred :: AstVarId -> AstVarId
pred :: AstVarId -> AstVarId
$ctoEnum :: Int -> AstVarId
toEnum :: Int -> AstVarId
$cfromEnum :: AstVarId -> Int
fromEnum :: AstVarId -> Int
$cenumFrom :: AstVarId -> [AstVarId]
enumFrom :: AstVarId -> [AstVarId]
$cenumFromThen :: AstVarId -> AstVarId -> [AstVarId]
enumFromThen :: AstVarId -> AstVarId -> [AstVarId]
$cenumFromTo :: AstVarId -> AstVarId -> [AstVarId]
enumFromTo :: AstVarId -> AstVarId -> [AstVarId]
$cenumFromThenTo :: AstVarId -> AstVarId -> AstVarId -> [AstVarId]
enumFromThenTo :: AstVarId -> AstVarId -> AstVarId -> [AstVarId]
Enum)
intToAstVarId :: Int -> AstVarId
intToAstVarId :: Int -> AstVarId
intToAstVarId = Int -> AstVarId
AstVarId
type role AstVarName phantom nominal
data AstVarName :: AstSpanType -> TK -> Type where
AstVarName :: forall s y.
FullShapeTK y -> Int64 -> Int64 -> AstVarId
-> AstVarName s y
instance Eq (AstVarName s y) where
AstVarName FullShapeTK y
_ Int64
_ Int64
_ AstVarId
varId1 == :: AstVarName s y -> AstVarName s y -> Bool
== AstVarName FullShapeTK y
_ Int64
_ Int64
_ AstVarId
varId2 = AstVarId
varId1 AstVarId -> AstVarId -> Bool
forall a. Eq a => a -> a -> Bool
== AstVarId
varId2
instance Show (AstVarName s y) where
showsPrec :: Int -> AstVarName s y -> ShowS
showsPrec Int
d (AstVarName FullShapeTK y
_ Int64
_ Int64
_ AstVarId
varId) =
Int -> AstVarId -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d AstVarId
varId
instance DMap.Enum1 (AstVarName s) where
type Enum1Info (AstVarName s) = Some FtkAndBounds
fromEnum1 :: forall (a :: TK).
AstVarName s a -> (Int, Enum1Info @TK (AstVarName s))
fromEnum1 (AstVarName FullShapeTK a
ftk Int64
minb Int64
maxb AstVarId
varId) =
(AstVarId -> Int
forall a. Enum a => a -> Int
fromEnum AstVarId
varId, FtkAndBounds a -> Some @TK FtkAndBounds
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some (FullShapeTK a -> Int64 -> Int64 -> FtkAndBounds a
forall (y :: TK). FullShapeTK y -> Int64 -> Int64 -> FtkAndBounds y
FtkAndBounds FullShapeTK a
ftk Int64
minb Int64
maxb))
toEnum1 :: Int -> Enum1Info @TK (AstVarName s) -> Some @TK (AstVarName s)
toEnum1 Int
varIdInt (Some (FtkAndBounds FullShapeTK a
ftk Int64
minb Int64
maxb)) =
AstVarName s a -> Some @TK (AstVarName s)
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some (AstVarName s a -> Some @TK (AstVarName s))
-> AstVarName s a -> Some @TK (AstVarName s)
forall a b. (a -> b) -> a -> b
$ FullShapeTK a -> Int64 -> Int64 -> AstVarId -> AstVarName s a
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Int64 -> Int64 -> AstVarId -> AstVarName s y
AstVarName FullShapeTK a
ftk Int64
minb Int64
maxb (AstVarId -> AstVarName s a) -> AstVarId -> AstVarName s a
forall a b. (a -> b) -> a -> b
$ Int -> AstVarId
forall a. Enum a => Int -> a
toEnum Int
varIdInt
type role FtkAndBounds nominal
data FtkAndBounds y = FtkAndBounds (FullShapeTK y) Int64 Int64
instance TestEquality (AstVarName s) where
testEquality :: forall (a :: TK) (b :: TK).
AstVarName s a -> AstVarName s b -> Maybe ((:~:) @TK a b)
testEquality (AstVarName FullShapeTK a
ftk1 Int64
_ Int64
_ AstVarId
_) (AstVarName FullShapeTK b
ftk2 Int64
_ Int64
_ AstVarId
_) =
FullShapeTK a -> FullShapeTK b -> Maybe ((:~:) @TK a b)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK a
ftk1 FullShapeTK b
ftk2
mkAstVarName :: forall s y.
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId
-> AstVarName s y
mkAstVarName :: forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK y
ftk Maybe (Int64, Int64)
Nothing = FullShapeTK y -> Int64 -> Int64 -> AstVarId -> AstVarName s y
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Int64 -> Int64 -> AstVarId -> AstVarName s y
AstVarName FullShapeTK y
ftk (-Int64
1000000000) Int64
1000000000
mkAstVarName FullShapeTK y
ftk (Just (Int64
minb, Int64
maxb)) = FullShapeTK y -> Int64 -> Int64 -> AstVarId -> AstVarName s y
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Int64 -> Int64 -> AstVarId -> AstVarName s y
AstVarName FullShapeTK y
ftk Int64
minb Int64
maxb
varNameToAstVarId :: AstVarName s y -> AstVarId
varNameToAstVarId :: forall (s :: AstSpanType) (y :: TK). AstVarName s y -> AstVarId
varNameToAstVarId (AstVarName FullShapeTK y
_ Int64
_ Int64
_ AstVarId
varId) = AstVarId
varId
varNameToFTK :: AstVarName s y -> FullShapeTK y
varNameToFTK :: forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK (AstVarName FullShapeTK y
ftk Int64
_ Int64
_ AstVarId
_) = FullShapeTK y
ftk
varNameToBounds :: AstVarName s y -> Maybe (Int64, Int64)
varNameToBounds :: forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> Maybe (Int64, Int64)
varNameToBounds (AstVarName FullShapeTK y
_ Int64
minb Int64
maxb AstVarId
_) =
if Int64
minb Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== -Int64
1000000000 Bool -> Bool -> Bool
&& Int64
maxb Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
1000000000
then Maybe (Int64, Int64)
forall a. Maybe a
Nothing
else (Int64, Int64) -> Maybe (Int64, Int64)
forall a. a -> Maybe a
Just (Int64
minb, Int64
maxb)
astVar :: AstSpan s
=> AstVarName s y -> AstTensor ms s y
astVar :: forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
AstVarName s y -> AstTensor ms s y
astVar (AstVarName (FTKScalar @r) Int64
lb Int64
ub AstVarId
_)
| Int64
lb Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
ub
, Just (:~:) @Type r Int64
Refl <- TypeRep @Type r
-> TypeRep @Type Int64 -> Maybe ((:~:) @Type r Int64)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Int64) =
AstTensor ms PrimalSpan y -> AstTensor ms s y
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms s y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor ms PrimalSpan y -> AstTensor ms s y)
-> AstTensor ms PrimalSpan y -> AstTensor ms s y
forall a b. (a -> b) -> a -> b
$ Int64 -> AstTensor ms PrimalSpan (TKScalar Int64)
forall r (ms :: AstMethodOfSharing).
GoodScalar r =>
r -> AstTensor ms PrimalSpan (TKScalar r)
AstConcreteK Int64
lb
astVar AstVarName s y
varName = AstVarName s y -> AstTensor ms s y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstVarName s y -> AstTensor ms s y
AstVar AstVarName s y
varName
type role AstArtifactRev nominal nominal
data AstArtifactRev x z = AstArtifactRev
{ forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind z)
, forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDomainRev :: AstVarName PrimalSpan x
, forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
, forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artPrimalRev :: AstTensor AstMethodLet PrimalSpan z
}
deriving Int -> AstArtifactRev x z -> ShowS
[AstArtifactRev x z] -> ShowS
AstArtifactRev x z -> String
(Int -> AstArtifactRev x z -> ShowS)
-> (AstArtifactRev x z -> String)
-> ([AstArtifactRev x z] -> ShowS)
-> Show (AstArtifactRev x z)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (x :: TK) (z :: TK). Int -> AstArtifactRev x z -> ShowS
forall (x :: TK) (z :: TK). [AstArtifactRev x z] -> ShowS
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
$cshowsPrec :: forall (x :: TK) (z :: TK). Int -> AstArtifactRev x z -> ShowS
showsPrec :: Int -> AstArtifactRev x z -> ShowS
$cshow :: forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
show :: AstArtifactRev x z -> String
$cshowList :: forall (x :: TK) (z :: TK). [AstArtifactRev x z] -> ShowS
showList :: [AstArtifactRev x z] -> ShowS
Show
type role AstArtifactFwd nominal nominal
data AstArtifactFwd x z = AstArtifactFwd
{ forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstVarName PrimalSpan (ADTensorKind x)
artVarDsFwd :: AstVarName PrimalSpan (ADTensorKind x)
, forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstVarName PrimalSpan x
artVarDomainFwd :: AstVarName PrimalSpan x
, forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artDerivativeFwd :: AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
, forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstTensor AstMethodLet PrimalSpan z
artPrimalFwd :: AstTensor AstMethodLet PrimalSpan z
}
deriving Int -> AstArtifactFwd x z -> ShowS
[AstArtifactFwd x z] -> ShowS
AstArtifactFwd x z -> String
(Int -> AstArtifactFwd x z -> ShowS)
-> (AstArtifactFwd x z -> String)
-> ([AstArtifactFwd x z] -> ShowS)
-> Show (AstArtifactFwd x z)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (x :: TK) (z :: TK). Int -> AstArtifactFwd x z -> ShowS
forall (x :: TK) (z :: TK). [AstArtifactFwd x z] -> ShowS
forall (x :: TK) (z :: TK). AstArtifactFwd x z -> String
$cshowsPrec :: forall (x :: TK) (z :: TK). Int -> AstArtifactFwd x z -> ShowS
showsPrec :: Int -> AstArtifactFwd x z -> ShowS
$cshow :: forall (x :: TK) (z :: TK). AstArtifactFwd x z -> String
show :: AstArtifactFwd x z -> String
$cshowList :: forall (x :: TK) (z :: TK). [AstArtifactFwd x z] -> ShowS
showList :: [AstArtifactFwd x z] -> ShowS
Show
type AstInt ms = AstTensor ms PrimalSpan (TKScalar Int64)
type IntVarName = AstVarName PrimalSpan (TKScalar Int64)
pattern AstIntVar :: IntVarName -> AstInt ms
pattern $mAstIntVar :: forall {r} {ms :: AstMethodOfSharing}.
AstInt ms -> (IntVarName -> r) -> ((# #) -> r) -> r
AstIntVar var <- AstVar var
type AstVarListS sh = ListS sh (Const IntVarName)
type AstIxS ms sh = IxS sh (AstInt ms)
pattern AstLeqInt :: AstInt ms -> AstInt ms -> AstBool ms
pattern $mAstLeqInt :: forall {r} {ms :: AstMethodOfSharing}.
AstBool ms -> (AstInt ms -> AstInt ms -> r) -> ((# #) -> r) -> r
$bAstLeqInt :: forall (ms :: AstMethodOfSharing).
AstInt ms -> AstInt ms -> AstBool ms
AstLeqInt t u <- (matchAstLeqInt -> Just (t, u))
where AstLeqInt AstInt ms
t AstInt ms
u = AstInt ms -> AstInt ms -> AstBool ms
forall r (ms :: AstMethodOfSharing).
GoodScalar r =>
AstTensor ms PrimalSpan (TKScalar r)
-> AstTensor ms PrimalSpan (TKScalar r) -> AstBool ms
AstLeqK AstInt ms
t AstInt ms
u
matchAstLeqInt :: AstBool ms -> Maybe (AstInt ms, AstInt ms)
matchAstLeqInt :: forall (ms :: AstMethodOfSharing).
AstBool ms -> Maybe (AstInt ms, AstInt ms)
matchAstLeqInt (AstLeqK @r AstTensor ms PrimalSpan (TKScalar r)
t AstTensor ms PrimalSpan (TKScalar r)
u)
| Just (:~:) @Type r Int64
Refl <- TypeRep @Type r
-> TypeRep @Type Int64 -> Maybe ((:~:) @Type r Int64)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Int64) =
(AstInt ms, AstInt ms) -> Maybe (AstInt ms, AstInt ms)
forall a. a -> Maybe a
Just (AstTensor ms PrimalSpan (TKScalar r)
AstInt ms
t, AstTensor ms PrimalSpan (TKScalar r)
AstInt ms
u)
matchAstLeqInt AstBool ms
_ = Maybe (AstInt ms, AstInt ms)
forall a. Maybe a
Nothing
type data AstMethodOfSharing = AstMethodShare | AstMethodLet
type role AstTensor nominal nominal nominal
data AstTensor :: AstMethodOfSharing -> AstSpanType -> Target where
AstPair :: forall y z ms s.
AstTensor ms s y -> AstTensor ms s z
-> AstTensor ms s (TKProduct y z)
AstProject1 :: forall y z ms s.
AstTensor ms s (TKProduct y z) -> AstTensor ms s y
AstProject2 :: forall y z ms s.
AstTensor ms s (TKProduct y z) -> AstTensor ms s z
AstFromVector :: forall y k ms s.
SNat k -> SingletonTK y
-> Data.Vector.Vector (AstTensor ms s y)
-> AstTensor ms s (BuildTensorKind k y)
AstSum :: forall y k ms s.
SNat k -> SingletonTK y
-> AstTensor ms s (BuildTensorKind k y)
-> AstTensor ms s y
AstReplicate :: forall y k ms s.
SNat k -> SingletonTK y
-> AstTensor ms s y
-> AstTensor ms s (BuildTensorKind k y)
AstMapAccumRDer
:: forall accy by ey k ms s.
SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> AstHFun s s
(TKProduct accy ey) (TKProduct accy by)
-> AstHFun s s
(TKProduct (ADTensorKind (TKProduct accy ey))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> AstHFun s s
(TKProduct (ADTensorKind (TKProduct accy by))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> AstTensor ms s accy
-> AstTensor ms s (BuildTensorKind k ey)
-> AstTensor ms s (TKProduct accy (BuildTensorKind k by))
AstMapAccumLDer
:: forall accy by ey k ms s.
SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> AstHFun s s
(TKProduct accy ey) (TKProduct accy by)
-> AstHFun s s
(TKProduct (ADTensorKind (TKProduct accy ey))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> AstHFun s s
(TKProduct (ADTensorKind (TKProduct accy by))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> AstTensor ms s accy
-> AstTensor ms s (BuildTensorKind k ey)
-> AstTensor ms s (TKProduct accy (BuildTensorKind k by))
AstApply :: (AstSpan s1, AstSpan s)
=> AstHFun s1 s x z -> AstTensor ms s1 x -> AstTensor ms s z
AstVar :: AstVarName s y -> AstTensor ms s y
AstCond :: forall y ms s.
AstBool ms -> AstTensor ms s y -> AstTensor ms s y
-> AstTensor ms s y
AstBuild1 :: forall y k ms s.
SNat k -> SingletonTK y
-> (IntVarName, AstTensor ms s y)
-> AstTensor ms s (BuildTensorKind k y)
AstLet :: forall y z s s2. AstSpan s
=> AstVarName s y -> AstTensor AstMethodLet s y
-> AstTensor AstMethodLet s2 z
-> AstTensor AstMethodLet s2 z
AstShare :: AstVarName s y -> AstTensor AstMethodShare s y
-> AstTensor AstMethodShare s y
AstToShare :: AstTensor AstMethodLet s y
-> AstTensor AstMethodShare s y
AstPrimalPart :: forall y ms.
AstTensor ms FullSpan y -> AstTensor ms PrimalSpan y
AstDualPart :: forall y ms.
AstTensor ms FullSpan y -> AstTensor ms DualSpan y
AstFromPrimal :: forall y ms.
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
AstFromDual :: forall y ms.
AstTensor ms DualSpan y -> AstTensor ms FullSpan y
AstPlusK :: GoodScalar r
=> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstTimesK :: GoodScalar r
=> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstN1K :: GoodScalar r
=> OpCodeNum1 -> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstR1K :: (RealFloatH r, Nested.FloatElt r, GoodScalar r)
=> OpCode1 -> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstR2K :: (RealFloatH r, Nested.FloatElt r, GoodScalar r)
=> OpCode2 -> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstI2K :: (IntegralH r, Nested.IntElt r, GoodScalar r)
=> OpCodeIntegral2 -> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstConcreteK :: GoodScalar r
=> r -> AstTensor ms PrimalSpan (TKScalar r)
AstFloorK :: (GoodScalar r1, RealFrac r1, GoodScalar r2, Integral r2)
=> AstTensor ms PrimalSpan (TKScalar r1)
-> AstTensor ms PrimalSpan (TKScalar r2)
AstFromIntegralK :: (GoodScalar r1, Integral r1, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKScalar r1)
-> AstTensor ms PrimalSpan (TKScalar r2)
AstCastK :: (GoodScalar r1, RealFrac r1, RealFrac r2, GoodScalar r2)
=> AstTensor ms s (TKScalar r1) -> AstTensor ms s (TKScalar r2)
AstPlusS :: GoodScalar r
=> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
AstTimesS :: GoodScalar r
=> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
AstN1S :: GoodScalar r
=> OpCodeNum1 -> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
AstR1S :: (RealFloatH r, Nested.FloatElt r, GoodScalar r)
=> OpCode1 -> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
AstR2S :: (RealFloatH r, Nested.FloatElt r, GoodScalar r)
=> OpCode2 -> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
AstI2S :: (IntegralH r, Nested.IntElt r, GoodScalar r)
=> OpCodeIntegral2 -> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS sh r)
AstConcreteS :: GoodScalar r
=> Nested.Shaped sh r -> AstTensor ms PrimalSpan (TKS sh r)
AstFloorS :: (GoodScalar r1, RealFrac r1, Integral r2, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKS sh r1)
-> AstTensor ms PrimalSpan (TKS sh r2)
AstFromIntegralS :: (GoodScalar r1, Integral r1, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKS sh r1)
-> AstTensor ms PrimalSpan (TKS sh r2)
AstCastS :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
=> AstTensor ms s (TKS sh r1)
-> AstTensor ms s (TKS sh r2)
AstIndexS :: forall shm shn x s ms.
ShS shn
-> AstTensor ms s (TKS2 (shm ++ shn) x) -> AstIxS ms shm
-> AstTensor ms s (TKS2 shn x)
AstScatterS :: forall shm shn shp x s ms.
ShS shn -> AstTensor ms s (TKS2 (shm ++ shn) x)
-> (AstVarListS shm, AstIxS ms shp)
-> AstTensor ms s (TKS2 (shp ++ shn) x)
AstGatherS :: forall shm shn shp x s ms.
ShS shn -> AstTensor ms s (TKS2 (shp ++ shn) x)
-> (AstVarListS shm, AstIxS ms shp)
-> AstTensor ms s (TKS2 (shm ++ shn) x)
AstMinIndexS :: forall n sh r r2 ms. (GoodScalar r, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKS (n ': sh) r)
-> AstTensor ms PrimalSpan (TKS (Init (n ': sh)) r2)
AstMaxIndexS :: forall n sh r r2 ms. (GoodScalar r, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKS (n ': sh) r)
-> AstTensor ms PrimalSpan (TKS (Init (n ': sh)) r2)
AstIotaS :: forall n r ms. GoodScalar r
=> SNat n -> AstTensor ms PrimalSpan (TKS '[n] r)
AstAppendS :: forall m n sh x ms s.
AstTensor ms s (TKS2 (m ': sh) x)
-> AstTensor ms s (TKS2 (n ': sh) x)
-> AstTensor ms s (TKS2 ((m + n) ': sh) x)
AstSliceS :: SNat i -> SNat n -> SNat k
-> AstTensor ms s (TKS2 (i + n + k ': sh) x)
-> AstTensor ms s (TKS2 (n ': sh) x)
AstReverseS :: forall n sh x ms s.
AstTensor ms s (TKS2 (n ': sh) x)
-> AstTensor ms s (TKS2 (n ': sh) x)
AstTransposeS :: (Permutation.IsPermutation perm, Rank perm <= Rank sh)
=> Permutation.Perm perm -> AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 (Permutation.PermutePrefix perm sh) x)
AstReshapeS :: Product sh ~ Product sh2
=> ShS sh2
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh2 x)
AstConvert :: TKConversion a b -> AstTensor ms s a -> AstTensor ms s b
AstSum0S :: AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 '[] x)
AstDot0S :: GoodScalar r
=> AstTensor ms s (TKS sh r) -> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS '[] r)
AstDot1InS :: forall sh n r ms s. GoodScalar r
=> ShS sh -> SNat n
-> AstTensor ms s (TKS (sh ++ '[n]) r)
-> AstTensor ms s (TKS (sh ++ '[n]) r)
-> AstTensor ms s (TKS sh r)
AstMatmul2S :: GoodScalar r
=> SNat m -> SNat n -> SNat p
-> AstTensor ms s (TKS '[m, n] r)
-> AstTensor ms s (TKS '[n, p] r)
-> AstTensor ms s (TKS '[m, p] r)
deriving instance Show (AstTensor ms s y)
type role AstHFun nominal nominal nominal nominal
data AstHFun s s2 x z where
AstLambda :: ~(AstVarName s x)
-> ~(AstTensor AstMethodLet s2 z)
-> AstHFun s s2 x z
deriving instance Show (AstHFun s s2 x z)
type role AstBool nominal
data AstBool ms where
AstBoolConst :: Bool -> AstBool ms
AstBoolNot :: AstBool ms -> AstBool ms
AstBoolAnd :: AstBool ms -> AstBool ms -> AstBool ms
AstLeqK :: forall r ms. GoodScalar r
=> AstTensor ms PrimalSpan (TKScalar r)
-> AstTensor ms PrimalSpan (TKScalar r)
-> AstBool ms
AstLeqS :: forall sh r ms. GoodScalar r
=> AstTensor ms PrimalSpan (TKS sh r)
-> AstTensor ms PrimalSpan (TKS sh r)
-> AstBool ms
deriving instance Show (AstBool ms)
data OpCodeNum1 =
NegateOp | AbsOp | SignumOp
deriving (Int -> OpCodeNum1 -> ShowS
[OpCodeNum1] -> ShowS
OpCodeNum1 -> String
(Int -> OpCodeNum1 -> ShowS)
-> (OpCodeNum1 -> String)
-> ([OpCodeNum1] -> ShowS)
-> Show OpCodeNum1
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OpCodeNum1 -> ShowS
showsPrec :: Int -> OpCodeNum1 -> ShowS
$cshow :: OpCodeNum1 -> String
show :: OpCodeNum1 -> String
$cshowList :: [OpCodeNum1] -> ShowS
showList :: [OpCodeNum1] -> ShowS
Show, OpCodeNum1 -> OpCodeNum1 -> Bool
(OpCodeNum1 -> OpCodeNum1 -> Bool)
-> (OpCodeNum1 -> OpCodeNum1 -> Bool) -> Eq OpCodeNum1
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OpCodeNum1 -> OpCodeNum1 -> Bool
== :: OpCodeNum1 -> OpCodeNum1 -> Bool
$c/= :: OpCodeNum1 -> OpCodeNum1 -> Bool
/= :: OpCodeNum1 -> OpCodeNum1 -> Bool
Eq)
data OpCode1 =
RecipOp
| ExpOp | LogOp | SqrtOp
| SinOp | CosOp | TanOp | AsinOp | AcosOp | AtanOp
| SinhOp | CoshOp | TanhOp | AsinhOp | AcoshOp | AtanhOp
deriving (Int -> OpCode1 -> ShowS
[OpCode1] -> ShowS
OpCode1 -> String
(Int -> OpCode1 -> ShowS)
-> (OpCode1 -> String) -> ([OpCode1] -> ShowS) -> Show OpCode1
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OpCode1 -> ShowS
showsPrec :: Int -> OpCode1 -> ShowS
$cshow :: OpCode1 -> String
show :: OpCode1 -> String
$cshowList :: [OpCode1] -> ShowS
showList :: [OpCode1] -> ShowS
Show, OpCode1 -> OpCode1 -> Bool
(OpCode1 -> OpCode1 -> Bool)
-> (OpCode1 -> OpCode1 -> Bool) -> Eq OpCode1
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OpCode1 -> OpCode1 -> Bool
== :: OpCode1 -> OpCode1 -> Bool
$c/= :: OpCode1 -> OpCode1 -> Bool
/= :: OpCode1 -> OpCode1 -> Bool
Eq)
data OpCode2 =
DivideOp
| PowerOp | LogBaseOp
| Atan2Op
deriving (Int -> OpCode2 -> ShowS
[OpCode2] -> ShowS
OpCode2 -> String
(Int -> OpCode2 -> ShowS)
-> (OpCode2 -> String) -> ([OpCode2] -> ShowS) -> Show OpCode2
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OpCode2 -> ShowS
showsPrec :: Int -> OpCode2 -> ShowS
$cshow :: OpCode2 -> String
show :: OpCode2 -> String
$cshowList :: [OpCode2] -> ShowS
showList :: [OpCode2] -> ShowS
Show, OpCode2 -> OpCode2 -> Bool
(OpCode2 -> OpCode2 -> Bool)
-> (OpCode2 -> OpCode2 -> Bool) -> Eq OpCode2
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OpCode2 -> OpCode2 -> Bool
== :: OpCode2 -> OpCode2 -> Bool
$c/= :: OpCode2 -> OpCode2 -> Bool
/= :: OpCode2 -> OpCode2 -> Bool
Eq)
data OpCodeIntegral2 =
QuotOp | RemOp
deriving (Int -> OpCodeIntegral2 -> ShowS
[OpCodeIntegral2] -> ShowS
OpCodeIntegral2 -> String
(Int -> OpCodeIntegral2 -> ShowS)
-> (OpCodeIntegral2 -> String)
-> ([OpCodeIntegral2] -> ShowS)
-> Show OpCodeIntegral2
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OpCodeIntegral2 -> ShowS
showsPrec :: Int -> OpCodeIntegral2 -> ShowS
$cshow :: OpCodeIntegral2 -> String
show :: OpCodeIntegral2 -> String
$cshowList :: [OpCodeIntegral2] -> ShowS
showList :: [OpCodeIntegral2] -> ShowS
Show, OpCodeIntegral2 -> OpCodeIntegral2 -> Bool
(OpCodeIntegral2 -> OpCodeIntegral2 -> Bool)
-> (OpCodeIntegral2 -> OpCodeIntegral2 -> Bool)
-> Eq OpCodeIntegral2
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OpCodeIntegral2 -> OpCodeIntegral2 -> Bool
== :: OpCodeIntegral2 -> OpCodeIntegral2 -> Bool
$c/= :: OpCodeIntegral2 -> OpCodeIntegral2 -> Bool
/= :: OpCodeIntegral2 -> OpCodeIntegral2 -> Bool
Eq)