{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | The grammar of delta expressions.
--
-- A delta expression can be viewed as a concise representation
-- of a linear map (which is the derivative of the objective function)
-- and its evaluation on a given argument (in module "HordeAd.Core.DeltaEval")
-- as an adjoint (in the algebraic sense) of the linear map
-- applied to that argument. Since linear maps can be represented
-- as matrices, this operation corresponds to a transposition
-- of the matrix. However, the matrix is not constructed,
-- but is represented and transposed preserving the sparsity
-- of the representation.
--
-- The \'sparsity\' is less obvious when a delta expression
-- contains big concrete tensors, e.g., via the `DeltaScale` constructor.
-- However, via 'DeltaReplicate' and other constructors, the tensors
-- can be enlarged much beyond what's embedded in the delta term.
-- Also, if the expression refers to unknown inputs ('DeltaInput')
-- it may denote, after evaluation, a still larger tensor.
--
-- The algebraic structure here is an extension of vector space
-- with some additional constructors. The crucial extra constructor
-- 'DeltaInput' replaces the usual one-hot access to parameters
-- with something cheaper and more uniform.
-- A lot of the remaining additional constructors is for introducing
-- and reducing dimensions of tensors and it mimics many of the operations
-- available for the primal value arrays.
module HordeAd.Core.Delta
  ( -- * Delta identifiers
    NodeId, mkNodeId, nodeIdToFTK
  , InputId, mkInputId, inputIdToFTK
    -- * The grammar of delta expressions
  , Delta(..), NestedTarget(..)
    -- * Full tensor kind derivation for delta expressions
  , ftkDelta
  ) where

import Prelude

import Control.Exception.Assert.Sugar
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Kind (Type)
import Data.Some
import Data.Type.Equality (TestEquality (..), gcastWith, testEquality, (:~:))
import Data.Vector.Generic qualified as V
import Data.Vector.Strict qualified as Data.Vector
import GHC.TypeLits (type (+), type (<=))
import Text.Show.Functions ()

import Data.Array.Nested (type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (snatPlus, unsafeCoerceRefl)

import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types

-- * Delta identifiers

-- | The identifiers for nodes of delta expression trees.
type role NodeId nominal nominal
data NodeId :: Target -> TK -> Type where
  NodeId :: forall target y. FullShapeTK y -> Int -> NodeId target y

-- No Eq instance to limit hacks outside this module.

instance Show (NodeId target y) where
  showsPrec :: Int -> NodeId target y -> ShowS
showsPrec Int
d (NodeId FullShapeTK y
_ Int
n) =
    Int -> Int -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d Int
n  -- less verbose, more readable

instance DMap.Enum1 (NodeId target) where
  type Enum1Info (NodeId target) = Some FullShapeTK
  fromEnum1 :: forall (a :: TK).
NodeId target a -> (Int, Enum1Info @TK (NodeId target))
fromEnum1 (NodeId FullShapeTK a
ftk Int
n) = (Int
n, FullShapeTK a -> Some @TK FullShapeTK
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some FullShapeTK a
ftk)
  toEnum1 :: Int -> Enum1Info @TK (NodeId target) -> Some @TK (NodeId target)
toEnum1 Int
n (Some FullShapeTK a
ftk) = NodeId target a -> Some @TK (NodeId target)
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some (NodeId target a -> Some @TK (NodeId target))
-> NodeId target a -> Some @TK (NodeId target)
forall a b. (a -> b) -> a -> b
$ FullShapeTK a -> Int -> NodeId target a
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> NodeId target y
NodeId FullShapeTK a
ftk Int
n

instance TestEquality (NodeId target) where
  testEquality :: forall (a :: TK) (b :: TK).
NodeId target a -> NodeId target b -> Maybe ((:~:) @TK a b)
testEquality (NodeId FullShapeTK a
ftk1 Int
_) (NodeId FullShapeTK b
ftk2 Int
_) = FullShapeTK a -> FullShapeTK b -> Maybe ((:~:) @TK a b)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK a
ftk1 FullShapeTK b
ftk2

-- | Wrap non-negative (only!) integers in the t'NodeId' newtype.
mkNodeId :: FullShapeTK y -> Int -> NodeId f y
mkNodeId :: forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> NodeId f y
mkNodeId FullShapeTK y
ftk Int
i = Bool -> NodeId f y -> NodeId f y
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0) (NodeId f y -> NodeId f y) -> NodeId f y -> NodeId f y
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> Int -> NodeId f y
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> NodeId target y
NodeId FullShapeTK y
ftk Int
i

nodeIdToFTK :: NodeId f y -> FullShapeTK y
nodeIdToFTK :: forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK (NodeId FullShapeTK y
ftk Int
_) = FullShapeTK y
ftk

-- | The identifiers for input leaves of delta expressions.
type role InputId nominal nominal
data InputId :: Target -> TK -> Type where
  InputId :: forall target y. FullShapeTK y -> Int -> InputId target y

-- No Eq instance to limit hacks outside this module.

instance Show (InputId target y) where  -- backward compatibility
  showsPrec :: Int -> InputId target y -> ShowS
showsPrec Int
_ (InputId FullShapeTK y
_ Int
n) =
    Bool -> ShowS -> ShowS
showParen Bool
True
    (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"InputId "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Show a => a -> ShowS
shows Int
n

instance DMap.Enum1 (InputId target) where
  type Enum1Info (InputId target) = Some FullShapeTK
  fromEnum1 :: forall (a :: TK).
InputId target a -> (Int, Enum1Info @TK (InputId target))
fromEnum1 (InputId FullShapeTK a
ftk Int
n) = (Int
n, FullShapeTK a -> Some @TK FullShapeTK
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some FullShapeTK a
ftk)
  toEnum1 :: Int -> Enum1Info @TK (InputId target) -> Some @TK (InputId target)
toEnum1 Int
n (Some FullShapeTK a
ftk) = InputId target a -> Some @TK (InputId target)
forall {k} (tag :: k -> Type) (a :: k). tag a -> Some @k tag
Some (InputId target a -> Some @TK (InputId target))
-> InputId target a -> Some @TK (InputId target)
forall a b. (a -> b) -> a -> b
$ FullShapeTK a -> Int -> InputId target a
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> InputId target y
InputId FullShapeTK a
ftk Int
n

instance TestEquality (InputId target) where
  testEquality :: forall (a :: TK) (b :: TK).
InputId target a -> InputId target b -> Maybe ((:~:) @TK a b)
testEquality (InputId FullShapeTK a
ftk1 Int
_) (InputId FullShapeTK b
ftk2 Int
_) = FullShapeTK a -> FullShapeTK b -> Maybe ((:~:) @TK a b)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK a
ftk1 FullShapeTK b
ftk2

-- | Wrap non-negative (only!) integers in the t'InputId' newtype.
mkInputId :: FullShapeTK y -> Int -> InputId f y
mkInputId :: forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> InputId f y
mkInputId FullShapeTK y
ftk Int
i = Bool -> InputId f y -> InputId f y
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0) (InputId f y -> InputId f y) -> InputId f y -> InputId f y
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> Int -> InputId f y
forall (target :: Target) (y :: TK).
FullShapeTK y -> Int -> InputId target y
InputId FullShapeTK y
ftk Int
i

inputIdToFTK :: InputId f y -> FullShapeTK y
inputIdToFTK :: forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK (InputId FullShapeTK y
ftk Int
_) = FullShapeTK y
ftk


-- * The grammar of delta expressions

-- | The grammar of delta expressions.
--
-- The t`NodeId` identifier that appears in a @DeltaShare n d@ expression
-- is the unique identity stamp of subterm @d@, that is, there is
-- no different term @e@ such that @DeltaShare n e@ appears in any delta
-- expression term in memory during the same run of an executable.
-- The subterm identity is used to avoid evaluating shared
-- subterms repeatedly in gradient and derivative computations.
-- The identifiers also represent data dependencies among terms
-- for the purpose of gradient and derivative computation. Computation for
-- a term may depend only on data obtained from terms with lower value
-- of their node identifiers. Such data dependency determination
-- agrees with the subterm relation, but is faster than traversing
-- the term tree in order to determine the relation of terms.
--
-- When computing gradients, node identifiers are also used to index,
-- directly or indirectly, the data accumulated for each node,
-- in the form of cotangents, that is partial derivatives
-- of the objective function with respect to the position(s)
-- of the node in the whole objective function dual number term
-- (or, more precisely, with respect to the single node in the term DAG,
-- in which subterms with the same node identifier are collapsed).
-- Only the @DeltaInput@ nodes have a separate data storage.
-- The t`InputId` identifiers in the @DeltaInput@ term constructors
-- are indexes into a contiguous vector of cotangents of @DeltaInput@
-- subterms of the whole term. The value at that index is the partial
-- derivative of the objective function (represented by the whole term,
-- or more precisely by (the data flow graph of) its particular
-- evaluation from which the delta expression originates)
-- with respect to the input parameter component at that index
-- in the objective function domain.
type role Delta nominal nominal
data Delta :: Target -> Target where
  -- Sharing-related operations
  DeltaShare :: NodeId target y -> Delta target y -> Delta target y
  DeltaInput :: InputId target y -> Delta target y

  -- General operations
  DeltaPair :: forall y z target.
               Delta target y -> Delta target z
            -> Delta target (TKProduct y z)
  DeltaProject1 :: forall y z target.
                   Delta target (TKProduct y z) -> Delta target y
  DeltaProject2 :: forall y z target.
                   Delta target (TKProduct y z) -> Delta target z
  DeltaFromVector :: forall y k target.
                     SNat k -> SingletonTK y
                  -> Data.Vector.Vector (Delta target y)
                  -> Delta target (BuildTensorKind k y)
  DeltaSum :: forall y k target.
              SNat k -> SingletonTK y
           -> Delta target (BuildTensorKind k y)
           -> Delta target y
  DeltaReplicate :: forall y k target.
                    SNat k -> SingletonTK y
                 -> Delta target y
                 -> Delta target (BuildTensorKind k y)
  DeltaMapAccumR
    :: forall target k accy by ey.
       ( Show (target (BuildTensorKind k accy))
       , Show (target (BuildTensorKind k ey)) )
    => SNat k
    -> FullShapeTK by
    -> FullShapeTK ey
    -> target (BuildTensorKind k accy)
    -> target (BuildTensorKind k ey)
    -> HFun (TKProduct (ADTensorKind (TKProduct accy ey))
                       (TKProduct accy ey))
            (ADTensorKind (TKProduct accy by))
    -> HFun (TKProduct (ADTensorKind (TKProduct accy by))
                       (TKProduct accy ey))
            (ADTensorKind (TKProduct accy ey))
    -> Delta target accy
    -> Delta target (BuildTensorKind k ey)
    -> Delta target (TKProduct accy (BuildTensorKind k by))
  DeltaMapAccumL
    :: forall target k accy by ey.
       ( Show (target (BuildTensorKind k accy))
       , Show (target (BuildTensorKind k ey)) )
    => SNat k
    -> FullShapeTK by
    -> FullShapeTK ey
    -> target (BuildTensorKind k accy)
    -> target (BuildTensorKind k ey)
    -> HFun (TKProduct (ADTensorKind (TKProduct accy ey))
                       (TKProduct accy ey))
            (ADTensorKind (TKProduct accy by))
    -> HFun (TKProduct (ADTensorKind (TKProduct accy by))
                       (TKProduct accy ey))
            (ADTensorKind (TKProduct accy ey))
    -> Delta target accy
    -> Delta target (BuildTensorKind k ey)
    -> Delta target (TKProduct accy (BuildTensorKind k by))

  -- Vector space operations
  DeltaZero :: FullShapeTK y -> Delta target y
  DeltaScale :: Num (target y)
             => NestedTarget target y -> Delta target y -> Delta target y
  DeltaAdd :: Num (target y)
           => Delta target y -> Delta target y -> Delta target y

  -- Scalar arithmetic
  DeltaCastK :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
             => Delta target (TKScalar r1) -> Delta target (TKScalar r2)

  -- Ranked tensor operations
  DeltaCastR :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
             => Delta target (TKR n r1) -> Delta target (TKR n r2)
  DeltaSum0R :: Delta target (TKR2 n r) -> Delta target (TKR2 0 r)
  DeltaDot0R :: (GoodScalar r, Show (target (TKR n r)))
             => target (TKR n r) -> Delta target (TKR n r)
             -> Delta target (TKR 0 r)
  DeltaIndexR :: forall m n r target.
                 SNat n
              -> Delta target (TKR2 (m + n) r) -> IxROf target m
              -> Delta target (TKR2 n r)
  DeltaScatterR :: forall m n p r target.
                   SNat m -> SNat n -> SNat p
                -> IShR (p + n) -> Delta target (TKR2 (m + n) r)
                -> (IxROf target m -> IxROf target p)
                -> Delta target (TKR2 (p + n) r)
  DeltaGatherR :: forall m n p r target.
                  SNat m -> SNat n -> SNat p
               -> IShR (m + n) -> Delta target (TKR2 (p + n) r)
               -> (IxROf target m -> IxROf target p)
               -> Delta target (TKR2 (m + n) r)
  DeltaAppendR :: Delta target (TKR2 (1 + n) r)
               -> Delta target (TKR2 (1 + n) r)
               -> Delta target (TKR2 (1 + n) r)
  DeltaSliceR :: Int -> Int -> Delta target (TKR2 (1 + n) r)
              -> Delta target (TKR2 (1 + n) r)
  DeltaReverseR :: Delta target (TKR2 (1 + n) r)
                -> Delta target (TKR2 (1 + n) r)
  DeltaTransposeR :: Permutation.PermR -> Delta target (TKR2 n r)
                  -> Delta target (TKR2 n r)
  DeltaReshapeR :: IShR m -> Delta target (TKR2 n r)
                -> Delta target (TKR2 m r)

  -- Shaped tensor operations
  DeltaCastS :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
             => Delta target (TKS sh r1) -> Delta target (TKS sh r2)
  DeltaSum0S :: Delta target (TKS2 sh r) -> Delta target (TKS2 '[] r)
  DeltaDot0S :: (GoodScalar r, Show (target (TKS sh r)))
             => target (TKS sh r) -> Delta target (TKS sh r)
             -> Delta target (TKS '[] r)
  DeltaIndexS :: forall shm shn r target.
                 ShS shn
              -> Delta target (TKS2 (shm ++ shn) r) -> IxSOf target shm
              -> Delta target (TKS2 shn r)
  DeltaScatterS :: forall shm shn shp r target.
                   ShS shm -> ShS shn -> ShS shp
                -> Delta target (TKS2 (shm ++ shn) r)
                -> (IxSOf target shm -> IxSOf target shp)
                -> Delta target (TKS2 (shp ++ shn) r)
  DeltaGatherS :: forall shm shn shp r target.
                  ShS shm -> ShS shn -> ShS shp
               -> Delta target (TKS2 (shp ++ shn) r)
               -> (IxSOf target shm -> IxSOf target shp)
               -> Delta target (TKS2 (shm ++ shn) r)
  DeltaAppendS :: forall target r m n sh.
                  Delta target (TKS2 (m ': sh) r)
               -> Delta target (TKS2 (n ': sh) r)
               -> Delta target (TKS2 ((m + n) ': sh) r)
  DeltaSliceS :: SNat i -> SNat n -> SNat k
              -> Delta target (TKS2 (i + n + k ': sh) r)
              -> Delta target (TKS2 (n ': sh) r)
  DeltaReverseS :: Delta target (TKS2 (n ': sh) r)
                -> Delta target (TKS2 (n ': sh) r)
  DeltaTransposeS :: forall perm sh r target.
                     (Permutation.IsPermutation perm, Rank perm <= Rank sh)
                  => Permutation.Perm perm
                  -> Delta target (TKS2 sh r)
                  -> Delta target (TKS2 (Permutation.PermutePrefix perm sh) r)
  DeltaReshapeS :: Product sh ~ Product sh2
                => ShS sh2
                -> Delta target (TKS2 sh r)
                -> Delta target (TKS2 sh2 r)

  -- Mixed tensor operations
  DeltaCastX :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
             => Delta target (TKX sh r1) -> Delta target (TKX sh r2)
  DeltaSum0X :: Delta target (TKX2 sh r) -> Delta target (TKX2 '[] r)
  DeltaDot0X :: (GoodScalar r, Show (target (TKX sh r)))
             => target (TKX sh r) -> Delta target (TKX sh r)
             -> Delta target (TKX '[] r)
  DeltaIndexX :: forall shm shn r target.
                 StaticShX shn
              -> Delta target (TKX2 (shm ++ shn) r) -> IxXOf target shm
              -> Delta target (TKX2 shn r)
  DeltaScatterX :: StaticShX shm -> StaticShX shn -> StaticShX shp
                -> IShX (shp ++ shn) -> Delta target (TKX2 (shm ++ shn) r)
                -> (IxXOf target shm -> IxXOf target shp)
                -> Delta target (TKX2 (shp ++ shn) r)
  DeltaGatherX :: StaticShX shm -> StaticShX shn -> StaticShX shp
               -> IShX (shm ++ shn) -> Delta target (TKX2 (shp ++ shn) r)
               -> (IxXOf target shm -> IxXOf target shp)
               -> Delta target (TKX2 (shm ++ shn) r)
  DeltaAppendX :: Delta target (TKX2 (Just m ': sh) r)
               -> Delta target (TKX2 (Just n ': sh) r)
               -> Delta target (TKX2 (Just (m + n) ': sh) r)
  DeltaSliceX :: SNat i -> SNat n -> SNat k
              -> Delta target (TKX2 (Just (i + n + k) ': sh) r)
              -> Delta target (TKX2 (Just n ': sh) r)
  DeltaReverseX :: Delta target (TKX2 (mn ': sh) r)
                -> Delta target (TKX2 (mn ': sh) r)
  DeltaTransposeX :: forall perm sh r target.
                     (Permutation.IsPermutation perm, Rank perm <= Rank sh)
                  => Permutation.Perm perm
                  -> Delta target (TKX2 sh r)
                  -> Delta target (TKX2 (Permutation.PermutePrefix perm sh) r)
  DeltaReshapeX :: IShX sh2 -> Delta target (TKX2 sh r)
                -> Delta target (TKX2 sh2 r)

  -- Conversions
  DeltaConvert :: TKConversion a b -> Delta target a -> Delta target b

deriving instance Show (IntOf target) => Show (Delta target y)

-- | A newtype defined only to cut the knot of 'Show' instances in 'DeltaScale'
-- that are problematic to pass around as dictionaries without
-- bloating each constructor. The @DeltaScale@ constructor appears
-- in delta expressions a lot and so the primal
-- subterm would bloat the pretty-printed output (though OTOH the primal
-- terms are often important).
--
-- Possibly, @Has Show (Delta target)@ is a better solution.
type NestedTarget :: Target -> Target
type role NestedTarget nominal nominal
newtype NestedTarget target y = NestedTarget (target y)

instance Show (NestedTarget target y) where
  showsPrec :: Int -> NestedTarget target y -> ShowS
showsPrec Int
_ NestedTarget target y
_ = String -> ShowS
showString String
"<primal>"


-- * Full tensor kind derivation for delta expressions

-- | Full tensor kind derivation for delta expressions.
ftkDelta :: forall target y.
            Delta target y -> FullShapeTK y
ftkDelta :: forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta = \case
  DeltaShare NodeId target y
i Delta target y
_ -> NodeId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK NodeId target y
i
  DeltaInput InputId target y
i -> InputId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target y
i

  DeltaPair Delta target y
t1 Delta target z
t2 -> FullShapeTK y -> FullShapeTK z -> FullShapeTK (TKProduct y z)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
t1) (Delta target z -> FullShapeTK z
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target z
t2)
  DeltaProject1 Delta target (TKProduct y z)
v -> case Delta target (TKProduct y z) -> FullShapeTK (TKProduct y z)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct y z)
v of
    FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
_ -> FullShapeTK y
FullShapeTK y1
ftk1
  DeltaProject2 Delta target (TKProduct y y)
v -> case Delta target (TKProduct y y) -> FullShapeTK (TKProduct y y)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct y y)
v of
    FTKProduct FullShapeTK y1
_ FullShapeTK z
ftk2 -> FullShapeTK y
FullShapeTK z
ftk2
  DeltaFromVector SNat k
snat SingletonTK y
_ Vector (Delta target y)
l -> case Vector (Delta target y)
-> Maybe (Delta target y, Vector (Delta target y))
forall (v :: Type -> Type) a. Vector v a => v a -> Maybe (a, v a)
V.uncons Vector (Delta target y)
l of
    Maybe (Delta target y, Vector (Delta target y))
Nothing -> String -> FullShapeTK y
forall a. (?callStack::CallStack) => String -> a
error String
"ftkDelta: empty vector"
    Just (Delta target y
d, Vector (Delta target y)
_) -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d)
  DeltaSum SNat k
snat SingletonTK y
stk Delta target (BuildTensorKind k y)
d -> SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
snat SingletonTK y
stk (Delta target (BuildTensorKind k y)
-> FullShapeTK (BuildTensorKind k y)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (BuildTensorKind k y)
d)
  DeltaReplicate SNat k
snat SingletonTK y
_ Delta target y
d -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d)
  DeltaMapAccumR SNat k
k FullShapeTK by
bftk FullShapeTK ey
_eftk target (BuildTensorKind k accy)
_q target (BuildTensorKind k ey)
_es HFun
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
_df HFun
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
_rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
_es' ->
    FullShapeTK accy
-> FullShapeTK (BuildTensorKind k by)
-> FullShapeTK (TKProduct accy (BuildTensorKind k by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0') (SNat k -> FullShapeTK by -> FullShapeTK (BuildTensorKind k by)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK by
bftk)
  DeltaMapAccumL SNat k
k FullShapeTK by
bftk FullShapeTK ey
_eftk target (BuildTensorKind k accy)
_q target (BuildTensorKind k ey)
_es HFun
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
_df HFun
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
_rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
_es' ->
    FullShapeTK accy
-> FullShapeTK (BuildTensorKind k by)
-> FullShapeTK (TKProduct accy (BuildTensorKind k by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0') (SNat k -> FullShapeTK by -> FullShapeTK (BuildTensorKind k by)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK by
bftk)

  DeltaZero FullShapeTK y
ftk -> FullShapeTK y
ftk
  DeltaScale NestedTarget target y
_ Delta target y
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d
  DeltaAdd (DeltaShare NodeId target y
i Delta target y
_) Delta target y
_ -> NodeId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK NodeId target y
i
  DeltaAdd Delta target y
_ Delta target y
e -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
e

  DeltaCastK{} -> FullShapeTK y
FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar

  DeltaCastR Delta target (TKR n r1)
d -> case Delta target (TKR n r1) -> FullShapeTK (TKR n r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR n r1)
d of
    FTKR IShR n
sh FullShapeTK x
_ -> IShR n
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKR2 n (TKScalar r2))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR n
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  DeltaSum0R Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
    FTKR IShR n
_ FullShapeTK x
x -> IShR 0 -> FullShapeTK x -> FullShapeTK (TKR2 0 x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR 0
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR FullShapeTK x
x
  DeltaDot0R{} -> IShR 0
-> FullShapeTK (TKScalar r) -> FullShapeTK (TKR2 0 (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR 0
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  DeltaIndexR SNat n
SNat Delta target (TKR2 (m + n) r)
d IxROf target m
ix | SNat m
SNat <- IxROf target m -> SNat m
forall (n :: Nat) i. IxR n i -> SNat n
ixrRank IxROf target m
ix -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
    FTKR IShR n
sh FullShapeTK x
x -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (ShR (m + n) Int -> IShR n
forall (m :: Nat) (n :: Nat) i.
(KnownNat m, KnownNat n) =>
ShR (m + n) i -> ShR n i
shrDrop IShR n
ShR (m + n) Int
sh) FullShapeTK x
x
  DeltaScatterR SNat m
_ SNat n
_ SNat p
_ IShR (p + n)
sh Delta target (TKR2 (m + n) r)
d IxROf target m -> IxROf target p
_ -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
    FTKR IShR n
_ FullShapeTK x
x -> IShR (p + n) -> FullShapeTK x -> FullShapeTK (TKR2 (p + n) x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR (p + n)
sh FullShapeTK x
x
  DeltaGatherR SNat m
_ SNat n
_ SNat p
_ IShR (m + n)
sh Delta target (TKR2 (p + n) r)
d IxROf target m -> IxROf target p
_ -> case Delta target (TKR2 (p + n) r) -> FullShapeTK (TKR2 (p + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (p + n) r)
d of
    FTKR IShR n
_ FullShapeTK x
x -> IShR (m + n) -> FullShapeTK x -> FullShapeTK (TKR2 (m + n) x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR (m + n)
sh FullShapeTK x
x
  DeltaAppendR Delta target (TKR2 (1 + n) r)
a Delta target (TKR2 (1 + n) r)
b -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
a of
    FTKR ShR n Int
ZSR FullShapeTK x
_ -> String -> FullShapeTK y
forall a. (?callStack::CallStack) => String -> a
error String
"ftkDelta: impossible pattern needlessly required"
    FTKR (Int
ai :$: ShR n Int
ash) FullShapeTK x
x -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
Delta target (TKR2 (1 + n) r)
b of
      FTKR ShR n Int
ZSR FullShapeTK x
_ -> String -> FullShapeTK y
forall a. (?callStack::CallStack) => String -> a
error String
"ftkDelta: impossible pattern needlessly required"
      FTKR (Int
bi :$: ShR n Int
_) FullShapeTK x
_ -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
ai Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
bi Int -> ShR n Int -> IShR n
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR n Int
ash) FullShapeTK x
x
  DeltaSliceR Int
_ Int
n Delta target (TKR2 (1 + n) r)
d -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
    FTKR IShR n
sh FullShapeTK x
x -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
n Int -> ShR n Int -> IShR n
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR (n + 1) Int -> ShR n Int
forall (n :: Nat) i. ShR (n + 1) i -> ShR n i
shrTail IShR n
ShR (n + 1) Int
sh) FullShapeTK x
x
  DeltaReverseR Delta target (TKR2 (1 + n) r)
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
Delta target (TKR2 (1 + n) r)
d
  DeltaTransposeR PermR
perm Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
    FTKR IShR n
sh FullShapeTK x
x -> IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (PermR -> IShR n -> IShR n
forall (n :: Nat) i. PermR -> ShR n i -> ShR n i
shrPermutePrefix PermR
perm IShR n
sh) FullShapeTK x
x
  DeltaReshapeR IShR m
sh Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
    FTKR IShR n
_ FullShapeTK x
x -> IShR m -> FullShapeTK x -> FullShapeTK (TKR2 m x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR m
sh FullShapeTK x
x

  DeltaCastS Delta target (TKS sh r1)
d -> case Delta target (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS sh r1)
d of
    FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS2 sh (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  DeltaSum0S Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
    FTKS ShS sh
_ FullShapeTK x
x -> ShS ('[] @Nat) -> FullShapeTK x -> FullShapeTK (TKS2 ('[] @Nat) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK x
x
  DeltaDot0S{} -> ShS ('[] @Nat)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKS2 ('[] @Nat) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  DeltaIndexS ShS shn
shn Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm
_ix -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
    FTKS ShS sh
_ FullShapeTK x
x -> ShS shn -> FullShapeTK x -> FullShapeTK (TKS2 shn x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS shn
shn FullShapeTK x
x
  DeltaScatterS ShS shm
_shm ShS shn
shn ShS shp
shp Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm -> IxSOf target shp
_ -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
    FTKS ShS sh
_ FullShapeTK x
x -> ShS ((++) @Nat shp shn)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat shp shn) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS shp
shp ShS shp -> ShS shn -> ShS ((++) @Nat shp shn)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS shn
shn) FullShapeTK x
x
  DeltaGatherS ShS shm
shm ShS shn
shn ShS shp
_shp Delta target (TKS2 ((++) @Nat shp shn) r)
d IxSOf target shm -> IxSOf target shp
_ -> case Delta target (TKS2 ((++) @Nat shp shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shp shn) r)
d of
    FTKS ShS sh
_ FullShapeTK x
x -> ShS ((++) @Nat shm shn)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS shm
shm ShS shm -> ShS shn -> ShS ((++) @Nat shm shn)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS shn
shn) FullShapeTK x
x
  DeltaAppendS Delta target (TKS2 ((':) @Nat m sh) r)
a Delta target (TKS2 ((':) @Nat n sh) r)
b -> case (Delta target (TKS2 ((':) @Nat m sh) r)
-> FullShapeTK (TKS2 ((':) @Nat m sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat m sh) r)
a, Delta target (TKS2 ((':) @Nat n sh) r)
-> FullShapeTK (TKS2 ((':) @Nat n sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat n sh) r)
b) of
    (FTKS (SNat n
m :$$ ShS sh
sh) FullShapeTK x
x, FTKS (SNat n
n :$$ ShS sh
_) FullShapeTK x
_) -> ShS ((':) @Nat (m + n) sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat (m + n) sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n -> SNat n -> SNat (n + n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
snatPlus SNat n
m SNat n
n SNat (m + n) -> ShS sh -> ShS ((':) @Nat (m + n) sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
  DeltaSliceS SNat i
_ n :: SNat n
n@SNat n
SNat SNat k
_ Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d -> case Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d of
    FTKS (SNat n
_ :$$ ShS sh
sh) FullShapeTK x
x -> ShS ((':) @Nat n sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat n sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n
n SNat n -> ShS sh -> ShS ((':) @Nat n sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
  DeltaReverseS Delta target (TKS2 ((':) @Nat n sh) r)
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
Delta target (TKS2 ((':) @Nat n sh) r)
d
  DeltaTransposeS Perm perm
perm Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
    FTKS ShS sh
sh FullShapeTK x
x -> ShS (PermutePrefix @Nat perm sh)
-> FullShapeTK x
-> FullShapeTK (TKS2 (PermutePrefix @Nat perm sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (Perm perm -> ShS sh -> ShS (PermutePrefix @Nat perm sh)
forall (is :: [Nat]) (sh :: [Nat]).
Perm is -> ShS sh -> ShS (PermutePrefix @Nat is sh)
shsPermutePrefix Perm perm
perm ShS sh
sh) FullShapeTK x
x
  DeltaReshapeS ShS sh2
sh2 Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
    FTKS ShS sh
_ FullShapeTK x
x -> ShS sh2 -> FullShapeTK x -> FullShapeTK (TKS2 sh2 x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh2
sh2 FullShapeTK x
x

  DeltaCastX Delta target (TKX sh r1)
d -> case Delta target (TKX sh r1) -> FullShapeTK (TKX sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX sh r1)
d of
    FTKX IShX sh
sh FullShapeTK x
FTKScalar -> IShX sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKX2 sh (TKScalar r2))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  DeltaSum0X Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
    FTKX IShX sh
_ FullShapeTK x
x -> IShX ('[] @(Maybe Nat))
-> FullShapeTK x -> FullShapeTK (TKX2 ('[] @(Maybe Nat)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX FullShapeTK x
x
  DeltaDot0X{} -> IShX ('[] @(Maybe Nat))
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  DeltaIndexX @shm @shn StaticShX shn
shn Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm
ix -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
    FTKX IShX sh
sh FullShapeTK x
x | SNat @len <- IxXOf target shm -> SNat (Rank @(Maybe Nat) shm)
forall (sh :: [Maybe Nat]) i.
IxX sh i -> SNat (Rank @(Maybe Nat) sh)
ixxRank IxXOf target shm
ix ->
      (:~:)
  @[Maybe Nat] (Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh) shn
-> (((Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
     ~ (shn :: [Maybe Nat])) =>
    FullShapeTK y)
-> FullShapeTK y
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Nat] (Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh) shn
(:~:)
  @[Maybe Nat]
  (Drop
     @(Maybe Nat) (Rank @(Maybe Nat) shm) ((++) @(Maybe Nat) shm shn))
  shn
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Drop (Rank shm) (shm ++ shn) :~: shn) ((((Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
   ~ (shn :: [Maybe Nat])) =>
  FullShapeTK y)
 -> FullShapeTK y)
-> (((Drop @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
     ~ (shn :: [Maybe Nat])) =>
    FullShapeTK y)
-> FullShapeTK y
forall a b. (a -> b) -> a -> b
$
      StaticShX sh -> (KnownShX sh => FullShapeTK y) -> FullShapeTK y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => FullShapeTK y) -> FullShapeTK y)
-> (KnownShX sh => FullShapeTK y) -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$
      StaticShX shn -> (KnownShX shn => FullShapeTK y) -> FullShapeTK y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => FullShapeTK y) -> FullShapeTK y)
-> (KnownShX shn => FullShapeTK y) -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$
      IShX shn -> FullShapeTK x -> FullShapeTK (TKX2 shn x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (forall (len :: Nat) (sh :: [Maybe Nat]).
(KnownNat len, KnownShX sh, KnownShX (Drop @(Maybe Nat) len sh)) =>
IShX sh -> IShX (Drop @(Maybe Nat) len sh)
shxDrop @len IShX sh
sh) FullShapeTK x
x  -- TODO: (shxDropSSX sh (ssxFromIxX ix)) x
  DeltaScatterX StaticShX shm
_ StaticShX shn
_ StaticShX shp
_ IShX ((++) @(Maybe Nat) shp shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm -> IxXOf target shp
_ -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
    FTKX IShX sh
_ FullShapeTK x
x -> IShX ((++) @(Maybe Nat) shp shn)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shp shn) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ((++) @(Maybe Nat) shp shn)
sh FullShapeTK x
x
  DeltaGatherX StaticShX shm
_ StaticShX shn
_ StaticShX shp
_ IShX ((++) @(Maybe Nat) shm shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d IxXOf target shm -> IxXOf target shp
_ -> case Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d of
    FTKX IShX sh
_ FullShapeTK x
x -> IShX ((++) @(Maybe Nat) shm shn)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ((++) @(Maybe Nat) shm shn)
sh FullShapeTK x
x
  DeltaAppendX Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
a Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
b -> case (Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
a, Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
b) of
    (FTKX (Nested.SKnown SNat n1
m :$% ShX sh Int
sh) FullShapeTK x
x, FTKX (Nested.SKnown SNat n1
n :$% ShX sh Int
_) FullShapeTK x
_) ->
      IShX ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat (m + n) -> SMayNat @Nat Int SNat ('Just @Nat (m + n))
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (SNat n1 -> SNat n1 -> SNat (n1 + n1)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
snatPlus SNat n1
m SNat n1
n) SMayNat @Nat Int SNat ('Just @Nat (m + n))
-> ShX sh Int -> IShX ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX sh Int
sh) FullShapeTK x
x
  DeltaSliceX SNat i
_ n :: SNat n
n@SNat n
SNat SNat k
_ Delta
  target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d -> case Delta
  target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> FullShapeTK
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta
  target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d of
    FTKX (SMayNat @Nat Int SNat n
_ :$% ShX sh Int
sh) FullShapeTK x
x -> IShX ((':) @(Maybe Nat) ('Just @Nat n) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat n -> SMayNat @Nat Int SNat ('Just @Nat n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat n
n SMayNat @Nat Int SNat ('Just @Nat n)
-> ShX sh Int -> IShX ((':) @(Maybe Nat) ('Just @Nat n) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX sh Int
sh) FullShapeTK x
x
  DeltaReverseX Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d -> Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d
  DeltaTransposeX Perm perm
perm Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
    FTKX IShX sh
sh FullShapeTK x
x -> IShX (PermutePrefix @(Maybe Nat) perm sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (Perm perm
-> IShX sh -> ShX (PermutePrefix @(Maybe Nat) perm sh) Int
forall (is :: [Nat]) (sh :: [Maybe Nat]) i.
Perm is -> ShX sh i -> ShX (PermutePrefix @(Maybe Nat) is sh) i
shxPermutePrefix Perm perm
perm IShX sh
sh) FullShapeTK x
x
  DeltaReshapeX IShX sh2
sh2 Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
    FTKX IShX sh
_ FullShapeTK x
x -> IShX sh2 -> FullShapeTK x -> FullShapeTK (TKX2 sh2 x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh2
sh2 FullShapeTK x
x

  DeltaConvert TKConversion a y
c Delta target a
d -> TKConversion a y -> FullShapeTK a -> FullShapeTK y
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a y
c (FullShapeTK a -> FullShapeTK y) -> FullShapeTK a -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$ Delta target a -> FullShapeTK a
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a
d