{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-- | Types and functions needed to define general tensor operations
-- that work for any tensor kind, including nested (product) arrays
-- and an assortment of such operations.
module HordeAd.Core.Unwind
  ( addTarget, multTarget, sum0Target, dot0Target
  , replTarget, defTarget, concreteTarget
  , toADTensorKindShared, fromADTensorKindShared
  ) where

import Prelude

import Data.Default
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import GHC.TypeLits (type (+))
import Type.Reflection (typeRep)

import Data.Array.Nested (MapJust, Replicate, type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (unsafeCoerceRefl)

import HordeAd.Core.CarriersConcrete
import HordeAd.Core.ConvertTensor
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.OpsTensor

-- * Winding and unwinding

-- | This captures the normal form of type family UnWind and also
-- corresponds to the portion of ox-arrays that has Num defined.
type role RepW nominal nominal
data RepW target y where
  WTKScalar :: GoodScalar r
            => target (TKScalar r)
            -> RepW target (TKScalar r)
  WTKR :: GoodScalar r
       => target (TKR n r)
       -> RepW target (TKR n r)
  WTKS :: GoodScalar r
       => target (TKS sh r)
       -> RepW target (TKS sh r)
  WTKX :: GoodScalar r
       => target (TKX sh r)
       -> RepW target (TKX sh r)
  WTKProduct :: RepW target x -> RepW target z
             -> RepW target (TKProduct x z)

-- | This captures the normal form of type family UnWind for full shape
-- singletons.
type role FullShapeTKW nominal
data FullShapeTKW y where
  WFTKScalar :: GoodScalar r
             => FullShapeTKW (TKScalar r)
  WFTKR :: GoodScalar r
        => IShR n -> FullShapeTKW (TKR n r)
  WFTKS :: GoodScalar r
        => ShS sh -> FullShapeTKW (TKS sh r)
  WFTKX :: GoodScalar r
        => IShX sh -> FullShapeTKW (TKX sh r)
  WFTKProduct :: FullShapeTKW y -> FullShapeTKW z
              -> FullShapeTKW (TKProduct y z)

addRepW :: forall y target. BaseTensor target
        => RepW target y -> RepW target y -> RepW target y
addRepW :: forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
addRepW RepW target y
a RepW target y
b = case (RepW target y
a, RepW target y
b) of
  (WTKScalar target (TKScalar r)
ta, WTKScalar target (TKScalar r)
tb) -> target (TKScalar r) -> RepW target (TKScalar r)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar (target (TKScalar r) -> RepW target (TKScalar r))
-> target (TKScalar r) -> RepW target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKScalar r)
ta target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
+ target (TKScalar r)
target (TKScalar r)
tb
  (WTKR target (TKR2 n (TKScalar r))
ta, WTKR target (TKR2 n (TKScalar r))
tb) -> target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r))
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR (target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
-> RepW target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKR2 n (TKScalar r))
ta target (TKR2 n (TKScalar r))
-> target (TKR2 n (TKScalar r)) -> target (TKR2 n (TKScalar r))
forall a. Num a => a -> a -> a
+ target (TKR2 n (TKScalar r))
target (TKR2 n (TKScalar r))
tb
  (WTKS target (TKS2 sh (TKScalar r))
ta, WTKS target (TKS2 sh (TKScalar r))
tb) -> target (TKS2 sh (TKScalar r)) -> RepW target (TKS2 sh (TKScalar r))
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS (target (TKS2 sh (TKScalar r))
 -> RepW target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
-> RepW target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKS2 sh (TKScalar r))
ta target (TKS2 sh (TKScalar r))
-> target (TKS2 sh (TKScalar r)) -> target (TKS2 sh (TKScalar r))
forall a. Num a => a -> a -> a
+ target (TKS2 sh (TKScalar r))
target (TKS2 sh (TKScalar r))
tb
  (WTKX target (TKX2 sh (TKScalar r))
ta, WTKX target (TKX2 sh (TKScalar r))
tb) -> target (TKX2 sh (TKScalar r)) -> RepW target (TKX2 sh (TKScalar r))
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX (target (TKX2 sh (TKScalar r))
 -> RepW target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
-> RepW target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh (TKScalar r))
ta target (TKX2 sh (TKScalar r))
-> target (TKX2 sh (TKScalar r)) -> target (TKX2 sh (TKScalar r))
forall a. Num a => a -> a -> a
+ target (TKX2 sh (TKScalar r))
target (TKX2 sh (TKScalar r))
tb
  (WTKProduct RepW target x
ta1 RepW target z
ta2, WTKProduct RepW target x
tb1 RepW target z
tb2) ->
    RepW target x -> RepW target z -> RepW target (TKProduct x z)
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct (RepW target x -> RepW target x -> RepW target x
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
addRepW RepW target x
ta1 RepW target x
RepW target x
tb1) (RepW target z -> RepW target z -> RepW target z
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
addRepW RepW target z
ta2 RepW target z
RepW target z
tb2)

multRepW :: forall y target. BaseTensor target
         => RepW target y -> RepW target y -> RepW target y
multRepW :: forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
multRepW RepW target y
a RepW target y
b = case (RepW target y
a, RepW target y
b) of
  (WTKScalar target (TKScalar r)
ta, WTKScalar target (TKScalar r)
tb) -> target (TKScalar r) -> RepW target (TKScalar r)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar (target (TKScalar r) -> RepW target (TKScalar r))
-> target (TKScalar r) -> RepW target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKScalar r)
ta target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
target (TKScalar r)
tb
  (WTKR target (TKR2 n (TKScalar r))
ta, WTKR target (TKR2 n (TKScalar r))
tb) -> target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r))
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR (target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
-> RepW target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKR2 n (TKScalar r))
ta target (TKR2 n (TKScalar r))
-> target (TKR2 n (TKScalar r)) -> target (TKR2 n (TKScalar r))
forall a. Num a => a -> a -> a
* target (TKR2 n (TKScalar r))
target (TKR2 n (TKScalar r))
tb
  (WTKS target (TKS2 sh (TKScalar r))
ta, WTKS target (TKS2 sh (TKScalar r))
tb) -> target (TKS2 sh (TKScalar r)) -> RepW target (TKS2 sh (TKScalar r))
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS (target (TKS2 sh (TKScalar r))
 -> RepW target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
-> RepW target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKS2 sh (TKScalar r))
ta target (TKS2 sh (TKScalar r))
-> target (TKS2 sh (TKScalar r)) -> target (TKS2 sh (TKScalar r))
forall a. Num a => a -> a -> a
* target (TKS2 sh (TKScalar r))
target (TKS2 sh (TKScalar r))
tb
  (WTKX target (TKX2 sh (TKScalar r))
ta, WTKX target (TKX2 sh (TKScalar r))
tb) -> target (TKX2 sh (TKScalar r)) -> RepW target (TKX2 sh (TKScalar r))
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX (target (TKX2 sh (TKScalar r))
 -> RepW target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
-> RepW target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh (TKScalar r))
ta target (TKX2 sh (TKScalar r))
-> target (TKX2 sh (TKScalar r)) -> target (TKX2 sh (TKScalar r))
forall a. Num a => a -> a -> a
* target (TKX2 sh (TKScalar r))
target (TKX2 sh (TKScalar r))
tb
  (WTKProduct RepW target x
ta1 RepW target z
ta2, WTKProduct RepW target x
tb1 RepW target z
tb2) ->
    RepW target x -> RepW target z -> RepW target (TKProduct x z)
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct (RepW target x -> RepW target x -> RepW target x
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
multRepW RepW target x
ta1 RepW target x
RepW target x
tb1) (RepW target z -> RepW target z -> RepW target z
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
multRepW RepW target z
ta2 RepW target z
RepW target z
tb2)

sum0RepW :: forall y target. (BaseTensor target, ConvertTensor target)
         => FullShapeTKW y -> RepW target y
         -> target (TKScalar Double)
sum0RepW :: forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y -> RepW target y -> target (TKScalar Double)
sum0RepW FullShapeTKW y
ftk RepW target y
a = case (FullShapeTKW y
ftk, RepW target y
a) of
  (FullShapeTKW y
_, WTKScalar @r target (TKScalar r)
ta) ->
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast target (TKScalar r)
ta) target (TKScalar Double)
0
  (WFTKR IShR n
sh, WTKR @r target (TKR2 n (TKScalar r))
ta) | SNat n
SNat <- IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh ->
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKR 0 r) -> target (TKScalar r)
forall r. GoodScalar r => target (TKR 0 r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKR 0 r) -> target (TKScalar r)
kfromR (target (TKR 0 r) -> target (TKScalar r))
-> target (TKR 0 r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKR2 n (TKScalar r)) -> target (TKR 0 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
target (TKR2 n x) -> target (TKR2 0 x)
rsum0 target (TKR2 n (TKScalar r))
ta) target (TKScalar Double)
0
  (WFTKS ShS sh
sh, WTKS @r target (TKS2 sh (TKScalar r))
ta) ->
    ShS sh
-> (KnownShS sh => target (TKScalar Double))
-> target (TKScalar Double)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target (TKScalar Double))
 -> target (TKScalar Double))
-> (KnownShS sh => target (TKScalar Double))
-> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKS ('[] @Nat) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKS ('[] @Nat) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKS ('[] @Nat) r) -> target (TKScalar r)
kfromS (target (TKS ('[] @Nat) r) -> target (TKScalar r))
-> target (TKS ('[] @Nat) r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKS2 sh (TKScalar r)) -> target (TKS ('[] @Nat) r)
forall (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
ssum0 target (TKS2 sh (TKScalar r))
ta) target (TKScalar Double)
0
  (WFTKX IShX sh
sh, WTKX @r target (TKX2 sh (TKScalar r))
ta) ->
    StaticShX sh
-> (KnownShX sh => target (TKScalar Double))
-> target (TKScalar Double)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => target (TKScalar Double))
 -> target (TKScalar Double))
-> (KnownShX sh => target (TKScalar Double))
-> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
kfromX (target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r))
-> target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh (TKScalar r)) -> target (TKX ('[] @(Maybe Nat)) r)
forall (sh :: [Maybe Nat]) (x :: TK) (target :: Target).
(KnownShX sh, KnownSTK x, BaseTensor target,
 ConvertTensor target) =>
target (TKX2 sh x) -> target (TKX2 ('[] @(Maybe Nat)) x)
xsum0 target (TKX2 sh (TKScalar r))
ta) target (TKScalar Double)
0
  (WFTKProduct FullShapeTKW y
ftk1 FullShapeTKW z
ftk2, WTKProduct RepW target x
ta1 RepW target z
ta2) ->
    FullShapeTKW y -> RepW target y -> target (TKScalar Double)
forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y -> RepW target y -> target (TKScalar Double)
sum0RepW FullShapeTKW y
ftk1 RepW target y
RepW target x
ta1 target (TKScalar Double)
-> target (TKScalar Double) -> target (TKScalar Double)
forall a. Num a => a -> a -> a
+ FullShapeTKW z -> RepW target z -> target (TKScalar Double)
forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y -> RepW target y -> target (TKScalar Double)
sum0RepW FullShapeTKW z
ftk2 RepW target z
RepW target z
ta2

dot0RepW :: forall y target. (BaseTensor target, ConvertTensor target)
         => FullShapeTKW y -> RepW target y -> RepW target y
         -> target (TKScalar Double)
dot0RepW :: forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y
-> RepW target y -> RepW target y -> target (TKScalar Double)
dot0RepW FullShapeTKW y
ftk RepW target y
a RepW target y
b = case (FullShapeTKW y
ftk, RepW target y
a, RepW target y
b) of
  (FullShapeTKW y
_, WTKScalar @r target (TKScalar r)
ta, WTKScalar target (TKScalar r)
tb) ->
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKScalar r)
ta target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
target (TKScalar r)
tb) target (TKScalar Double)
0
  (WFTKR IShR n
sh, WTKR @r target (TKR2 n (TKScalar r))
ta, WTKR target (TKR2 n (TKScalar r))
tb) | SNat n
SNat <- IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh ->
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKR 0 r) -> target (TKScalar r)
forall r. GoodScalar r => target (TKR 0 r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKR 0 r) -> target (TKScalar r)
kfromR (target (TKR 0 r) -> target (TKScalar r))
-> target (TKR 0 r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKR2 n (TKScalar r))
-> target (TKR2 n (TKScalar r)) -> target (TKR 0 r)
forall (n :: Nat) r (target :: Target).
(KnownNat n, GoodScalar r, BaseTensor target) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
rdot0 target (TKR2 n (TKScalar r))
ta target (TKR2 n (TKScalar r))
target (TKR2 n (TKScalar r))
tb) target (TKScalar Double)
0
  (WFTKS ShS sh
sh, WTKS @r target (TKS2 sh (TKScalar r))
ta, WTKS target (TKS2 sh (TKScalar r))
tb) ->
    ShS sh
