{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module HordeAd.Core.Delta
(
NodeId, mkNodeId, nodeIdToFTK
, InputId, mkInputId, inputIdToFTK
, Delta(..), NestedTarget(..)
, ftkDelta
) where
import Prelude
import Control.Exception.Assert.Sugar
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Kind (Type)
import Data.Some
import Data.Type.Equality (TestEquality (..), gcastWith, testEquality, (:~:))
import Data.Vector.Generic qualified as V
import Data.Vector.Strict qualified as Data.Vector
import GHC.TypeLits (type (+), type (<=))
import Text.Show.Functions ()
import Data.Array.Nested (type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (snatPlus, unsafeCoerceRefl)
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
type role NodeId nominal nominal
data NodeId :: Target -> TK -> Type where
NodeId :: forall target y. FullShapeTK y -> Int -> NodeId target y
instance Show (NodeId target y) where
showsPrec :: Int -> NodeId target y -> ShowS
showsPrec Int
d (NodeId FullShapeTK y
_ Int
n) =
Int -> Int -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d Int
n
instance DMap.Enum1 (NodeId target) where
type Enum1Info (NodeId target) = Some FullShapeTK
fromEnum1 :: forall (a :: TK).
NodeId target a -> (Int, Enum1Info @TK (NodeId target))
fromEnum1 (NodeId FullShapeTK a
ftk Int
n) = (Int
n, FullShapeTK a -> Some @TK FullShapeTK
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some FullShapeTK a
ftk)
toEnum1 :: Int -> Enum1Info @TK (NodeId target) -> Some @TK (NodeId target)
toEnum1 Int
n (Some FullShapeTK a
ftk) = NodeId target a -> Some @TK (NodeId target)
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some (NodeId target a -> Some @TK (NodeId target))
-> NodeId target a -> Some @TK (NodeId target)
forall a b. (a -> b) -> a -> b
$ FullShapeTK a -> Int -> NodeId target a
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> NodeId target y
NodeId FullShapeTK a
ftk Int
n
instance TestEquality (NodeId target) where
testEquality :: forall (a :: TK) (b :: TK).
NodeId target a -> NodeId target b -> Maybe ((:~:) @TK a b)
testEquality (NodeId FullShapeTK a
ftk1 Int
_) (NodeId FullShapeTK b
ftk2 Int
_) = FullShapeTK a -> FullShapeTK b -> Maybe ((:~:) @TK a b)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK a
ftk1 FullShapeTK b
ftk2
mkNodeId :: FullShapeTK y -> Int -> NodeId f y
mkNodeId :: forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> NodeId f y
mkNodeId FullShapeTK y
ftk Int
i = Bool -> NodeId f y -> NodeId f y
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0) (NodeId f y -> NodeId f y) -> NodeId f y -> NodeId f y
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> Int -> NodeId f y
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> NodeId target y
NodeId FullShapeTK y
ftk Int
i
nodeIdToFTK :: NodeId f y -> FullShapeTK y
nodeIdToFTK :: forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK (NodeId FullShapeTK y
ftk Int
_) = FullShapeTK y
ftk
type role InputId nominal nominal
data InputId :: Target -> TK -> Type where
InputId :: forall target y. FullShapeTK y -> Int -> InputId target y
instance Show (InputId target y) where
showsPrec :: Int -> InputId target y -> ShowS
showsPrec Int
_ (InputId FullShapeTK y
_ Int
n) =
Bool -> ShowS -> ShowS
showParen Bool
True
(ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"InputId "
ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Show a => a -> ShowS
shows Int
n
instance DMap.Enum1 (InputId target) where
type Enum1Info (InputId target) = Some FullShapeTK
fromEnum1 :: forall (a :: TK).
InputId target a -> (Int, Enum1Info @TK (InputId target))
fromEnum1 (InputId FullShapeTK a
ftk Int
n) = (Int
n, FullShapeTK a -> Some @TK FullShapeTK
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some FullShapeTK a
ftk)
toEnum1 :: Int -> Enum1Info @TK (InputId target) -> Some @TK (InputId target)
toEnum1 Int
n (Some FullShapeTK a
ftk) = InputId target a -> Some @TK (InputId target)
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some (InputId target a -> Some @TK (InputId target))
-> InputId target a -> Some @TK (InputId target)
forall a b. (a -> b) -> a -> b
$ FullShapeTK a -> Int -> InputId target a
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> InputId target y
InputId FullShapeTK a
ftk Int
n
instance TestEquality (InputId target) where
testEquality :: forall (a :: TK) (b :: TK).
InputId target a -> InputId target b -> Maybe ((:~:) @TK a b)
testEquality (InputId FullShapeTK a
ftk1 Int
_) (InputId FullShapeTK b
ftk2 Int
_) = FullShapeTK a -> FullShapeTK b -> Maybe ((:~:) @TK a b)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK a
ftk1 FullShapeTK b
ftk2
mkInputId :: FullShapeTK y -> Int -> InputId f y
mkInputId :: forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> InputId f y
mkInputId FullShapeTK y
ftk Int
i = Bool -> InputId f y -> InputId f y
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0) (InputId f y -> InputId f y) -> InputId f y -> InputId f y
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> Int -> InputId f y
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> InputId target y
InputId FullShapeTK y
ftk Int
i
inputIdToFTK :: InputId f y -> FullShapeTK y
inputIdToFTK :: forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK (InputId FullShapeTK y
ftk Int
_) = FullShapeTK y
ftk
type role Delta nominal nominal
data Delta :: Target -> Target where
DeltaShare :: NodeId target y -> Delta target y -> Delta target y
DeltaInput :: InputId target y -> Delta target y
DeltaPair :: forall y z target.
Delta target y -> Delta target z
-> Delta target (TKProduct y z)
DeltaProject1 :: forall y z target.
Delta target (TKProduct y z) -> Delta target y
DeltaProject2 :: forall y z target.
Delta target (TKProduct y z) -> Delta target z
DeltaFromVector :: forall y k target.
SNat k -> SingletonTK y
-> Data.Vector.Vector (Delta target y)
-> Delta target (BuildTensorKind k y)
DeltaSum :: forall y k target.
SNat k -> SingletonTK y
-> Delta target (BuildTensorKind k y)
-> Delta target y
DeltaReplicate :: forall y k target.
SNat k -> SingletonTK y
-> Delta target y
-> Delta target (BuildTensorKind k y)
DeltaMapAccumR
:: forall target k accy by ey.
( Show (target (BuildTensorKind k accy))
, Show (target (BuildTensorKind k ey)) )
=> SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> HFun (TKProduct (ADTensorKind (TKProduct accy ey))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFun (TKProduct (ADTensorKind (TKProduct accy by))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> Delta target accy
-> Delta target (BuildTensorKind k ey)
-> Delta target (TKProduct accy (BuildTensorKind k by))
DeltaMapAccumL
:: forall target k accy by ey.
( Show (target (BuildTensorKind k accy))
, Show (target (BuildTensorKind k ey)) )
=> SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> HFun (TKProduct (ADTensorKind (TKProduct accy ey))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
-> HFun (TKProduct (ADTensorKind (TKProduct accy by))
(TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
-> Delta target accy
-> Delta target (BuildTensorKind k ey)
-> Delta target (TKProduct accy (BuildTensorKind k by))
DeltaZero :: FullShapeTK y -> Delta target y
DeltaScale :: Num (target y)
=> NestedTarget target y -> Delta target y -> Delta target y
DeltaAdd :: Num (target y)
=> Delta target y -> Delta target y -> Delta target y
DeltaCastK :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
=> Delta target (TKScalar r1) -> Delta target (TKScalar r2)
DeltaCastR :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
=> Delta target (TKR n r1) -> Delta target (TKR n r2)
DeltaSum0R :: Delta target (TKR2 n r) -> Delta target (TKR2 0 r)
DeltaDot0R :: (GoodScalar r, Show (target (TKR n r)))
=> target (TKR n r) -> Delta target (TKR n r)
-> Delta target (TKR 0 r)
DeltaIndexR :: forall m n r target.
SNat n
-> Delta target (TKR2 (m + n) r) -> IxROf target m
-> Delta target (TKR2 n r)
DeltaScatterR :: forall m n p r target.
SNat m -> SNat n -> SNat p
-> IShR (p + n) -> Delta target (TKR2 (m + n) r)
-> (IxROf target m -> IxROf target p)
-> Delta target (TKR2 (p + n) r)
DeltaGatherR :: forall m n p r target.
SNat m -> SNat n -> SNat p
-> IShR (m + n) -> Delta target (TKR2 (p + n) r)
-> (IxROf target m -> IxROf target p)
-> Delta target (TKR2 (m + n) r)
DeltaAppendR :: Delta target (TKR2 (1 + n) r)
-> Delta target (TKR2 (1 + n) r)
-> Delta target (TKR2 (1 + n) r)
DeltaSliceR :: Int -> Int -> Delta target (TKR2 (1 + n) r)
-> Delta target (TKR2 (1 + n) r)
DeltaReverseR :: Delta target (TKR2 (1 + n) r)
-> Delta target (TKR2 (1 + n) r)
DeltaTransposeR :: Permutation.PermR -> Delta target (TKR2 n r)
-> Delta target (TKR2 n r)
DeltaReshapeR :: IShR m -> Delta target (TKR2 n r)
-> Delta target (TKR2 m r)
DeltaCastS :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
=> Delta target (TKS sh r1) -> Delta target (TKS sh r2)
DeltaSum0S :: Delta target (TKS2 sh r) -> Delta target (TKS2 '[] r)
DeltaDot0S :: (GoodScalar r, Show (target (TKS sh r)))
=> target (TKS sh r) -> Delta target (TKS sh r)
-> Delta target (TKS '[] r)
DeltaIndexS :: forall shm shn r target.
ShS shn
-> Delta target (TKS2 (shm ++ shn) r) -> IxSOf target shm
-> Delta target (TKS2 shn r)
DeltaScatterS :: forall shm shn shp r target.
ShS shm -> ShS shn -> ShS shp
-> Delta target (TKS2 (shm ++ shn) r)
-> (IxSOf target shm -> IxSOf target shp)
-> Delta target (TKS2 (shp ++ shn) r)
DeltaGatherS :: forall shm shn shp r target.
ShS shm -> ShS shn -> ShS shp
-> Delta target (TKS2 (shp ++ shn) r)
-> (IxSOf target shm -> IxSOf target shp)
-> Delta target (TKS2 (shm ++ shn) r)
DeltaAppendS :: forall target r m n sh.
Delta target (TKS2 (m ': sh) r)
-> Delta target (TKS2 (n ': sh) r)
-> Delta target (TKS2 ((m + n) ': sh) r)
DeltaSliceS :: SNat i -> SNat n -> SNat k
-> Delta target (TKS2 (i + n + k ': sh) r)
-> Delta target (TKS2 (n ': sh) r)
DeltaReverseS :: Delta target (TKS2 (n ': sh) r)
-> Delta target (TKS2 (n ': sh) r)
DeltaTransposeS :: forall perm sh r target.
(Permutation.IsPermutation perm, Rank perm <= Rank sh)
=> Permutation.Perm perm
-> Delta target (TKS2 sh r)
-> Delta target (TKS2 (Permutation.PermutePrefix perm sh) r)
DeltaReshapeS :: Product sh ~ Product sh2
=> ShS sh2
-> Delta target (TKS2 sh r)
-> Delta target (TKS2 sh2 r)
DeltaCastX :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
=> Delta target (TKX sh r1) -> Delta target (TKX sh r2)
DeltaSum0X :: Delta target (TKX2 sh r) -> Delta target (TKX2 '[] r)
DeltaDot0X :: (GoodScalar r, Show (target (TKX sh r)))
=> target (TKX sh r) -> Delta target (TKX sh r)
-> Delta target (TKX '[] r)
DeltaIndexX :: forall shm shn r target.
StaticShX shn
-> Delta target (TKX2 (shm ++ shn) r) -> IxXOf target shm
-> Delta target (TKX2 shn r)
DeltaScatterX :: StaticShX shm -> StaticShX shn -> StaticShX shp
-> IShX (shp ++ shn) -> Delta target (TKX2 (shm ++ shn) r)
-> (IxXOf target shm -> IxXOf target shp)
-> Delta target (TKX2 (shp ++ shn) r)
DeltaGatherX :: StaticShX shm -> StaticShX shn -> StaticShX shp
-> IShX (shm ++ shn) -> Delta target (TKX2 (shp ++ shn) r)
-> (IxXOf target shm -> IxXOf target shp)
-> Delta target (TKX2 (shm ++ shn) r)
DeltaAppendX :: Delta target (TKX2 (Just m ': sh) r)
-> Delta target (TKX2 (Just n ': sh) r)
-> Delta target (TKX2 (Just (m + n) ': sh) r)
DeltaSliceX :: SNat i -> SNat n -> SNat k
-> Delta target (TKX2 (Just (i + n + k) ': sh) r)
-> Delta target (TKX2 (Just n ': sh) r)
DeltaReverseX :: Delta target (TKX2 (mn ': sh) r)
-> Delta target (TKX2 (mn ': sh) r)
DeltaTransposeX :: forall perm sh r target.
(Permutation.IsPermutation perm, Rank perm <= Rank sh)
=> Permutation.Perm perm
-> Delta target (TKX2 sh r)
-> Delta target (TKX2 (Permutation.PermutePrefix perm sh) r)
DeltaReshapeX :: IShX sh2 -> Delta target (TKX2 sh r)
-> Delta target (TKX2 sh2 r)
DeltaConvert :: TKConversion a b -> Delta target a -> Delta target b
deriving instance Show (IntOf target) => Show (Delta target y)
type NestedTarget :: Target -> Target
type role NestedTarget nominal nominal
newtype NestedTarget target y = NestedTarget (target y)
instance Show (NestedTarget target y) where
showsPrec :: Int -> NestedTarget target y -> ShowS
showsPrec Int
_ NestedTarget target y
_ = String -> ShowS
showString String
"<primal>"
ftkDelta :: forall target y.
Delta target y -> FullShapeTK y
ftkDelta :: forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta = \case
DeltaShare NodeId target y
i Delta target y
_ -> NodeId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK NodeId target y
i
DeltaInput InputId target y
i -> InputId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target y
i
DeltaPair Delta target y
t1 Delta target z
t2 -> FullShapeTK y -> FullShapeTK z -> FullShapeTK (TKProduct y z)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
t1) (Delta target z -> FullShapeTK z
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target z
t2)
DeltaProject1 Delta target (TKProduct y z)
v -> case Delta target (TKProduct y z) -> FullShapeTK (TKProduct y z)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct y z)
v of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
_ -> FullShapeTK y
FullShapeTK y1
ftk1
DeltaProject2 Delta target (TKProduct y y)
v -> case Delta target (TKProduct y y) -> FullShapeTK (TKProduct y y)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct y y)
v of
FTKProduct FullShapeTK y1
_ FullShapeTK z
ftk2 -> FullShapeTK y
FullShapeTK z
ftk2
DeltaFromVector SNat k
snat SingletonTK y
_ Vector (Delta target y)
l -> case Vector (Delta target y)
-> Maybe (Delta target y, Vector (Delta target y))
forall (v :: Type -> Type) a. Vector v a => v a -> Maybe (a, v a)
V.uncons Vector (Delta target y)
l of
Maybe (Delta target y, Vector (Delta target y))
Nothing -> String -> FullShapeTK y
forall a. (?callStack::CallStack) => String -> a
error String
"ftkDelta: empty vector"
Just (Delta target y
d, Vector (Delta target y)
_) -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d)
DeltaSum SNat k
snat SingletonTK y
stk Delta target (BuildTensorKind k y)
d -> SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
snat SingletonTK y
stk (Delta target (BuildTensorKind k y)
-> FullShapeTK (BuildTensorKind k y)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (BuildTensorKind k y)
d)
DeltaReplicate SNat k
snat SingletonTK y
_ Delta target y
d -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d)
DeltaMapAccumR SNat k
k FullShapeTK by
bftk FullShapeTK ey
_eftk target (BuildTensorKind k accy)
_q target (BuildTensorKind k ey)
_es HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
_es' ->
FullShapeTK accy
-> FullShapeTK (BuildTensorKind k by)
-> FullShapeTK (TKProduct accy (BuildTensorKind k by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0') (SNat k -> FullShapeTK by -> FullShapeTK (BuildTensorKind k by)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK by
bftk)
DeltaMapAccumL SNat k
k FullShapeTK by
bftk FullShapeTK ey
_eftk target (BuildTensorKind k accy)
_q target (BuildTensorKind k ey)
_es HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
_es' ->
FullShapeTK accy
-> FullShapeTK (BuildTensorKind k by)
-> FullShapeTK (TKProduct accy (BuildTensorKind k by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0') (SNat k -> FullShapeTK by -> FullShapeTK (BuildTensorKind k by)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK by
bftk)
DeltaZero FullShapeTK y
ftk -> FullShapeTK y
ftk
DeltaScale NestedTarget target y
_ Delta target y
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d
DeltaAdd (DeltaShare NodeId target y
i Delta target y
_) Delta target y
_ -> NodeId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK NodeId target y
i
DeltaAdd Delta target y
_ Delta target y
e -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
e
DeltaCastK{} -> FullShapeTK y
FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaCastR Delta target (TKR n r1)
d -> case Delta target (TKR n r1) -> FullShapeTK (TKR n r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR n r1)
d of
FTKR IShR n
sh FullShapeTK x
_ -> IShR n
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKR2 n (TKScalar r2))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR n
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaSum0R Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
_ FullShapeTK x
x -> IShR 0 -> FullShapeTK x -> FullShapeTK (TKR2 0 x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR 0
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR FullShapeTK x
x
DeltaDot0R{} -> IShR 0
-> FullShapeTK (TKScalar r) -> FullShapeTK (TKR2 0 (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR 0
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaIndexR SNat n
SNat Delta target (TKR2 (m + n) r)
d IxROf target m
ix | SNat m
SNat <- IxROf target m -> SNat m
forall (n :: Nat) i. IxR n i -> SNat n
ixrRank IxROf target m
ix -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
FTKR IShR n
sh FullShapeTK x
x -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (ShR (m + n) Int -> IShR n
forall (m :: Nat) (n :: Nat) i.
(KnownNat m, KnownNat n) =>
ShR (m + n) i -> ShR n i
shrDrop IShR n
ShR (m + n) Int
sh) FullShapeTK x
x
DeltaScatterR SNat m
_ SNat n
_ SNat p
_ IShR (p + n)
sh Delta target (TKR2 (m + n) r)
d IxROf target m -> IxROf target p
_ -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x -> IShR (p + n) -> FullShapeTK x -> FullShapeTK (TKR2 (p + n) x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR (p + n)
sh FullShapeTK x
x
DeltaGatherR SNat m
_ SNat n
_ SNat p
_ IShR (m + n)
sh Delta target (TKR2 (p + n) r)
d IxROf target m -> IxROf target p
_ -> case Delta target (TKR2 (p + n) r) -> FullShapeTK (TKR2 (p + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (p + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x -> IShR (m + n) -> FullShapeTK x -> FullShapeTK (TKR2 (m + n) x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR (m + n)
sh FullShapeTK x
x
DeltaAppendR Delta target (TKR2 (1 + n) r)
a Delta target (TKR2 (1 + n) r)
b -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
a of
FTKR ShR n Int
ZSR FullShapeTK x
_ -> String -> FullShapeTK y
forall a. (?callStack::CallStack) => String -> a
error String
"ftkDelta: impossible pattern needlessly required"
FTKR (Int
ai :$: ShR n Int
ash) FullShapeTK x
x -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
Delta target (TKR2 (1 + n) r)
b of
FTKR ShR n Int
ZSR FullShapeTK x
_ -> String -> FullShapeTK y
forall a. (?callStack::CallStack) => String -> a
error String
"ftkDelta: impossible pattern needlessly required"
FTKR (Int
bi :$: ShR n Int
_) FullShapeTK x
_ -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
ai Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
bi Int -> ShR n Int -> IShR n
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR n Int
ash) FullShapeTK x
x
DeltaSliceR Int
_ Int
n Delta target (TKR2 (1 + n) r)
d -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
FTKR IShR n
sh FullShapeTK x
x -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
n Int -> ShR n Int -> IShR n
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR (n + 1) Int -> ShR n Int
forall (n :: Nat) i. ShR (n + 1) i -> ShR n i
shrTail IShR n
ShR (n + 1) Int
sh) FullShapeTK x
x
DeltaReverseR Delta target (TKR2 (1 + n) r)
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
Delta target (TKR2 (1 + n) r)
d
DeltaTransposeR PermR
perm Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
sh FullShapeTK x
x -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (PermR -> IShR n -> IShR n
forall (n :: Nat) i. PermR -> ShR n i -> ShR n i
shrPermutePrefix PermR
perm IShR n
sh) FullShapeTK x
x
DeltaReshapeR IShR m
sh Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
_ FullShapeTK x
x -> IShR m -> FullShapeTK x -> FullShapeTK (TKR2 m x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR m
sh FullShapeTK x
x
DeltaCastS Delta target (TKS sh r1)
d -> case Delta target (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS sh r1)
d of
FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS2 sh (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaSum0S Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
_ FullShapeTK x
x -> ShS ('[] @Nat) -> FullShapeTK x -> FullShapeTK (TKS2 ('[] @Nat) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK x
x
DeltaDot0S{} -> ShS ('[] @Nat)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKS2 ('[] @Nat) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaIndexS ShS shn
shn Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm
_ix -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x -> ShS shn -> FullShapeTK x -> FullShapeTK (TKS2 shn x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS shn
shn FullShapeTK x
x
DeltaScatterS ShS shm
_shm ShS shn
shn ShS shp
shp Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm -> IxSOf target shp
_ -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x -> ShS ((++) @Nat shp shn)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat shp shn) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS shp
shp ShS shp -> ShS shn -> ShS ((++) @Nat shp shn)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS shn
shn) FullShapeTK x
x
DeltaGatherS ShS shm
shm ShS shn
shn ShS shp
_shp Delta target (TKS2 ((++) @Nat shp shn) r)
d IxSOf target shm -> IxSOf target shp
_ -> case Delta target (TKS2 ((++) @Nat shp shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shp shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x -> ShS ((++) @Nat shm shn)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS shm
shm ShS shm -> ShS shn -> ShS ((++) @Nat shm shn)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS shn
shn) FullShapeTK x
x
DeltaAppendS Delta target (TKS2 ((':) @Nat m sh) r)
a Delta target (TKS2 ((':) @Nat n sh) r)
b -> case (Delta target (TKS2 ((':) @Nat m sh) r)
-> FullShapeTK (TKS2 ((':) @Nat m sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat m sh) r)
a, Delta target (TKS2 ((':) @Nat n sh) r)
-> FullShapeTK (TKS2 ((':) @Nat n sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat n sh) r)
b) of
(FTKS (SNat n
m :$$ ShS sh
sh) FullShapeTK x
x, FTKS (SNat n
n :$$ ShS sh
_) FullShapeTK x
_) -> ShS ((':) @Nat (m + n) sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat (m + n) sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n -> SNat n -> SNat (n + n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
snatPlus SNat n
m SNat n
n SNat (m + n) -> ShS sh -> ShS ((':) @Nat (m + n) sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
DeltaSliceS SNat i
_ n :: SNat n
n@SNat n
SNat SNat k
_ Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d -> case Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d of
FTKS (SNat n
_ :$$ ShS sh
sh) FullShapeTK x
x -> ShS ((':) @Nat n sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat n sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n
n SNat n -> ShS sh -> ShS ((':) @Nat n sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
DeltaReverseS Delta target (TKS2 ((':) @Nat n sh) r)
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
Delta target (TKS2 ((':) @Nat n sh) r)
d
DeltaTransposeS Perm perm
perm Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
sh FullShapeTK x
x -> ShS (PermutePrefix @Nat perm sh)
-> FullShapeTK x
-> FullShapeTK (TKS2 (PermutePrefix @Nat perm sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (Perm perm -> ShS sh -> ShS (PermutePrefix @Nat perm sh)
forall (is :: [Nat]) (sh :: [Nat]).
Perm is -> ShS sh -> ShS (PermutePrefix @Nat is sh)
shsPermutePrefix Perm perm
perm ShS sh
sh) FullShapeTK x
x
DeltaReshapeS ShS sh2
sh2 Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
_ FullShapeTK x
x -> ShS sh2 -> FullShapeTK x -> FullShapeTK (TKS2 sh2 x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh2
sh2 FullShapeTK x
x
DeltaCastX Delta target (TKX sh r1)
d -> case Delta target (TKX sh r1) -> FullShapeTK (TKX sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX sh r1)
d of
FTKX IShX sh
sh FullShapeTK x
FTKScalar -> IShX sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKX2 sh (TKScalar r2))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaSum0X Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
_ FullShapeTK x
x -> IShX ('[] @(Maybe Nat))
-> FullShapeTK x -> FullShapeTK (TKX2 ('[] @(Maybe Nat)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX FullShapeTK x
x
DeltaDot0X{} -> IShX ('[] @(Maybe Nat))
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
DeltaIndexX @shm @shn StaticShX shn
shn Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm
ix -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
FTKX IShX sh
sh FullShapeTK x
x | SNat @len <- IxXOf target shm -> SNat (Rank @(Maybe Nat) shm)
forall (sh :: [Maybe Nat]) i.
IxX sh i -> SNat (Rank @(Maybe Nat) sh)
ixxRank IxXOf target shm
ix ->
(:~:)
@[Maybe Nat] (Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh) shn
-> (((Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
~ (shn :: [Maybe Nat])) =>
FullShapeTK y)
-> FullShapeTK y
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
@[Maybe Nat] (Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh) shn
(:~:)
@[Maybe Nat]
(Drop
@(Maybe Nat) (Rank @(Maybe Nat) shm) ((++) @(Maybe Nat) shm shn))
shn
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Drop (Rank shm) (shm ++ shn) :~: shn) ((((Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
~ (shn :: [Maybe Nat])) =>
FullShapeTK y)
-> FullShapeTK y)
-> (((Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
~ (shn :: [Maybe Nat])) =>
FullShapeTK y)
-> FullShapeTK y
forall a b. (a -> b) -> a -> b
$
StaticShX sh -> (KnownShX sh => FullShapeTK y) -> FullShapeTK y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => FullShapeTK y) -> FullShapeTK y)
-> (KnownShX sh => FullShapeTK y) -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$
StaticShX shn -> (KnownShX shn => FullShapeTK y) -> FullShapeTK y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => FullShapeTK y) -> FullShapeTK y)
-> (KnownShX shn => FullShapeTK y) -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$
IShX shn -> FullShapeTK x -> FullShapeTK (TKX2 shn x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (forall (len :: Nat) (sh :: [Maybe Nat]).
(KnownNat len, KnownShX sh, KnownShX (Drop @(Maybe Nat) len sh)) =>
IShX sh -> IShX (Drop @(Maybe Nat) len sh)
shxDrop @len IShX sh
sh) FullShapeTK x
x
DeltaScatterX StaticShX shm
_ StaticShX shn
_ StaticShX shp
_ IShX ((++) @(Maybe Nat) shp shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm -> IxXOf target shp
_ -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
FTKX IShX sh
_ FullShapeTK x
x -> IShX ((++) @(Maybe Nat) shp shn)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shp shn) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ((++) @(Maybe Nat) shp shn)
sh FullShapeTK x
x
DeltaGatherX StaticShX shm
_ StaticShX shn
_ StaticShX shp
_ IShX ((++) @(Maybe Nat) shm shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d IxXOf target shm -> IxXOf target shp
_ -> case Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d of
FTKX IShX sh
_ FullShapeTK x
x -> IShX ((++) @(Maybe Nat) shm shn)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ((++) @(Maybe Nat) shm shn)
sh FullShapeTK x
x
DeltaAppendX Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
a Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
b -> case (Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
a, Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
b) of
(FTKX (Nested.SKnown SNat n1
m :$% ShX sh Int
sh) FullShapeTK x
x, FTKX (Nested.SKnown SNat n1
n :$% ShX sh Int
_) FullShapeTK x
_) ->
IShX ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat (m + n) -> SMayNat @Nat Int SNat ('Just @Nat (m + n))
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (SNat n1 -> SNat n1 -> SNat (n1 + n1)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
snatPlus SNat n1
m SNat n1
n) SMayNat @Nat Int SNat ('Just @Nat (m + n))
-> ShX sh Int -> IShX ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
(sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX sh Int
sh) FullShapeTK x
x
DeltaSliceX SNat i
_ n :: SNat n
n@SNat n
SNat SNat k
_ Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d -> case Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> FullShapeTK
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d of
FTKX (SMayNat @Nat Int SNat n
_ :$% ShX sh Int
sh) FullShapeTK x
x -> IShX ((':) @(Maybe Nat) ('Just @Nat n) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat n -> SMayNat @Nat Int SNat ('Just @Nat n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat n
n SMayNat @Nat Int SNat ('Just @Nat n)
-> ShX sh Int -> IShX ((':) @(Maybe Nat) ('Just @Nat n) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
(sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX sh Int
sh) FullShapeTK x
x
DeltaReverseX Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d
DeltaTransposeX Perm perm
perm Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
sh FullShapeTK x
x -> IShX (PermutePrefix @(Maybe Nat) perm sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (Perm perm
-> IShX sh -> ShX (PermutePrefix @(Maybe Nat) perm sh) Int
forall (is :: [Nat]) (sh :: [Maybe Nat]) i.
Perm is -> ShX sh i -> ShX (PermutePrefix @(Maybe Nat) is sh) i
shxPermutePrefix Perm perm
perm IShX sh
sh) FullShapeTK x
x
DeltaReshapeX IShX sh2
sh2 Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
_ FullShapeTK x
x -> IShX sh2 -> FullShapeTK x -> FullShapeTK (TKX2 sh2 x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh2
sh2 FullShapeTK x
x
DeltaConvert TKConversion a y
c Delta target a
d -> TKConversion a y -> FullShapeTK a -> FullShapeTK y
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a y
c (FullShapeTK a -> FullShapeTK y) -> FullShapeTK a -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$ Delta target a -> FullShapeTK a
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a
d