-> (KnownShS sh => target (TKScalar Double))
-> target (TKScalar Double)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target (TKScalar Double))
 -> target (TKScalar Double))
-> (KnownShS sh => target (TKScalar Double))
-> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKS ('[] @Nat) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKS ('[] @Nat) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKS ('[] @Nat) r) -> target (TKScalar r)
kfromS (target (TKS ('[] @Nat) r) -> target (TKScalar r))
-> target (TKS ('[] @Nat) r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKS2 sh (TKScalar r))
-> target (TKS2 sh (TKScalar r)) -> target (TKS ('[] @Nat) r)
forall (sh :: [Nat]) r (target :: Target).
(KnownShS sh, GoodScalar r, BaseTensor target) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
sdot0 target (TKS2 sh (TKScalar r))
ta target (TKS2 sh (TKScalar r))
target (TKS2 sh (TKScalar r))
tb) target (TKScalar Double)
0
  (WFTKX IShX sh
sh, WTKX @r target (TKX2 sh (TKScalar r))
ta, WTKX target (TKX2 sh (TKScalar r))
tb) ->
    StaticShX sh
-> (KnownShX sh => target (TKScalar Double))
-> target (TKScalar Double)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => target (TKScalar Double))
 -> target (TKScalar Double))
-> (KnownShX sh => target (TKScalar Double))
-> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (target (TKScalar r) -> target (TKScalar Double)
forall r1 r2 (target :: Target).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKScalar r1) -> target (TKScalar r2)
kcast (target (TKScalar r) -> target (TKScalar Double))
-> target (TKScalar r) -> target (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
kfromX (target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r))
-> target (TKX ('[] @(Maybe Nat)) r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh (TKScalar r))
-> target (TKX2 sh (TKScalar r))
-> target (TKX ('[] @(Maybe Nat)) r)
forall (sh :: [Maybe Nat]) r (target :: Target).
(KnownShX sh, GoodScalar r, BaseTensor target,
 ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Nat)) r)
xdot0 target (TKX2 sh (TKScalar r))
ta target (TKX2 sh (TKScalar r))
target (TKX2 sh (TKScalar r))
tb) target (TKScalar Double)
0
  (WFTKProduct FullShapeTKW y
ftk1 FullShapeTKW z
ftk2, WTKProduct RepW target x
ta1 RepW target z
ta2, WTKProduct RepW target x
tb1 RepW target z
tb2) ->
    FullShapeTKW y
-> RepW target y -> RepW target y -> target (TKScalar Double)
forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y
-> RepW target y -> RepW target y -> target (TKScalar Double)
dot0RepW FullShapeTKW y
ftk1 RepW target y
RepW target x
ta1 RepW target y
RepW target x
tb1 target (TKScalar Double)
-> target (TKScalar Double) -> target (TKScalar Double)
forall a. Num a => a -> a -> a
+ FullShapeTKW z
-> RepW target z -> RepW target z -> target (TKScalar Double)
forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y
-> RepW target y -> RepW target y -> target (TKScalar Double)
dot0RepW FullShapeTKW z
ftk2 RepW target z
RepW target z
ta2 RepW target z
RepW target z
tb2

replRepW :: forall y target. BaseTensor target
         => (forall r. GoodScalar r => r)
         -> FullShapeTKW y -> RepW target y
replRepW :: forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW forall r. GoodScalar r => r
r = \case
  FullShapeTKW y
WFTKScalar -> target (TKScalar r) -> RepW target (TKScalar r)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar (target (TKScalar r) -> RepW target (TKScalar r))
-> target (TKScalar r) -> RepW target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete r
forall r. GoodScalar r => r
r
  WFTKR IShR n
sh -> target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r))
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR (target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
-> RepW target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ IShR n -> r -> target (TKR2 n (TKScalar r))
forall (n :: Nat) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShR n -> r -> target (TKR n r)
rrepl IShR n
sh r
forall r. GoodScalar r => r
r
  WFTKS ShS sh
sh -> target (TKS2 sh (TKScalar r)) -> RepW target (TKS2 sh (TKScalar r))
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS (target (TKS2 sh (TKScalar r))
 -> RepW target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
-> RepW target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Shaped sh r -> target (TKS2 sh (TKScalar r))
forall r (target :: Target) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete (Shaped sh r -> target (TKS2 sh (TKScalar r)))
-> Shaped sh r -> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ ShS sh -> r -> Shaped sh r
forall a (sh :: [Nat]). PrimElt a => ShS sh -> a -> Shaped sh a
Nested.sreplicateScal ShS sh
sh r
forall r. GoodScalar r => r
r
  WFTKX IShX sh
sh -> target (TKX2 sh (TKScalar r)) -> RepW target (TKX2 sh (TKScalar r))
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX (target (TKX2 sh (TKScalar r))
 -> RepW target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
-> RepW target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ IShX sh -> r -> target (TKX2 sh (TKScalar r))
forall (sh :: [Maybe Nat]) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShX sh -> r -> target (TKX sh r)
xrepl IShX sh
sh r
forall r. GoodScalar r => r
r
  WFTKProduct FullShapeTKW y
ftk1 FullShapeTKW z
ftk2 ->
    RepW target y -> RepW target z -> RepW target (TKProduct y z)
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct ((forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
r FullShapeTKW y
ftk1) ((forall r. GoodScalar r => r) -> FullShapeTKW z -> RepW target z
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
r FullShapeTKW z
ftk2)

defRepW :: forall y target. BaseTensor target
        => FullShapeTKW y -> RepW target y
defRepW :: forall (y :: TK) (target :: Target).
BaseTensor target =>
FullShapeTKW y -> RepW target y
defRepW = \case
  FullShapeTKW y
WFTKScalar -> target (TKScalar r) -> RepW target (TKScalar r)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar (target (TKScalar r) -> RepW target (TKScalar r))
-> target (TKScalar r) -> RepW target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete r
forall a. Default a => a
def
  WFTKR IShR n
sh -> target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r))
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR (target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
-> RepW target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ IShR n -> r -> target (TKR2 n (TKScalar r))
forall (n :: Nat) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShR n -> r -> target (TKR n r)
rrepl IShR n
sh r
forall a. Default a => a
def
  WFTKS ShS sh
sh -> target (TKS2 sh (TKScalar r)) -> RepW target (TKS2 sh (TKScalar r))
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS (target (TKS2 sh (TKScalar r))
 -> RepW target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
-> RepW target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Shaped sh r -> target (TKS2 sh (TKScalar r))
forall r (target :: Target) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete (Shaped sh r -> target (TKS2 sh (TKScalar r)))
-> Shaped sh r -> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ ShS sh -> r -> Shaped sh r
forall a (sh :: [Nat]). PrimElt a => ShS sh -> a -> Shaped sh a
Nested.sreplicateScal ShS sh
sh r
forall a. Default a => a
def
  WFTKX IShX sh
sh -> target (TKX2 sh (TKScalar r)) -> RepW target (TKX2 sh (TKScalar r))
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX (target (TKX2 sh (TKScalar r))
 -> RepW target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
-> RepW target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ IShX sh -> r -> target (TKX2 sh (TKScalar r))
forall (sh :: [Maybe Nat]) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShX sh -> r -> target (TKX sh r)
xrepl IShX sh
sh r
forall a. Default a => a
def
  WFTKProduct FullShapeTKW y
ftk1 FullShapeTKW z
ftk2 ->
    RepW target y -> RepW target z -> RepW target (TKProduct y z)
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct (FullShapeTKW y -> RepW target y
forall (y :: TK) (target :: Target).
BaseTensor target =>
FullShapeTKW y -> RepW target y
defRepW FullShapeTKW y
ftk1) (FullShapeTKW z -> RepW target z
forall (y :: TK) (target :: Target).
BaseTensor target =>
FullShapeTKW y -> RepW target y
defRepW FullShapeTKW z
ftk2)

concreteRepW
  :: forall y target. (ConvertTensor Concrete, ConvertTensor target)
  => (forall r. GoodScalar r => Concrete (TKScalar r) -> target (TKScalar r))
  -> (forall r sh. GoodScalar r => Concrete (TKS sh r) -> target (TKS sh r))
  -> (forall x z. FullShapeTK z -> target x -> target z)
  -> RepW Concrete y -> RepW target y
{-# INLINE concreteRepW #-}
concreteRepW :: forall (y :: TK) (target :: Target).
(ConvertTensor Concrete, ConvertTensor target) =>
(forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete y
-> RepW target y
concreteRepW forall r.
GoodScalar r =>
Concrete (TKScalar r) -> target (TKScalar r)
concreteK forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS RepW Concrete y
w = case RepW Concrete y
w of
  WTKScalar Concrete (TKScalar r)
v -> target (TKScalar r) -> RepW target (TKScalar r)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar (target (TKScalar r) -> RepW target (TKScalar r))
-> target (TKScalar r) -> RepW target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Concrete (TKScalar r) -> target (TKScalar r)
forall r.
GoodScalar r =>
Concrete (TKScalar r) -> target (TKScalar r)
concreteK Concrete (TKScalar r)
v
  WTKR Concrete (TKR2 n (TKScalar r))
v -> target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r))
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR (target (TKR2 n (TKScalar r)) -> RepW target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
-> RepW target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$
    let sh' :: IShR n
sh' = Ranked n r -> IShR n
forall a (n :: Nat). Elt a => Ranked n a -> IShR n
Nested.rshape (Ranked n r -> IShR n) -> Ranked n r -> IShR n
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 n (TKScalar r)) -> RepConcrete (TKR2 n (TKScalar r))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 n (TKScalar r))
v
    in IShR n
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
    ShS sh -> target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
forall (n :: Nat) r.
IShR n
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
    ShS sh -> r)
-> r
withShsFromShR IShR n
sh' ((forall (sh :: [Nat]).
  ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
  ShS sh -> target (TKR2 n (TKScalar r)))
 -> target (TKR2 n (TKScalar r)))
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
    ShS sh -> target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
      ShS sh
-> (KnownShS sh => target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target (TKR2 n (TKScalar r)))
 -> target (TKR2 n (TKScalar r)))
-> (KnownShS sh => target (TKR2 n (TKScalar r)))
-> target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$
      FullShapeTK (TKR2 n (TKScalar r))
-> target (TKS sh r) -> target (TKR2 n (TKScalar r))
forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS (IShR n
-> FullShapeTK (TKScalar r) -> FullShapeTK (TKR2 n (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR n
sh' FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
      (target (TKS sh r) -> target (TKR2 n (TKScalar r)))
-> target (TKS sh r) -> target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Concrete (TKS sh r) -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS (forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
sfromR @_ @sh Concrete (TKR2 n (TKScalar r))
Concrete (TKR2 (Rank @Nat sh) (TKScalar r))
v)
  WTKS Concrete (TKS2 sh (TKScalar r))
v -> target (TKS2 sh (TKScalar r)) -> RepW target (TKS2 sh (TKScalar r))
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS (target (TKS2 sh (TKScalar r))
 -> RepW target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
-> RepW target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Concrete (TKS2 sh (TKScalar r)) -> target (TKS2 sh (TKScalar r))
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS Concrete (TKS2 sh (TKScalar r))
v
  WTKX Concrete (TKX2 sh (TKScalar r))
v -> target (TKX2 sh (TKScalar r)) -> RepW target (TKX2 sh (TKScalar r))
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX (target (TKX2 sh (TKScalar r))
 -> RepW target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
-> RepW target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$
    let sh' :: IShX sh
sh' = Mixed sh r -> IShX sh
forall (sh :: [Maybe Nat]). Mixed sh r -> IShX sh
forall a (sh :: [Maybe Nat]). Elt a => Mixed sh a -> IShX sh
Nested.mshape (Mixed sh r -> IShX sh) -> Mixed sh r -> IShX sh
forall a b. (a -> b) -> a -> b
$ Concrete (TKX2 sh (TKScalar r))
-> RepConcrete (TKX2 sh (TKScalar r))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKX2 sh (TKScalar r))
v
    in IShX sh
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
    ShS sh -> target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall (sh' :: [Maybe Nat]) r.
IShX sh'
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
    ShS sh -> r)
-> r
withShsFromShX IShX sh
sh' ((forall (sh :: [Nat]).
  ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
  ShS sh -> target (TKX2 sh (TKScalar r)))
 -> target (TKX2 sh (TKScalar r)))
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
    ShS sh -> target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
      ShS sh
-> (KnownShS sh => target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target (TKX2 sh (TKScalar r)))
 -> target (TKX2 sh (TKScalar r)))
-> (KnownShS sh => target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$
      FullShapeTK (TKX2 sh (TKScalar r))
-> target (TKS sh r) -> target (TKX2 sh (TKScalar r))
forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS (IShX sh
-> FullShapeTK (TKScalar r) -> FullShapeTK (TKX2 sh (TKScalar r))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh
sh' FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
      (target (TKS sh r) -> target (TKX2 sh (TKScalar r)))
-> target (TKS sh r) -> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Concrete (TKS sh r) -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS (forall (target :: Target) (sh :: [Nat]) (sh' :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh,
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKX2 sh' x) -> target (TKS2 sh x)
sfromX @_ @sh Concrete (TKX2 sh (TKScalar r))
v)
  WTKProduct RepW Concrete x
v1 RepW Concrete z
v2 ->
    RepW target x -> RepW target z -> RepW target (TKProduct x z)
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct ((forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete x
-> RepW target x
forall (y :: TK) (target :: Target).
(ConvertTensor Concrete, ConvertTensor target) =>
(forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete y
-> RepW target y
concreteRepW Concrete (TKScalar r) -> target (TKScalar r)
forall r.
GoodScalar r =>
Concrete (TKScalar r) -> target (TKScalar r)
concreteK Concrete (TKS sh r) -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS FullShapeTK z -> target x -> target z
forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS RepW Concrete x
v1)
               ((forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete z
-> RepW target z
forall (y :: TK) (target :: Target).
(ConvertTensor Concrete, ConvertTensor target) =>
(forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete y
-> RepW target y
concreteRepW Concrete (TKScalar r) -> target (TKScalar r)
forall r.
GoodScalar r =>
Concrete (TKScalar r) -> target (TKScalar r)
concreteK Concrete (TKS sh r) -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS FullShapeTK z -> target x -> target z
forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS RepW Concrete z
v2)

toADTensorKindW
  :: forall y target. BaseTensor target
  => RepW target y -> FullShapeTKW y -> RepW target (ADTensorKind y)
toADTensorKindW :: forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> FullShapeTKW y -> RepW target (ADTensorKind y)
toADTensorKindW RepW target y
t = \case
  WFTKScalar @r -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> RepW target y
RepW target (ADTensorKind y)
t
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r Float)
_ -> (:~:) @Type (ADTensorScalar r) Z1
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Type (ADTensorScalar r) Z1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorScalar r :~: Z1) ((((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
  RepW target (ADTensorKind y))
 -> RepW target (ADTensorKind y))
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$
           target (TKScalar Z1) -> RepW target (TKScalar Z1)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar (target (TKScalar Z1) -> RepW target (TKScalar Z1))
-> target (TKScalar Z1) -> RepW target (TKScalar Z1)
forall a b. (a -> b) -> a -> b
$ Z1 -> target (TKScalar Z1)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete Z1
Z1
  WFTKR @r IShR n
sh -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> RepW target y
RepW target (ADTensorKind y)
t
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r Float)
_ -> (:~:) @Type (ADTensorScalar r) Z1
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Type (ADTensorScalar r) Z1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorScalar r :~: Z1) ((((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
  RepW target (ADTensorKind y))
 -> RepW target (ADTensorKind y))
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$
           target (TKR n Z1) -> RepW target (TKR n Z1)
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR (target (TKR n Z1) -> RepW target (TKR n Z1))
-> target (TKR n Z1) -> RepW target (TKR n Z1)
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShR n -> r -> target (TKR n r)
rrepl @_ @_ @target IShR n
sh Z1
Z1
  WFTKS @r ShS sh
sh -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> RepW target y
RepW target (ADTensorKind y)
t
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r Float)
_ -> (:~:) @Type (ADTensorScalar r) Z1
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Type (ADTensorScalar r) Z1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorScalar r :~: Z1) ((((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
  RepW target (ADTensorKind y))
 -> RepW target (ADTensorKind y))
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$
           target (TKS sh Z1) -> RepW target (TKS sh Z1)
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS (target (TKS sh Z1) -> RepW target (TKS sh Z1))
-> target (TKS sh Z1) -> RepW target (TKS sh Z1)
forall a b. (a -> b) -> a -> b
$ Shaped sh Z1 -> target (TKS sh Z1)
forall r (target :: Target) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete (Shaped sh Z1 -> target (TKS sh Z1))
-> Shaped sh Z1 -> target (TKS sh Z1)
forall a b. (a -> b) -> a -> b
$ ShS sh -> Z1 -> Shaped sh Z1
forall a (sh :: [Nat]). PrimElt a => ShS sh -> a -> Shaped sh a
Nested.sreplicateScal ShS sh
sh Z1
Z1
  WFTKX @r IShX sh
sh -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> RepW target y
RepW target (ADTensorKind y)
t
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r Float)
_ -> (:~:) @Type (ADTensorScalar r) Z1
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Type (ADTensorScalar r) Z1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorScalar r :~: Z1) ((((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
  RepW target (ADTensorKind y))
 -> RepW target (ADTensorKind y))
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    RepW target (ADTensorKind y))
-> RepW target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$
           target (TKX sh Z1) -> RepW target (TKX sh Z1)
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX (target (TKX sh Z1) -> RepW target (TKX sh Z1))
-> target (TKX sh Z1) -> RepW target (TKX sh Z1)
forall a b. (a -> b) -> a -> b
$ forall (sh :: [Maybe Nat]) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShX sh -> r -> target (TKX sh r)
xrepl @_ @_ @target IShX sh
sh Z1
Z1
  WFTKProduct FullShapeTKW y
ftk1 FullShapeTKW z
ftk2 -> case RepW target y
t of
    WTKProduct RepW target x
t1 RepW target z
t2 ->
      RepW target (ADTensorKind y)
-> RepW target (ADTensorKind z)
-> RepW target (TKProduct (ADTensorKind y) (ADTensorKind z))
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct (RepW target x -> FullShapeTKW x -> RepW target (ADTensorKind x)
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> FullShapeTKW y -> RepW target (ADTensorKind y)
toADTensorKindW RepW target x
t1 FullShapeTKW y
FullShapeTKW x
ftk1) (RepW target z -> FullShapeTKW z -> RepW target (ADTensorKind z)
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> FullShapeTKW y -> RepW target (ADTensorKind y)
toADTensorKindW RepW target z
t2 FullShapeTKW z
FullShapeTKW z
ftk2)

fromADTensorKindW
  :: forall y target. BaseTensor target
  => SingletonTK y -> RepW target (ADTensorKind y) -> RepW target y
fromADTensorKindW :: forall (y :: TK) (target :: Target).
BaseTensor target =>
SingletonTK y -> RepW target (ADTensorKind y) -> RepW target y
fromADTensorKindW SingletonTK y
stk RepW target (ADTensorKind y)
t = case (SingletonTK y
stk, RepW target (ADTensorKind y)
t) of
  (STKScalar @r1, WTKScalar @r2 target (TKScalar r)
_) ->
    case TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) of
      Just (:~:) @Type r r
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r r)
_ -> (forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
0 FullShapeTKW y
FullShapeTKW (TKScalar r)
forall y. GoodScalar y => FullShapeTKW (TKScalar y)
WFTKScalar
  (STKR SNat n
_ (STKScalar @r1), WTKR @r2 target (TKR2 n (TKScalar r))
v) ->
    case TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) of
      Just (:~:) @Type r r
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r r)
_ -> (forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
0 (IShR n -> FullShapeTKW (TKR n r)
forall y (z :: Nat).
GoodScalar y =>
IShR z -> FullShapeTKW (TKR z y)
WFTKR (target (TKR2 n (TKScalar r)) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 n (TKScalar r))
v))
  (STKS ShS sh
sh (STKScalar @r1), WTKS @r2 target (TKS2 sh (TKScalar r))
_) ->
    case TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) of
      Just (:~:) @Type r r
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r r)
_ -> (forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
0 (ShS sh -> FullShapeTKW (TKS sh r)
forall y (z :: [Nat]).
GoodScalar y =>
ShS z -> FullShapeTKW (TKS z y)
WFTKS ShS sh
sh)
  (STKX StaticShX sh
_ (STKScalar @r1), WTKX @r2 target (TKX2 sh (TKScalar r))
v) ->
    case TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) of
      Just (:~:) @Type r r
Refl -> RepW target y
RepW target (ADTensorKind y)
t
      Maybe ((:~:) @Type r r)
_ -> (forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
0 (IShX sh -> FullShapeTKW (TKX sh r)
forall y (z :: [Maybe Nat]).
GoodScalar y =>
IShX z -> FullShapeTKW (TKX z y)
WFTKX (target (TKX2 sh (TKScalar r)) -> IShX sh
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 sh (TKScalar r))
v))
  (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2, WTKProduct RepW target x
t1 RepW target z
t2) ->
    RepW target y1 -> RepW target z -> RepW target (TKProduct y1 z)
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct (SingletonTK y1 -> RepW target (ADTensorKind y1) -> RepW target y1
forall (y :: TK) (target :: Target).
BaseTensor target =>
SingletonTK y -> RepW target (ADTensorKind y) -> RepW target y
fromADTensorKindW SingletonTK y1
stk1 RepW target x
RepW target (ADTensorKind y1)
t1) (SingletonTK z -> RepW target (ADTensorKind z) -> RepW target z
forall (y :: TK) (target :: Target).
BaseTensor target =>
SingletonTK y -> RepW target (ADTensorKind y) -> RepW target y
fromADTensorKindW SingletonTK z
stk2 RepW target z
RepW target (ADTensorKind z)
t2)
  (SingletonTK y, RepW target (ADTensorKind y))
_ -> [Char] -> RepW target y
forall a. HasCallStack => [Char] -> a
error [Char]
"fromADTensorKindW: impossible SingletonTK"

type family UnWind y where
  UnWind (TKScalar r) =
    TKScalar r
  UnWind (TKR2 n (TKScalar r)) =
    TKR2 n (TKScalar r)
  UnWind (TKR2 n (TKR2 m x)) =
    UnWind (TKR2 (n + m) x)
  UnWind (TKR2 n (TKS2 sh2 x)) =
    UnWind (TKX2 (Replicate n Nothing ++ MapJust sh2) x)
  UnWind (TKR2 n (TKX2 sh2 x)) =
    UnWind (TKX2 (Replicate n Nothing ++ sh2) x)
  UnWind (TKR2 n (TKProduct y z)) =
    TKProduct (UnWind (TKR2 n y)) (UnWind (TKR2 n z))
  UnWind (TKS2 sh1 (TKScalar r)) =
    TKS2 sh1 (TKScalar r)
  UnWind (TKS2 sh1 (TKR2 m x)) =
    UnWind (TKX2 (MapJust sh1 ++ Replicate m Nothing) x)
  UnWind (TKS2 sh1 (TKS2 sh2 x)) =
    UnWind (TKS2 (sh1 ++ sh2) x)
  UnWind (TKS2 sh1 (TKX2 sh2 x)) =
    UnWind (TKX2 (MapJust sh1 ++ sh2) x)
  UnWind (TKS2 sh1 (TKProduct y z)) =
    TKProduct (UnWind (TKS2 sh1 y)) (UnWind (TKS2 sh1 z))
  UnWind (TKX2 sh1 (TKScalar r)) =
    TKX2 sh1 (TKScalar r)
  UnWind (TKX2 sh1 (TKR2 m x)) =
    UnWind (TKX2 (sh1 ++ Replicate m Nothing) x)
  UnWind (TKX2 sh1 (TKS2 sh2 x)) =
    UnWind (TKX2 (sh1 ++ MapJust sh2) x)
  UnWind (TKX2 sh1 (TKX2 sh2 x)) =
    UnWind (TKX2 (sh1 ++ sh2) x)
  UnWind (TKX2 sh1 (TKProduct y z)) =
    TKProduct (UnWind (TKX2 sh1 y)) (UnWind (TKX2 sh1 z))
  UnWind (TKProduct y z) =
    TKProduct (UnWind y) (UnWind z)

unWindSTK :: SingletonTK y -> SingletonTK (UnWind y)
unWindSTK :: forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK = \case
  stk :: SingletonTK y
stk@SingletonTK y
STKScalar -> SingletonTK y
SingletonTK (UnWind y)
stk
  stk :: SingletonTK y
stk@(STKR SNat n
_ SingletonTK x
STKScalar) -> SingletonTK y
SingletonTK (UnWind y)
stk
  STKR (SNat @n) (STKR (SNat @m) SingletonTK x
stk2) ->
    SingletonTK (TKR2 (n + n) x)
-> SingletonTK (UnWind (TKR2 (n + n) x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKR2 (n + n) x)
 -> SingletonTK (UnWind (TKR2 (n + n) x)))
-> SingletonTK (TKR2 (n + n) x)
-> SingletonTK (UnWind (TKR2 (n + n) x))
forall a b. (a -> b) -> a -> b
$ SNat (n + n) -> SingletonTK x -> SingletonTK (TKR2 (n + n) x)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR (forall (n :: Nat). KnownNat n => SNat n
SNat @(n + m)) SingletonTK x
stk2
  STKR SNat n
n (STKS ShS sh
sh2 SingletonTK x
stk2) ->
    SingletonTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
     x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))
              (MapJust @Nat sh))
           x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK
    (SingletonTK
   (TKX2
      ((++)
         @(Maybe Nat)
         (Replicate @(Maybe Nat) n ('Nothing @Nat))
         (MapJust @Nat sh))
      x)
 -> SingletonTK
      (UnWind
         (TKX2
            ((++)
               @(Maybe Nat)
               (Replicate @(Maybe Nat) n ('Nothing @Nat))
               (MapJust @Nat sh))
            x)))
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))
              (MapJust @Nat sh))
           x))
forall a b. (a -> b) -> a -> b
$ StaticShX
  ((++)
     @(Maybe Nat)
     (Replicate @(Maybe Nat) n ('Nothing @Nat))
     (MapJust @Nat sh))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX (MapJust @Nat sh)
-> StaticShX
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2)) SingletonTK x
stk2
  STKR SNat n
n (STKX StaticShX sh
sh2 SingletonTK x
stk2) ->
    SingletonTK
  (TKX2
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
     x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
           x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK
   (TKX2
      ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
      x)
 -> SingletonTK
      (UnWind
         (TKX2
            ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
            x)))
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
           x))
forall a b. (a -> b) -> a -> b
$ StaticShX
  ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX sh
-> StaticShX
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2
  STKR n :: SNat n
n@SNat n
SNat (STKProduct SingletonTK y1
y SingletonTK z
z) ->
    SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
-> SingletonTK (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
 -> SingletonTK (UnWind (TKProduct (TKR2 n y1) (TKR2 n z))))
-> SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
-> SingletonTK (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKR2 n y1)
-> SingletonTK (TKR2 n z)
-> SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (SNat n -> SingletonTK y1 -> SingletonTK (TKR2 n y1)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
n SingletonTK y1
y) (SNat n -> SingletonTK z -> SingletonTK (TKR2 n z)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
n SingletonTK z
z)
  stk :: SingletonTK y
stk@(STKS ShS sh
_ SingletonTK x
STKScalar) -> SingletonTK y
SingletonTK (UnWind y)
stk
  STKS ShS sh
sh1 (STKR SNat n
m SingletonTK x
stk2) ->
    SingletonTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (MapJust @Nat sh)
              (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK
    (SingletonTK
   (TKX2
      ((++)
         @(Maybe Nat)
         (MapJust @Nat sh)
         (Replicate @(Maybe Nat) n ('Nothing @Nat)))
      x)
 -> SingletonTK
      (UnWind
         (TKX2
            ((++)
               @(Maybe Nat)
               (MapJust @Nat sh)
               (Replicate @(Maybe Nat) n ('Nothing @Nat)))
            x)))
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (MapJust @Nat sh)
              (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall a b. (a -> b) -> a -> b
$ StaticShX
  ((++)
     @(Maybe Nat)
     (MapJust @Nat sh)
     (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1) StaticShX (MapJust @Nat sh)
-> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
m) SingletonTK x
stk2
  STKS ShS sh
sh1 (STKS ShS sh
sh2 SingletonTK x
stk2) ->
    SingletonTK (TKS2 ((++) @Nat sh sh) x)
-> SingletonTK (UnWind (TKS2 ((++) @Nat sh sh) x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKS2 ((++) @Nat sh sh) x)
 -> SingletonTK (UnWind (TKS2 ((++) @Nat sh sh) x)))
-> SingletonTK (TKS2 ((++) @Nat sh sh) x)
-> SingletonTK (UnWind (TKS2 ((++) @Nat sh sh) x))
forall a b. (a -> b) -> a -> b
$ ShS ((++) @Nat sh sh)
-> SingletonTK x -> SingletonTK (TKS2 ((++) @Nat sh sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS (ShS sh
sh1 ShS sh -> ShS sh -> ShS ((++) @Nat sh sh)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS sh
sh2) SingletonTK x
stk2
  STKS ShS sh
sh1 (STKX StaticShX sh
sh2 SingletonTK x
stk2) ->
    SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> SingletonTK
     (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
 -> SingletonTK
      (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)))
-> SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> SingletonTK
     (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
forall a b. (a -> b) -> a -> b
$ StaticShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
-> SingletonTK x
-> SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1) StaticShX (MapJust @Nat sh)
-> StaticShX sh
-> StaticShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2
  STKS ShS sh
sh1 (STKProduct SingletonTK y1
y SingletonTK z
z)->
    SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> SingletonTK (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
 -> SingletonTK (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z))))
-> SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> SingletonTK (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKS2 sh y1)
-> SingletonTK (TKS2 sh z)
-> SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (ShS sh -> SingletonTK y1 -> SingletonTK (TKS2 sh y1)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh1 SingletonTK y1
y) (ShS sh -> SingletonTK z -> SingletonTK (TKS2 sh z)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh1 SingletonTK z
z)
  stk :: SingletonTK y
stk@(STKX StaticShX sh
_ SingletonTK x
STKScalar) -> SingletonTK y
SingletonTK (UnWind y)
stk
  STKX StaticShX sh
sh1 (STKR SNat n
m SingletonTK x
stk2) ->
    SingletonTK
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK
   (TKX2
      ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
      x)
 -> SingletonTK
      (UnWind
         (TKX2
            ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
            x)))
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> SingletonTK
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall a b. (a -> b) -> a -> b
$ StaticShX
  ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh
-> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
m) SingletonTK x
stk2
  STKX StaticShX sh
sh1 (STKS ShS sh
sh2 SingletonTK x
stk2) ->
    SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> SingletonTK
     (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
 -> SingletonTK
      (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)))
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> SingletonTK
     (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
forall a b. (a -> b) -> a -> b
$ StaticShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
-> SingletonTK x
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh
-> StaticShX (MapJust @Nat sh)
-> StaticShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2)) SingletonTK x
stk2
  STKX StaticShX sh
sh1 (STKX StaticShX sh
sh2 SingletonTK x
stk2) ->
    SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> SingletonTK (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
 -> SingletonTK (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x)))
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> SingletonTK (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
forall a b. (a -> b) -> a -> b
$ StaticShX ((++) @(Maybe Nat) sh sh)
-> SingletonTK x -> SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh -> StaticShX sh -> StaticShX ((++) @(Maybe Nat) sh sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2
  STKX StaticShX sh
sh1 (STKProduct SingletonTK y1
y SingletonTK z
z) ->
    SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> SingletonTK (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK (SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
 -> SingletonTK (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z))))
-> SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> SingletonTK (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKX2 sh y1)
-> SingletonTK (TKX2 sh z)
-> SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (StaticShX sh -> SingletonTK y1 -> SingletonTK (TKX2 sh y1)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh1 SingletonTK y1
y) (StaticShX sh -> SingletonTK z -> SingletonTK (TKX2 sh z)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh1 SingletonTK z
z)
  STKProduct SingletonTK y1
y SingletonTK z
z -> SingletonTK (UnWind y1)
-> SingletonTK (UnWind z)
-> SingletonTK (TKProduct (UnWind y1) (UnWind z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (SingletonTK y1 -> SingletonTK (UnWind y1)
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK SingletonTK y1
y) (SingletonTK z -> SingletonTK (UnWind z)
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK SingletonTK z
z)

unWindFTK :: FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK :: forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK = \case
  FullShapeTK y
FTKScalar -> FullShapeTKW (TKScalar r)
FullShapeTKW (UnWind y)
forall y. GoodScalar y => FullShapeTKW (TKScalar y)
WFTKScalar
  FTKR IShR n
sh FullShapeTK x
FTKScalar -> IShR n -> FullShapeTKW (TKR n r)
forall y (z :: Nat).
GoodScalar y =>
IShR z -> FullShapeTKW (TKR z y)
WFTKR IShR n
sh
  FTKR IShR n
sh1 (FTKR IShR n
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK (TKR2 (n + n) x)
-> FullShapeTKW (UnWind (TKR2 (n + n) x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKR2 (n + n) x)
 -> FullShapeTKW (UnWind (TKR2 (n + n) x)))
-> FullShapeTK (TKR2 (n + n) x)
-> FullShapeTKW (UnWind (TKR2 (n + n) x))
forall a b. (a -> b) -> a -> b
$ IShR (n + n) -> FullShapeTK x -> FullShapeTK (TKR2 (n + n) x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (IShR n
sh1 IShR n -> IShR n -> IShR (n + n)
forall (n :: Nat) (m :: Nat) i. ShR n i -> ShR m i -> ShR (n + m) i
`shrAppend` IShR n
sh2) FullShapeTK x
ftk2
  FTKR IShR n
sh1 (FTKS ShS sh
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
     x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))
              (MapJust @Nat sh))
           x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK
    (FullShapeTK
   (TKX2
      ((++)
         @(Maybe Nat)
         (Replicate @(Maybe Nat) n ('Nothing @Nat))
         (MapJust @Nat sh))
      x)
 -> FullShapeTKW
      (UnWind
         (TKX2
            ((++)
               @(Maybe Nat)
               (Replicate @(Maybe Nat) n ('Nothing @Nat))
               (MapJust @Nat sh))
            x)))
-> FullShapeTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))
              (MapJust @Nat sh))
           x))
forall a b. (a -> b) -> a -> b
$ IShX
  ((++)
     @(Maybe Nat)
     (Replicate @(Maybe Nat) n ('Nothing @Nat))
     (MapJust @Nat sh))
-> FullShapeTK x
-> FullShapeTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
sh1 ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
-> ShX (MapJust @Nat sh) Int
-> IShX
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2) FullShapeTK x
ftk2
  FTKR IShR n
sh1 (FTKX IShX sh
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK
  (TKX2
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
     x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
           x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK
   (TKX2
      ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
      x)
 -> FullShapeTKW
      (UnWind
         (TKX2
            ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
            x)))
-> FullShapeTK
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
           x))
forall a b. (a -> b) -> a -> b
$ IShX
  ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
-> FullShapeTK x
-> FullShapeTK
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
sh1 ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
-> IShX sh
-> IShX
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` IShX sh
sh2) FullShapeTK x
ftk2
  FTKR IShR n
sh1 (FTKProduct FullShapeTK y1
y FullShapeTK z
z) ->
    FullShapeTK (TKProduct (TKR2 n y1) (TKR2 n z))
-> FullShapeTKW (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKProduct (TKR2 n y1) (TKR2 n z))
 -> FullShapeTKW (UnWind (TKProduct (TKR2 n y1) (TKR2 n z))))
-> FullShapeTK (TKProduct (TKR2 n y1) (TKR2 n z))
-> FullShapeTKW (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
forall a b. (a -> b) -> a -> b
$ FullShapeTK (TKR2 n y1)
-> FullShapeTK (TKR2 n z)
-> FullShapeTK (TKProduct (TKR2 n y1) (TKR2 n z))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (IShR n -> FullShapeTK y1 -> FullShapeTK (TKR2 n y1)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR n
sh1 FullShapeTK y1
y) (IShR n -> FullShapeTK z -> FullShapeTK (TKR2 n z)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR n
sh1 FullShapeTK z
z)
  FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS sh -> FullShapeTKW (TKS sh r)
forall y (z :: [Nat]).
GoodScalar y =>
ShS z -> FullShapeTKW (TKS z y)
WFTKS ShS sh
sh
  FTKS ShS sh
sh1 (FTKR IShR n
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (MapJust @Nat sh)
              (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK
    (FullShapeTK
   (TKX2
      ((++)
         @(Maybe Nat)
         (MapJust @Nat sh)
         (Replicate @(Maybe Nat) n ('Nothing @Nat)))
      x)
 -> FullShapeTKW
      (UnWind
         (TKX2
            ((++)
               @(Maybe Nat)
               (MapJust @Nat sh)
               (Replicate @(Maybe Nat) n ('Nothing @Nat)))
            x)))
-> FullShapeTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (MapJust @Nat sh)
              (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall a b. (a -> b) -> a -> b
$ IShX
  ((++)
     @(Maybe Nat)
     (MapJust @Nat sh)
     (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> FullShapeTK x
-> FullShapeTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1 ShX (MapJust @Nat sh) Int
-> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
-> IShX
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
sh2) FullShapeTK x
ftk2
  FTKS ShS sh
sh1 (FTKS ShS sh
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK (TKS2 ((++) @Nat sh sh) x)
-> FullShapeTKW (UnWind (TKS2 ((++) @Nat sh sh) x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKS2 ((++) @Nat sh sh) x)
 -> FullShapeTKW (UnWind (TKS2 ((++) @Nat sh sh) x)))
-> FullShapeTK (TKS2 ((++) @Nat sh sh) x)
-> FullShapeTKW (UnWind (TKS2 ((++) @Nat sh sh) x))
forall a b. (a -> b) -> a -> b
$ ShS ((++) @Nat sh sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat sh sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS sh
sh1 ShS sh -> ShS sh -> ShS ((++) @Nat sh sh)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS sh
sh2) FullShapeTK x
ftk2
  FTKS ShS sh
sh1 (FTKX IShX sh
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> FullShapeTKW
     (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
 -> FullShapeTKW
      (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)))
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> FullShapeTKW
     (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
forall a b. (a -> b) -> a -> b
$ IShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1 ShX (MapJust @Nat sh) Int
-> IShX sh -> IShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` IShX sh
sh2) FullShapeTK x
ftk2
  FTKS ShS sh
sh1 (FTKProduct FullShapeTK y1
y FullShapeTK z
z) ->
    FullShapeTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> FullShapeTKW (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
 -> FullShapeTKW (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z))))
-> FullShapeTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> FullShapeTKW (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
forall a b. (a -> b) -> a -> b
$ FullShapeTK (TKS2 sh y1)
-> FullShapeTK (TKS2 sh z)
-> FullShapeTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (ShS sh -> FullShapeTK y1 -> FullShapeTK (TKS2 sh y1)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh1 FullShapeTK y1
y) (ShS sh -> FullShapeTK z -> FullShapeTK (TKS2 sh z)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh1 FullShapeTK z
z)
  FTKX IShX sh
sh FullShapeTK x
FTKScalar -> IShX sh -> FullShapeTKW (TKX sh r)
forall y (z :: [Maybe Nat]).
GoodScalar y =>
IShX z -> FullShapeTKW (TKX z y)
WFTKX IShX sh
sh
  FTKX IShX sh
sh1 (FTKR IShR n
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK
   (TKX2
      ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
      x)
 -> FullShapeTKW
      (UnWind
         (TKX2
            ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
            x)))
-> FullShapeTK
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> FullShapeTKW
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall a b. (a -> b) -> a -> b
$ IShX
  ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> FullShapeTK x
-> FullShapeTK
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (IShX sh
sh1 IShX sh
-> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
-> IShX
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
sh2) FullShapeTK x
ftk2
  FTKX IShX sh
sh1 (FTKS ShS sh
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> FullShapeTKW
     (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
 -> FullShapeTKW
      (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)))
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> FullShapeTKW
     (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
forall a b. (a -> b) -> a -> b
$ IShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
-> FullShapeTK x
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (IShX sh
sh1 IShX sh
-> ShX (MapJust @Nat sh) Int
-> IShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2) FullShapeTK x
ftk2
  FTKX IShX sh
sh1 (FTKX IShX sh
sh2 FullShapeTK x
ftk2) ->
    FullShapeTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> FullShapeTKW (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
 -> FullShapeTKW (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x)))
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> FullShapeTKW (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
forall a b. (a -> b) -> a -> b
$ IShX ((++) @(Maybe Nat) sh sh)
-> FullShapeTK x -> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (IShX sh
sh1 IShX sh -> IShX sh -> IShX ((++) @(Maybe Nat) sh sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` IShX sh
sh2) FullShapeTK x
ftk2
  FTKX IShX sh
sh1 (FTKProduct FullShapeTK y1
y FullShapeTK z
z) ->
    FullShapeTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> FullShapeTKW (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK (FullShapeTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
 -> FullShapeTKW (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z))))
-> FullShapeTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> FullShapeTKW (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
forall a b. (a -> b) -> a -> b
$ FullShapeTK (TKX2 sh y1)
-> FullShapeTK (TKX2 sh z)
-> FullShapeTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (IShX sh -> FullShapeTK y1 -> FullShapeTK (TKX2 sh y1)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh
sh1 FullShapeTK y1
y) (IShX sh -> FullShapeTK z -> FullShapeTK (TKX2 sh z)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh
sh1 FullShapeTK z
z)
  FTKProduct FullShapeTK y1
y FullShapeTK z
z -> FullShapeTKW (UnWind y1)
-> FullShapeTKW (UnWind z)
-> FullShapeTKW (TKProduct (UnWind y1) (UnWind z))
forall (y :: TK) (z :: TK).
FullShapeTKW y -> FullShapeTKW z -> FullShapeTKW (TKProduct y z)
WFTKProduct (FullShapeTK y1 -> FullShapeTKW (UnWind y1)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK y1
y) (FullShapeTK z -> FullShapeTKW (UnWind z)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK z
z)

-- This uses tunpairConv so to preserve sharing, @target@ either has
-- to have a `ShareTensor` instance or the argument has to be duplicable.
-- Only the argument of the first call, not of recursive calls,
-- is assumed to be duplicable. In the AST case, this creates
-- a tower of projections for product, but if it's balanced,
-- that's of logarithmic length, so maybe even better than sharing
-- excessively, which is hard for technical typing reasons.
unWindTarget :: ConvertTensor target
             => SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget :: forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y
stk target y
t = case SingletonTK y
stk of
  SingletonTK y
STKScalar -> target (TKScalar r) -> RepW target (TKScalar r)
forall y (target :: Target).
GoodScalar y =>
target (TKScalar y) -> RepW target (TKScalar y)
WTKScalar target y
target (TKScalar r)
t
  STKR SNat n
SNat SingletonTK x
STKScalar -> target (TKR n r) -> RepW target (TKR n r)
forall y (target :: Target) (z :: Nat).
GoodScalar y =>
target (TKR z y) -> RepW target (TKR z y)
WTKR target y
target (TKR n r)
t
  STKR (SNat @n) (STKR (SNat @m) SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    SingletonTK (TKR2 (n + n) x)
-> target (TKR2 (n + n) x) -> RepW target (UnWind (TKR2 (n + n) x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (SNat (n + n) -> SingletonTK x -> SingletonTK (TKR2 (n + n) x)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR (forall (n :: Nat). KnownNat n => SNat n
SNat @(n + m)) SingletonTK x
stk2) (target (TKR2 n (TKR2 n x)) -> target (TKR2 (n + n) x)
forall (n :: Nat) (m :: Nat) (x :: TK).
(KnownNat n, KnownNat m, KnownSTK x) =>
target (TKR2 n (TKR2 m x)) -> target (TKR2 (n + m) x)
forall (target :: Target) (n :: Nat) (m :: Nat) (x :: TK).
(ConvertTensor target, KnownNat n, KnownNat m, KnownSTK x) =>
target (TKR2 n (TKR2 m x)) -> target (TKR2 (n + m) x)
runNest target y
target (TKR2 n (TKR2 n x))
t)
  STKR n :: SNat n
n@SNat n
SNat (STKS ShS sh
sh2 SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
     x)
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))
              (MapJust @Nat sh))
           x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX
  ((++)
     @(Maybe Nat)
     (Replicate @(Maybe Nat) n ('Nothing @Nat))
     (MapJust @Nat sh))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n
                        StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX (MapJust @Nat sh)
-> StaticShX
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2)) SingletonTK x
stk2)
                 (target (TKR2 n (TKS2 sh x))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
forall (n :: Nat) (sh2 :: [Nat]) (x :: TK).
(KnownNat n, KnownShS sh2, KnownSTK x) =>
target (TKR2 n (TKS2 sh2 x))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh2))
        x)
forall (target :: Target) (n :: Nat) (sh2 :: [Nat]) (x :: TK).
(ConvertTensor target, KnownNat n, KnownShS sh2, KnownSTK x) =>
target (TKR2 n (TKS2 sh2 x))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh2))
        x)
runNestS target y
target (TKR2 n (TKS2 sh x))
t)
  STKR n :: SNat n
n@SNat n
SNat (STKX StaticShX sh
sh2 SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh2 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK
  (TKX2
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
     x)
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
           x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX
  ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX sh
-> StaticShX
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2)
                 (target (TKR2 n (TKX2 sh x))
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
forall (n :: Nat) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownNat n, KnownShX sh2, KnownSTK x) =>
target (TKR2 n (TKX2 sh2 x))
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh2)
        x)
forall (target :: Target) (n :: Nat) (sh2 :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownNat n, KnownShX sh2, KnownSTK x) =>
target (TKR2 n (TKX2 sh2 x))
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh2)
        x)
runNestX target y
target (TKR2 n (TKX2 sh x))
t)
  STKR n :: SNat n
n@SNat n
SNat (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2) ->
    SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
-> target (TKProduct (TKR2 n y1) (TKR2 n z))
-> RepW target (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (SingletonTK (TKR2 n y1)
-> SingletonTK (TKR2 n z)
-> SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (SNat n -> SingletonTK y1 -> SingletonTK (TKR2 n y1)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
n SingletonTK y1
stk1) (SNat n -> SingletonTK z -> SingletonTK (TKR2 n z)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
n SingletonTK z
stk2)) (target (TKR2 n (TKProduct y1 z))
-> target (TKProduct (TKR2 n y1) (TKR2 n z))
forall (y :: TK) (z :: TK) (n :: Nat).
target (TKR2 n (TKProduct y z))
-> target (TKProduct (TKR2 n y) (TKR2 n z))
forall (target :: Target) (y :: TK) (z :: TK) (n :: Nat).
ConvertTensor target =>
target (TKR2 n (TKProduct y z))
-> target (TKProduct (TKR2 n y) (TKR2 n z))
runzip target y
target (TKR2 n (TKProduct y1 z))
t)
  STKS ShS sh
sh1 SingletonTK x
STKScalar -> ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh1 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ target (TKS sh r) -> RepW target (TKS sh r)
forall y (target :: Target) (z :: [Nat]).
GoodScalar y =>
target (TKS z y) -> RepW target (TKS z y)
WTKS target y
target (TKS sh r)
t
  STKS ShS sh
sh1 (STKR m :: SNat n
m@(SNat @m) SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh1 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (MapJust @Nat sh)
              (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX
  ((++)
     @(Maybe Nat)
     (MapJust @Nat sh)
     (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1)
                        StaticShX (MapJust @Nat sh)
-> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
m) SingletonTK x
stk2) (forall (target :: Target) (sh1 :: [Nat]) (m :: Nat) (x :: TK).
(ConvertTensor target, KnownShS sh1, KnownNat m, KnownSTK x) =>
target (TKS2 sh1 (TKR2 m x))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh1)
           (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
sunNestR @_ @_ @m target y
target (TKS2 sh (TKR2 n x))
t)
  STKS ShS sh
sh1 (STKS ShS sh
sh2 SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh1 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK (TKS2 ((++) @Nat sh sh) x)
-> target (TKS2 ((++) @Nat sh sh) x)
-> RepW target (UnWind (TKS2 ((++) @Nat sh sh) x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (ShS ((++) @Nat sh sh)
-> SingletonTK x -> SingletonTK (TKS2 ((++) @Nat sh sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS (ShS sh
sh1 ShS sh -> ShS sh -> ShS ((++) @Nat sh sh)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS sh
sh2) SingletonTK x
stk2) (target (TKS2 sh (TKS2 sh x)) -> target (TKS2 ((++) @Nat sh sh) x)
forall (sh1 :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShS sh1, KnownShS sh2, KnownSTK x) =>
target (TKS2 sh1 (TKS2 sh2 x))
-> target (TKS2 ((++) @Nat sh1 sh2) x)
forall (target :: Target) (sh1 :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh1, KnownShS sh2, KnownSTK x) =>
target (TKS2 sh1 (TKS2 sh2 x))
-> target (TKS2 ((++) @Nat sh1 sh2) x)
sunNest target y
target (TKS2 sh (TKS2 sh x))
t)
  STKS ShS sh
sh1 (STKX StaticShX sh
sh2 SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh2 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh1 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> RepW
     target (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
-> SingletonTK x
-> SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1) StaticShX (MapJust @Nat sh)
-> StaticShX sh
-> StaticShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2)
                 (target (TKS2 sh (TKX2 sh x))
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
forall (sh1 :: [Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShS sh1, KnownShX sh2, KnownSTK x) =>
target (TKS2 sh1 (TKX2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh1) sh2) x)
forall (target :: Target) (sh1 :: [Nat]) (sh2 :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh1, KnownShX sh2, KnownSTK x) =>
target (TKS2 sh1 (TKX2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh1) sh2) x)
sunNestX target y
target (TKS2 sh (TKX2 sh x))
t)
  STKS ShS sh
sh1 (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2)->
    SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> target (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> RepW target (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (SingletonTK (TKS2 sh y1)
-> SingletonTK (TKS2 sh z)
-> SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (ShS sh -> SingletonTK y1 -> SingletonTK (TKS2 sh y1)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh1 SingletonTK y1
stk1) (ShS sh -> SingletonTK z -> SingletonTK (TKS2 sh z)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh1 SingletonTK z
stk2)) (target (TKS2 sh (TKProduct y1 z))
-> target (TKProduct (TKS2 sh y1) (TKS2 sh z))
forall (y :: TK) (z :: TK) (sh :: [Nat]).
target (TKS2 sh (TKProduct y z))
-> target (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (target :: Target) (y :: TK) (z :: TK) (sh :: [Nat]).
ConvertTensor target =>
target (TKS2 sh (TKProduct y z))
-> target (TKProduct (TKS2 sh y) (TKS2 sh z))
sunzip target y
target (TKS2 sh (TKProduct y1 z))
t)
  STKX StaticShX sh
sh1 SingletonTK x
STKScalar -> StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh1 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ target (TKX sh r) -> RepW target (TKX sh r)
forall y (target :: Target) (z :: [Maybe Nat]).
GoodScalar y =>
target (TKX z y) -> RepW target (TKX z y)
WTKX target y
target (TKX sh r)
t
  STKX StaticShX sh
sh1 (STKR m :: SNat n
m@(SNat @m) SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh1 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> target
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX
  ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh
-> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
m) SingletonTK x
stk2)
                 (forall (target :: Target) (sh1 :: [Maybe Nat]) (m :: Nat)
       (x :: TK).
(ConvertTensor target, KnownShX sh1, KnownNat m, KnownSTK x) =>
target (TKX2 sh1 (TKR2 m x))
-> target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
xunNestR @_ @_ @m target y
target (TKX2 sh (TKR2 n x))
t)
  STKX StaticShX sh
sh1 (STKS ShS sh
sh2 SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh1 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ ShS sh
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2 ((KnownShS sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShS sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> target (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> RepW
     target (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
-> SingletonTK x
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh
-> StaticShX (MapJust @Nat sh)
-> StaticShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2)) SingletonTK x
stk2)
                 (target (TKX2 sh (TKS2 sh x))
-> target (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
forall (sh1 :: [Maybe Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShX sh1, KnownShS sh2, KnownSTK x) =>
target (TKX2 sh1 (TKS2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Nat])
       (x :: TK).
(ConvertTensor target, KnownShX sh1, KnownShS sh2, KnownSTK x) =>
target (TKX2 sh1 (TKS2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
xunNestS target y
target (TKX2 sh (TKS2 sh x))
t)
  STKX StaticShX sh
sh1 (STKX StaticShX sh
sh2 SingletonTK x
stk2) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh1 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ StaticShX sh
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh2 ((KnownShX sh => RepW target (UnWind y)) -> RepW target (UnWind y))
-> (KnownShX sh => RepW target (UnWind y))
-> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$
    SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> target (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> RepW target (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (StaticShX ((++) @(Maybe Nat) sh sh)
-> SingletonTK x -> SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh -> StaticShX sh -> StaticShX ((++) @(Maybe Nat) sh sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2) (target (TKX2 sh (TKX2 sh x))
-> target (TKX2 ((++) @(Maybe Nat) sh sh) x)
forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 sh1 (TKX2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 sh1 (TKX2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
xunNest target y
target (TKX2 sh (TKX2 sh x))
t)
  STKX StaticShX sh
sh1 (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2) ->
    SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> target (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> RepW target (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (SingletonTK (TKX2 sh y1)
-> SingletonTK (TKX2 sh z)
-> SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (StaticShX sh -> SingletonTK y1 -> SingletonTK (TKX2 sh y1)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh1 SingletonTK y1
stk1) (StaticShX sh -> SingletonTK z -> SingletonTK (TKX2 sh z)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh1 SingletonTK z
stk2)) (target (TKX2 sh (TKProduct y1 z))
-> target (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (y :: TK) (z :: TK) (sh :: [Maybe Nat]).
target (TKX2 sh (TKProduct y z))
-> target (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (target :: Target) (y :: TK) (z :: TK) (sh :: [Maybe Nat]).
ConvertTensor target =>
target (TKX2 sh (TKProduct y z))
-> target (TKProduct (TKX2 sh y) (TKX2 sh z))
xunzip target y
target (TKX2 sh (TKProduct y1 z))
t)
  STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
    let (target y1
t1, target z
t2) = target (TKProduct y1 z) -> (target y1, target z)
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ConvertTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpairConv target y
target (TKProduct y1 z)
t
    in RepW target (UnWind y1)
-> RepW target (UnWind z)
-> RepW target (TKProduct (UnWind y1) (UnWind z))
forall (target :: Target) (y :: TK) (z :: TK).
RepW target y -> RepW target z -> RepW target (TKProduct y z)
WTKProduct (SingletonTK y1 -> target y1 -> RepW target (UnWind y1)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y1
stk1 target y1
t1) (SingletonTK z -> target z -> RepW target (UnWind z)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK z
stk2 target z
t2)

windTarget :: ConvertTensor target
           => SingletonTK y -> RepW target (UnWind y) -> target y
windTarget :: forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK y
stk RepW target (UnWind y)
t = case (SingletonTK y
stk, RepW target (UnWind y)
t) of
  (SingletonTK y
STKScalar, WTKScalar target (TKScalar r)
v) -> target y
target (TKScalar r)
v
  (STKR SNat n
_ SingletonTK x
STKScalar, WTKR target (TKR2 n (TKScalar r))
v) -> target y
target (TKR2 n (TKScalar r))
v
  (STKR n :: SNat n
n@(SNat @n) (STKR (SNat @m) SingletonTK x
stk2), RepW target (UnWind y)
_)
   | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    SNat n -> target (TKR2 (n + n) x) -> target (TKR2 n (TKR2 n x))
forall (n :: Nat) (m :: Nat) (x :: TK).
(KnownNat m, KnownSTK x) =>
SNat n -> target (TKR2 (n + m) x) -> target (TKR2 n (TKR2 m x))
forall (target :: Target) (n :: Nat) (m :: Nat) (x :: TK).
(ConvertTensor target, KnownNat m, KnownSTK x) =>
SNat n -> target (TKR2 (n + m) x) -> target (TKR2 n (TKR2 m x))
rnest SNat n
n (target (TKR2 (n + n) x) -> target (TKR2 n (TKR2 n x)))
-> target (TKR2 (n + n) x) -> target (TKR2 n (TKR2 n x))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKR2 (n + n) x)
-> RepW target (UnWind (TKR2 (n + n) x)) -> target (TKR2 (n + n) x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (SNat (n + n) -> SingletonTK x -> SingletonTK (TKR2 (n + n) x)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR (forall (n :: Nat). KnownNat n => SNat n
SNat @(n + m)) SingletonTK x
stk2) RepW target (UnWind y)
RepW target (UnWind (TKR2 (n + n) x))
t
  (STKR SNat n
n (STKS ShS sh
sh2 SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh -> (KnownShS sh => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2 ((KnownShS sh => target y) -> target y)
-> (KnownShS sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
    SNat n
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
-> target (TKR2 n (TKS2 sh x))
forall (n :: Nat) (sh2 :: [Nat]) (x :: TK).
(KnownShS sh2, KnownSTK x) =>
SNat n
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh2))
        x)
-> target (TKR2 n (TKS2 sh2 x))
forall (target :: Target) (n :: Nat) (sh2 :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh2, KnownSTK x) =>
SNat n
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh2))
        x)
-> target (TKR2 n (TKS2 sh2 x))
rnestS SNat n
n
    (target
   (TKX2
      ((++)
         @(Maybe Nat)
         (Replicate @(Maybe Nat) n ('Nothing @Nat))
         (MapJust @Nat sh))
      x)
 -> target (TKR2 n (TKS2 sh x)))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
-> target (TKR2 n (TKS2 sh x))
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
     x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))
              (MapJust @Nat sh))
           x))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX
  ((++)
     @(Maybe Nat)
     (Replicate @(Maybe Nat) n ('Nothing @Nat))
     (MapJust @Nat sh))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n
                        StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX (MapJust @Nat sh)
-> StaticShX
     ((++)
        @(Maybe Nat)
        (Replicate @(Maybe Nat) n ('Nothing @Nat))
        (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2)) SingletonTK x
stk2) RepW target (UnWind y)
RepW
  target
  (UnWind
     (TKX2
        ((++)
           @(Maybe Nat)
           (Replicate @(Maybe Nat) n ('Nothing @Nat))
           (MapJust @Nat sh))
        x))
t
  (STKR SNat n
n (STKX StaticShX sh
sh2 SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh -> (KnownShX sh => target y) -> target y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh2 ((KnownShX sh => target y) -> target y)
-> (KnownShX sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
    SNat n
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
-> target (TKR2 n (TKX2 sh x))
forall (n :: Nat) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh2, KnownSTK x) =>
SNat n
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh2)
        x)
-> target (TKR2 n (TKX2 sh2 x))
forall (target :: Target) (n :: Nat) (sh2 :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShX sh2, KnownSTK x) =>
SNat n
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh2)
        x)
-> target (TKR2 n (TKX2 sh2 x))
rnestX SNat n
n
    (target
   (TKX2
      ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
      x)
 -> target (TKR2 n (TKX2 sh x)))
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
-> target (TKR2 n (TKX2 sh x))
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (TKX2
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
     x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
           x))
-> target
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX
  ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX sh
-> StaticShX
     ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2) RepW target (UnWind y)
RepW
  target
  (UnWind
     (TKX2
        ((++) @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)) sh)
        x))
t
  (STKR n :: SNat n
n@SNat n
SNat (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK y1
Dict <- SingletonTK y1 -> Dict @TK KnownSTK y1
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK y1
stk1
                                          , Dict @TK KnownSTK z
Dict <- SingletonTK z -> Dict @TK KnownSTK z
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK z
stk2 ->
    target (TKProduct (TKR2 n y1) (TKR2 n z))
-> target (TKR2 n (TKProduct y1 z))
forall (y :: TK) (z :: TK) (n :: Nat).
(KnownSTK y, KnownSTK z) =>
target (TKProduct (TKR2 n y) (TKR2 n z))
-> target (TKR2 n (TKProduct y z))
forall (target :: Target) (y :: TK) (z :: TK) (n :: Nat).
(ConvertTensor target, KnownSTK y, KnownSTK z) =>
target (TKProduct (TKR2 n y) (TKR2 n z))
-> target (TKR2 n (TKProduct y z))
rzip (target (TKProduct (TKR2 n y1) (TKR2 n z))
 -> target (TKR2 n (TKProduct y1 z)))
-> target (TKProduct (TKR2 n y1) (TKR2 n z))
-> target (TKR2 n (TKProduct y1 z))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
-> RepW target (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
-> target (TKProduct (TKR2 n y1) (TKR2 n z))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (SingletonTK (TKR2 n y1)
-> SingletonTK (TKR2 n z)
-> SingletonTK (TKProduct (TKR2 n y1) (TKR2 n z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (SNat n -> SingletonTK y1 -> SingletonTK (TKR2 n y1)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
n SingletonTK y1
stk1) (SNat n -> SingletonTK z -> SingletonTK (TKR2 n z)
forall (n :: Nat) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
n SingletonTK z
stk2)) RepW target (UnWind y)
RepW target (UnWind (TKProduct (TKR2 n y1) (TKR2 n z)))
t
  (STKS ShS sh
_ SingletonTK x
STKScalar, WTKS target (TKS2 sh (TKScalar r))
v) -> target y
target (TKS2 sh (TKScalar r))
v
  (STKS ShS sh
sh1 (STKR m :: SNat n
m@SNat n
SNat SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> target (TKS2 sh (TKR2 n x))
forall (sh1 :: [Nat]) (m :: Nat) (x :: TK).
(KnownNat m, KnownSTK x) =>
ShS sh1
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh1)
           (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> target (TKS2 sh1 (TKR2 m x))
forall (target :: Target) (sh1 :: [Nat]) (m :: Nat) (x :: TK).
(ConvertTensor target, KnownNat m, KnownSTK x) =>
ShS sh1
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh1)
           (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> target (TKS2 sh1 (TKR2 m x))
snestR ShS sh
sh1
    (target
   (TKX2
      ((++)
         @(Maybe Nat)
         (MapJust @Nat sh)
         (Replicate @(Maybe Nat) n ('Nothing @Nat)))
      x)
 -> target (TKS2 sh (TKR2 n x)))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> target (TKS2 sh (TKR2 n x))
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (TKX2
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++)
              @(Maybe Nat)
              (MapJust @Nat sh)
              (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX
  ((++)
     @(Maybe Nat)
     (MapJust @Nat sh)
     (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1)
                        StaticShX (MapJust @Nat sh)
-> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX
     ((++)
        @(Maybe Nat)
        (MapJust @Nat sh)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
m) SingletonTK x
stk2) RepW target (UnWind y)
RepW
  target
  (UnWind
     (TKX2
        ((++)
           @(Maybe Nat)
           (MapJust @Nat sh)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x))
t
  (STKS ShS sh
sh1 (STKS ShS sh
sh2 SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh -> (KnownShS sh => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2 ((KnownShS sh => target y) -> target y)
-> (KnownShS sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
    ShS sh
-> target (TKS2 ((++) @Nat sh sh) x)
-> target (TKS2 sh (TKS2 sh x))
forall (sh1 :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShS sh2, KnownSTK x) =>
ShS sh1
-> target (TKS2 ((++) @Nat sh1 sh2) x)
-> target (TKS2 sh1 (TKS2 sh2 x))
forall (target :: Target) (sh1 :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh2, KnownSTK x) =>
ShS sh1
-> target (TKS2 ((++) @Nat sh1 sh2) x)
-> target (TKS2 sh1 (TKS2 sh2 x))
snest ShS sh
sh1 (target (TKS2 ((++) @Nat sh sh) x) -> target (TKS2 sh (TKS2 sh x)))
-> target (TKS2 ((++) @Nat sh sh) x)
-> target (TKS2 sh (TKS2 sh x))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKS2 ((++) @Nat sh sh) x)
-> RepW target (UnWind (TKS2 ((++) @Nat sh sh) x))
-> target (TKS2 ((++) @Nat sh sh) x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (ShS ((++) @Nat sh sh)
-> SingletonTK x -> SingletonTK (TKS2 ((++) @Nat sh sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS (ShS sh -> ShS sh -> ShS ((++) @Nat sh sh)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
shsAppend ShS sh
sh1 ShS sh
sh2) SingletonTK x
stk2) RepW target (UnWind y)
RepW target (UnWind (TKS2 ((++) @Nat sh sh) x))
t
  (STKS ShS sh
sh1 (STKX StaticShX sh
sh2 SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh -> (KnownShX sh => target y) -> target y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh2 ((KnownShX sh => target y) -> target y)
-> (KnownShX sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
    ShS sh
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> target (TKS2 sh (TKX2 sh x))
forall (sh1 :: [Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh2, KnownSTK x) =>
ShS sh1
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh1) sh2) x)
-> target (TKS2 sh1 (TKX2 sh2 x))
forall (target :: Target) (sh1 :: [Nat]) (sh2 :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShX sh2, KnownSTK x) =>
ShS sh1
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh1) sh2) x)
-> target (TKS2 sh1 (TKX2 sh2 x))
snestX ShS sh
sh1 (target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
 -> target (TKS2 sh (TKX2 sh x)))
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> target (TKS2 sh (TKX2 sh x))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
-> RepW
     target (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
-> target (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
-> SingletonTK x
-> SingletonTK (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh1)
                                   StaticShX (MapJust @Nat sh)
-> StaticShX sh
-> StaticShX ((++) @(Maybe Nat) (MapJust @Nat sh) sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh2) SingletonTK x
stk2) RepW target (UnWind y)
RepW
  target (UnWind (TKX2 ((++) @(Maybe Nat) (MapJust @Nat sh) sh) x))
t
  (STKS ShS sh
sh1 (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK y1
Dict <- SingletonTK y1 -> Dict @TK KnownSTK y1
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK y1
stk1
                                       , Dict @TK KnownSTK z
Dict <- SingletonTK z -> Dict @TK KnownSTK z
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK z
stk2 ->
    target (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> target (TKS2 sh (TKProduct y1 z))
forall (y :: TK) (z :: TK) (sh :: [Nat]).
(KnownSTK y, KnownSTK z) =>
target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> target (TKS2 sh (TKProduct y z))
forall (target :: Target) (y :: TK) (z :: TK) (sh :: [Nat]).
(ConvertTensor target, KnownSTK y, KnownSTK z) =>
target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> target (TKS2 sh (TKProduct y z))
szip (target (TKProduct (TKS2 sh y1) (TKS2 sh z))
 -> target (TKS2 sh (TKProduct y1 z)))
-> target (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> target (TKS2 sh (TKProduct y1 z))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
-> RepW target (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
-> target (TKProduct (TKS2 sh y1) (TKS2 sh z))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (SingletonTK (TKS2 sh y1)
-> SingletonTK (TKS2 sh z)
-> SingletonTK (TKProduct (TKS2 sh y1) (TKS2 sh z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (ShS sh -> SingletonTK y1 -> SingletonTK (TKS2 sh y1)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh1 SingletonTK y1
stk1) (ShS sh -> SingletonTK z -> SingletonTK (TKS2 sh z)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh1 SingletonTK z
stk2)) RepW target (UnWind y)
RepW target (UnWind (TKProduct (TKS2 sh y1) (TKS2 sh z)))
t
  (STKX StaticShX sh
_ SingletonTK x
STKScalar, WTKX target (TKX2 sh (TKScalar r))
v) -> target y
target (TKX2 sh (TKScalar r))
v
  (STKX StaticShX sh
sh1 (STKR m :: SNat n
m@SNat n
SNat SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh
-> target
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> target (TKX2 sh (TKR2 n x))
forall (sh1 :: [Maybe Nat]) (m :: Nat) (x :: TK).
(KnownNat m, KnownSTK x) =>
StaticShX sh1
-> target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> target (TKX2 sh1 (TKR2 m x))
forall (target :: Target) (sh1 :: [Maybe Nat]) (m :: Nat)
       (x :: TK).
(ConvertTensor target, KnownNat m, KnownSTK x) =>
StaticShX sh1
-> target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> target (TKX2 sh1 (TKR2 m x))
xnestR StaticShX sh
sh1
    (target
   (TKX2
      ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
      x)
 -> target (TKX2 sh (TKR2 n x)))
-> target
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> target (TKX2 sh (TKR2 n x))
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> RepW
     target
     (UnWind
        (TKX2
           ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
           x))
-> target
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX
  ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh
-> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> StaticShX
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
m) SingletonTK x
stk2) RepW target (UnWind y)
RepW
  target
  (UnWind
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x))
t
  (STKX StaticShX sh
sh1 (STKS ShS sh
sh2 SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    ShS sh -> (KnownShS sh => target y) -> target y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2 ((KnownShS sh => target y) -> target y)
-> (KnownShS sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
    StaticShX sh
-> target (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> target (TKX2 sh (TKS2 sh x))
forall (sh1 :: [Maybe Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShS sh2, KnownSTK x) =>
StaticShX sh1
-> target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> target (TKX2 sh1 (TKS2 sh2 x))
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh2, KnownSTK x) =>
StaticShX sh1
-> target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> target (TKX2 sh1 (TKS2 sh2 x))
xnestS StaticShX sh
sh1
    (target (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
 -> target (TKX2 sh (TKS2 sh x)))
-> target (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> target (TKX2 sh (TKS2 sh x))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
-> RepW
     target (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
-> target (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
-> SingletonTK x
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh
sh1 StaticShX sh
-> StaticShX (MapJust @Nat sh)
-> StaticShX ((++) @(Maybe Nat) sh (MapJust @Nat sh))
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShS sh -> ShX (MapJust @Nat sh) Int
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh2)) SingletonTK x
stk2) RepW target (UnWind y)
RepW
  target (UnWind (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) x))
t
  (STKX StaticShX sh
sh1 (STKX StaticShX sh
sh2 SingletonTK x
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
stk2 ->
    StaticShX sh -> (KnownShX sh => target y) -> target y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh2 ((KnownShX sh => target y) -> target y)
-> (KnownShX sh => target y) -> target y
forall a b. (a -> b) -> a -> b
$
    StaticShX sh
-> target (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> target (TKX2 sh (TKX2 sh x))
forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh2, KnownSTK x) =>
StaticShX sh1
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> target (TKX2 sh1 (TKX2 sh2 x))
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShX sh2, KnownSTK x) =>
StaticShX sh1
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> target (TKX2 sh1 (TKX2 sh2 x))
xnest StaticShX sh
sh1 (target (TKX2 ((++) @(Maybe Nat) sh sh) x)
 -> target (TKX2 sh (TKX2 sh x)))
-> target (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> target (TKX2 sh (TKX2 sh x))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
-> RepW target (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
-> target (TKX2 ((++) @(Maybe Nat) sh sh) x)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (StaticShX ((++) @(Maybe Nat) sh sh)
-> SingletonTK x -> SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (StaticShX sh -> StaticShX sh -> StaticShX ((++) @(Maybe Nat) sh sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
ssxAppend StaticShX sh
sh1 StaticShX sh
sh2) SingletonTK x
stk2) RepW target (UnWind y)
RepW target (UnWind (TKX2 ((++) @(Maybe Nat) sh sh) x))
t
  (STKX StaticShX sh
sh1 (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2), RepW target (UnWind y)
_) | Dict @TK KnownSTK y1
Dict <- SingletonTK y1 -> Dict @TK KnownSTK y1
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK y1
stk1
                                       , Dict @TK KnownSTK z
Dict <- SingletonTK z -> Dict @TK KnownSTK z
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK z
stk2 ->
    target (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> target (TKX2 sh (TKProduct y1 z))
forall (y :: TK) (z :: TK) (sh :: [Maybe Nat]).
(KnownSTK y, KnownSTK z) =>
target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> target (TKX2 sh (TKProduct y z))
forall (target :: Target) (y :: TK) (z :: TK) (sh :: [Maybe Nat]).
(ConvertTensor target, KnownSTK y, KnownSTK z) =>
target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> target (TKX2 sh (TKProduct y z))
xzip (target (TKProduct (TKX2 sh y1) (TKX2 sh z))
 -> target (TKX2 sh (TKProduct y1 z)))
-> target (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> target (TKX2 sh (TKProduct y1 z))
forall a b. (a -> b) -> a -> b
$ SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
-> RepW target (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
-> target (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (SingletonTK (TKX2 sh y1)
-> SingletonTK (TKX2 sh z)
-> SingletonTK (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct (StaticShX sh -> SingletonTK y1 -> SingletonTK (TKX2 sh y1)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh1 SingletonTK y1
stk1) (StaticShX sh -> SingletonTK z -> SingletonTK (TKX2 sh z)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh1 SingletonTK z
stk2)) RepW target (UnWind y)
RepW target (UnWind (TKProduct (TKX2 sh y1) (TKX2 sh z)))
t
  (STKProduct SingletonTK y1
stk1 SingletonTK z
stk2, WTKProduct RepW target x
t1 RepW target z
t2) ->
    target y1 -> target z -> target (TKProduct y1 z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
ConvertTensor target =>
target x -> target z -> target (TKProduct x z)
tpairConv (SingletonTK y1 -> RepW target (UnWind y1) -> target y1
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK y1
stk1 RepW target x
RepW target (UnWind y1)
t1) (SingletonTK z -> RepW target (UnWind z) -> target z
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK z
stk2 RepW target z
RepW target (UnWind z)
t2)


-- * Operations defined using unwinding

-- | Add two (nested pairs of) tensors. Requires duplicable arguments
-- or a `ShareTensor` instance.
addTarget :: (BaseTensor target, ConvertTensor target)
          => SingletonTK y -> target y -> target y -> target y
addTarget :: forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
SingletonTK y -> target y -> target y -> target y
addTarget SingletonTK y
stk target y
a target y
b =
  let a2 :: RepW target (UnWind y)
a2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y
stk target y
a
      b2 :: RepW target (UnWind y)
b2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y
stk target y
b
  in SingletonTK y -> RepW target (UnWind y) -> target y
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK y
stk (RepW target (UnWind y) -> target y)
-> RepW target (UnWind y) -> target y
forall a b. (a -> b) -> a -> b
$ RepW target (UnWind y)
-> RepW target (UnWind y) -> RepW target (UnWind y)
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
addRepW RepW target (UnWind y)
a2 RepW target (UnWind y)
b2

-- | Multiply two (nested pairs of) tensors. Requires duplicable arguments
-- or a `ShareTensor` instance.
multTarget :: (BaseTensor target, ConvertTensor target)
           => SingletonTK y -> target y -> target y -> target y
multTarget :: forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
SingletonTK y -> target y -> target y -> target y
multTarget SingletonTK y
stk target y
a target y
b =
  let a2 :: RepW target (UnWind y)
a2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y
stk target y
a
      b2 :: RepW target (UnWind y)
b2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y
stk target y
b
  in SingletonTK y -> RepW target (UnWind y) -> target y
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK y
stk (RepW target (UnWind y) -> target y)
-> RepW target (UnWind y) -> target y
forall a b. (a -> b) -> a -> b
$ RepW target (UnWind y)
-> RepW target (UnWind y) -> RepW target (UnWind y)
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> RepW target y -> RepW target y
multRepW RepW target (UnWind y)
a2 RepW target (UnWind y)
b2

-- | Sum all dimensions of each component and then sum it all.
-- Requires duplicable arguments or a `ShareTensor` instance.
sum0Target :: (BaseTensor target, ConvertTensor target)
           => FullShapeTK y -> target y
           -> target (TKScalar Double)
sum0Target :: forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
FullShapeTK y -> target y -> target (TKScalar Double)
sum0Target FullShapeTK y
ftk target y
a =
  let a2 :: RepW target (UnWind y)
a2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) target y
a
  in FullShapeTKW (UnWind y)
-> RepW target (UnWind y) -> target (TKScalar Double)
forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y -> RepW target y -> target (TKScalar Double)
sum0RepW (FullShapeTK y -> FullShapeTKW (UnWind y)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK y
ftk) RepW target (UnWind y)
a2

-- | Dot product each component and then sum it all.
-- Requires duplicable arguments or a `ShareTensor` instance.
dot0Target :: (BaseTensor target, ConvertTensor target)
           => FullShapeTK y -> target y -> target y
           -> target (TKScalar Double)
dot0Target :: forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
dot0Target FullShapeTK y
ftk target y
a target y
b =
  let a2 :: RepW target (UnWind y)
a2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) target y
a
      b2 :: RepW target (UnWind y)
b2 = SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) target y
b
  in FullShapeTKW (UnWind y)
-> RepW target (UnWind y)
-> RepW target (UnWind y)
-> target (TKScalar Double)
forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTKW y
-> RepW target y -> RepW target y -> target (TKScalar Double)
dot0RepW (FullShapeTK y -> FullShapeTKW (UnWind y)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK y
ftk) RepW target (UnWind y)
a2 RepW target (UnWind y)
b2

-- | Replicate a scalar along the given full shape singleton.
replTarget :: forall y target. (BaseTensor target, ConvertTensor target)
           => (forall r. GoodScalar r => r)
           -> FullShapeTK y -> target y
replTarget :: forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
replTarget forall r. GoodScalar r => r
r FullShapeTK y
ftk =
  SingletonTK y -> RepW target (UnWind y) -> target y
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) (RepW target (UnWind y) -> target y)
-> RepW target (UnWind y) -> target y
forall a b. (a -> b) -> a -> b
$ (forall r. GoodScalar r => r)
-> FullShapeTKW (UnWind y) -> RepW target (UnWind y)
forall (y :: TK) (target :: Target).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTKW y -> RepW target y
replRepW r
forall r. GoodScalar r => r
r (FullShapeTK y -> FullShapeTKW (UnWind y)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK y
ftk)

-- | Replicate the default value along the given full shape singleton.
defTarget :: forall y target. (BaseTensor target, ConvertTensor target)
          => FullShapeTK y -> target y
defTarget :: forall (y :: TK) (target :: Target).
(BaseTensor target, ConvertTensor target) =>
FullShapeTK y -> target y
defTarget FullShapeTK y
ftk =
  SingletonTK y -> RepW target (UnWind y) -> target y
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) (RepW target (UnWind y) -> target y)
-> RepW target (UnWind y) -> target y
forall a b. (a -> b) -> a -> b
$ FullShapeTKW (UnWind y) -> RepW target (UnWind y)
forall (y :: TK) (target :: Target).
BaseTensor target =>
FullShapeTKW y -> RepW target y
defRepW (FullShapeTK y -> FullShapeTKW (UnWind y)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK y
ftk)

concreteTarget
  :: forall y target. (ConvertTensor Concrete, ConvertTensor target)
  => (forall r. GoodScalar r => Concrete (TKScalar r) -> target (TKScalar r))
  -> (forall r sh. GoodScalar r => Concrete (TKS sh r) -> target (TKS sh r))
  -> (forall x z. FullShapeTK z -> target x -> target z)
  -> SingletonTK y -> Concrete y
  -> target y
concreteTarget :: forall (y :: TK) (target :: Target).
(ConvertTensor Concrete, ConvertTensor target) =>
(forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> SingletonTK y
-> Concrete y
-> target y
concreteTarget forall r.
GoodScalar r =>
Concrete (TKScalar r) -> target (TKScalar r)
concreteK forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS SingletonTK y
stk Concrete y
v =
  SingletonTK y -> RepW target (UnWind y) -> target y
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK y
stk
  (RepW target (UnWind y) -> target y)
-> RepW target (UnWind y) -> target y
forall a b. (a -> b) -> a -> b
$ (forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete (UnWind y)
-> RepW target (UnWind y)
forall (y :: TK) (target :: Target).
(ConvertTensor Concrete, ConvertTensor target) =>
(forall r.
 GoodScalar r =>
 Concrete (TKScalar r) -> target (TKScalar r))
-> (forall r (sh :: [Nat]).
    GoodScalar r =>
    Concrete (TKS sh r) -> target (TKS sh r))
-> (forall (x :: TK) (z :: TK).
    FullShapeTK z -> target x -> target z)
-> RepW Concrete y
-> RepW target y
concreteRepW Concrete (TKScalar r) -> target (TKScalar r)
forall r.
GoodScalar r =>
Concrete (TKScalar r) -> target (TKScalar r)
concreteK Concrete (TKS sh r) -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> target (TKS sh r)
concreteS FullShapeTK z -> target x -> target z
forall (x :: TK) (z :: TK). FullShapeTK z -> target x -> target z
fromS
  (RepW Concrete (UnWind y) -> RepW target (UnWind y))
-> RepW Concrete (UnWind y) -> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ SingletonTK y -> Concrete y -> RepW Concrete (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget SingletonTK y
stk Concrete y
v

lemUnWindOfAD :: SingletonTK y
              -> UnWind (ADTensorKind y) :~: ADTensorKind (UnWind y)
lemUnWindOfAD :: forall (y :: TK).
SingletonTK y
-> (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
lemUnWindOfAD SingletonTK y
_ = (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl

-- | Convert a tensor into a tensor with only trivial non-differentiable
-- scalars. The `ShareTensor` constraint is needed, despite what GHC says,
-- in order not to require duplicable arguments.
toADTensorKindShared
  :: (BaseTensor target, ConvertTensor target, ShareTensor target)
  => FullShapeTK y -> target y
  -> target (ADTensorKind y)
toADTensorKindShared :: forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK y
ftk target y
a | (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
Refl <- SingletonTK y
-> (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
forall (y :: TK).
SingletonTK y
-> (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
lemUnWindOfAD (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) =
  SingletonTK (ADTensorKind y)
-> RepW target (UnWind (ADTensorKind y)) -> target (ADTensorKind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK (SingletonTK y -> SingletonTK (ADTensorKind y))
-> SingletonTK y -> SingletonTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk)
  (RepW target (UnWind (ADTensorKind y)) -> target (ADTensorKind y))
-> RepW target (UnWind (ADTensorKind y)) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ RepW target (UnWind y)
-> FullShapeTKW (UnWind y) -> RepW target (ADTensorKind (UnWind y))
forall (y :: TK) (target :: Target).
BaseTensor target =>
RepW target y -> FullShapeTKW y -> RepW target (ADTensorKind y)
toADTensorKindW (SingletonTK y -> target y -> RepW target (UnWind y)
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) target y
a) (FullShapeTK y -> FullShapeTKW (UnWind y)
forall (y :: TK). FullShapeTK y -> FullShapeTKW (UnWind y)
unWindFTK FullShapeTK y
ftk)

-- | Convert a tensor with only trivial non-differentiable scalars
-- into a tensor with the non-differentiable scalars given by the singleton
-- and with zero values at the non-differentiable types. The `ShareTensor`
-- constraint is needed, despite what GHC says, in order not to require
-- duplicable arguments.
fromADTensorKindShared
  :: (BaseTensor target, ConvertTensor target, ShareTensor target)
  => SingletonTK y -> target (ADTensorKind y)
  -> target y
fromADTensorKindShared :: forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
SingletonTK y -> target (ADTensorKind y) -> target y
fromADTensorKindShared SingletonTK y
stk target (ADTensorKind y)
a | (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
Refl <- SingletonTK y
-> (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
forall (y :: TK).
SingletonTK y
-> (:~:) @TK (UnWind (ADTensorKind y)) (ADTensorKind (UnWind y))
lemUnWindOfAD SingletonTK y
stk =
  SingletonTK y -> RepW target (UnWind y) -> target y
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> RepW target (UnWind y) -> target y
windTarget SingletonTK y
stk
  (RepW target (UnWind y) -> target y)
-> RepW target (UnWind y) -> target y
forall a b. (a -> b) -> a -> b
$ SingletonTK (UnWind y)
-> RepW target (ADTensorKind (UnWind y)) -> RepW target (UnWind y)
forall (y :: TK) (target :: Target).
BaseTensor target =>
SingletonTK y -> RepW target (ADTensorKind y) -> RepW target y
fromADTensorKindW (SingletonTK y -> SingletonTK (UnWind y)
forall (y :: TK). SingletonTK y -> SingletonTK (UnWind y)
unWindSTK SingletonTK y
stk) (RepW target (ADTensorKind (UnWind y)) -> RepW target (UnWind y))
-> RepW target (ADTensorKind (UnWind y)) -> RepW target (UnWind y)
forall a b. (a -> b) -> a -> b
$ SingletonTK (ADTensorKind y)
-> target (ADTensorKind y) -> RepW target (UnWind (ADTensorKind y))
forall (target :: Target) (y :: TK).
ConvertTensor target =>
SingletonTK y -> target y -> RepW target (UnWind y)
unWindTarget (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) target (ADTensorKind y)
a