{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | Two kinds of singletons for tensor kindss and constraints
-- and lemmas associated with the singletons.
module HordeAd.Core.TensorKind
  ( -- * Tensor kind singletons
    SingletonTK(..), KnownSTK(..)
  , TKConversion(..), convertSTK, convertFTK, buildTKConversion
  , withKnownSTK, lemKnownSTK, sameKnownSTK, sameSTK
  , stkUnit, buildSTK, razeSTK, adSTK
  , lemKnownSTKOfBuild, lemKnownSTKOfAD, lemBuildOfAD, lengthSTK, widthSTK
    -- * Full shape tensor kind quasi-singletons
  , FullShapeTK(..)
  , matchingFTK, ftkToSTK, ftkUnit, buildFTK, razeFTK, adFTK, differentiableFTK
  , DummyDualTarget(..)
  ) where

import Prelude hiding ((.))

import Control.Category
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import GHC.Exts (withDict)
import GHC.TypeLits (KnownNat, OrderingI (..), cmpNat, fromSNat, type (+))
import Type.Reflection (typeRep)

import Data.Array.Nested (MapJust, Replicate, type (++))
import Data.Array.Nested.Convert (shrFromShX, shsFromShX, shxFromShS, shxFromShR, shsFromSSX)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (unsafeCoerceRefl)

import HordeAd.Core.Types

-- * Tensor kind singletons

-- | Tensor kind singleton type.
type role SingletonTK nominal
data SingletonTK y where
  STKScalar :: GoodScalar r
            => SingletonTK (TKScalar r)
  STKR :: SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
  STKS :: ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
  STKX :: StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
  STKProduct :: SingletonTK y -> SingletonTK z
             -> SingletonTK (TKProduct y z)

deriving instance Show (SingletonTK y)

-- | The constraints corresponding to 'SingletonTK'.
class KnownSTK (y :: TK) where
  knownSTK :: SingletonTK y

instance GoodScalar r => KnownSTK (TKScalar r) where
  knownSTK :: SingletonTK (TKScalar r)
knownSTK = SingletonTK (TKScalar r)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar

instance (KnownSTK x, KnownNat n)
         => KnownSTK (TKR2 n x) where
  knownSTK :: SingletonTK (TKR2 n x)
knownSTK = SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK

instance (KnownSTK x, KnownShS sh)
         => KnownSTK (TKS2 sh x) where
  knownSTK :: SingletonTK (TKS2 sh x)
knownSTK = ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS ShS sh
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK

instance (KnownSTK x, KnownShX sh)
         => KnownSTK (TKX2 sh x) where
  knownSTK :: SingletonTK (TKX2 sh x)
knownSTK = StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK

instance (KnownSTK y, KnownSTK z)
         => KnownSTK (TKProduct y z) where
  knownSTK :: SingletonTK (TKProduct y z)
knownSTK = SingletonTK y -> SingletonTK z -> SingletonTK (TKProduct y z)
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @y) (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @z)

-- | Turning a singleton into a constraint via a continuation.
withKnownSTK :: forall y r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK :: forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK = forall (cls :: Constraint) meth r.
WithDict cls meth =>
meth -> (cls => r) -> r
withDict @(KnownSTK y)

-- | Turning a singleton into a dictionary containing constraint.
lemKnownSTK :: SingletonTK y -> Dict KnownSTK y
lemKnownSTK :: forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK = \case
  SingletonTK y
STKScalar -> Dict @TK KnownSTK y
forall {k} (c :: k -> Constraint) (a :: k). c a => Dict @k c a
Dict
  STKR SNat n
SNat SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> Dict @TK KnownSTK y
forall {k} (c :: k -> Constraint) (a :: k). c a => Dict @k c a
Dict
  STKS ShS sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> ShS sh
-> (KnownShS sh => Dict @TK KnownSTK y) -> Dict @TK KnownSTK y
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh Dict @TK KnownSTK y
KnownShS sh => Dict @TK KnownSTK y
forall {k} (c :: k -> Constraint) (a :: k). c a => Dict @k c a
Dict
  STKX StaticShX sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> StaticShX sh
-> (KnownShX sh => Dict @TK KnownSTK y) -> Dict @TK KnownSTK y
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh Dict @TK KnownSTK y
KnownShX sh => Dict @TK KnownSTK y
forall {k} (c :: k -> Constraint) (a :: k). c a => Dict @k c a
Dict
  STKProduct SingletonTK y
stk1 SingletonTK z
stk2 | Dict @TK KnownSTK y
Dict <- SingletonTK y -> Dict @TK KnownSTK y
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK y
stk1
                       , Dict @TK KnownSTK z
Dict <- SingletonTK z -> Dict @TK KnownSTK z
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK z
stk2 -> Dict @TK KnownSTK y
forall {k} (c :: k -> Constraint) (a :: k). c a => Dict @k c a
Dict

sameKnownSTK :: forall y1 y2. (KnownSTK y1, KnownSTK y2)
             => Maybe (y1 :~: y2)
sameKnownSTK :: forall (y1 :: TK) (y2 :: TK).
(KnownSTK y1, KnownSTK y2) =>
Maybe ((:~:) @TK y1 y2)
sameKnownSTK = SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @y1) (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @y2)

-- | A plausible implementation of `testEquality` on `SingletonTK`.
sameSTK :: SingletonTK y1 -> SingletonTK y2 -> Maybe (y1 :~: y2)
sameSTK :: forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK y1
stk1 SingletonTK y2
stk2 = case (SingletonTK y1
stk1, SingletonTK y2
stk2) of
  (STKScalar @r1, STKScalar @r2)
    | Just (:~:) @Type r r
Refl <- TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (STKR SNat n
snat1 SingletonTK x
x1, STKR SNat n
snat2 SingletonTK x
x2)
    | Just (:~:) @TK x x
Refl <- SingletonTK x -> SingletonTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK x
x1 SingletonTK x
x2, Just (:~:) @Nat n n
Refl <- SNat n -> SNat n -> Maybe ((:~:) @Nat n n)
forall (a :: Nat) (b :: Nat).
SNat a -> SNat b -> Maybe ((:~:) @Nat a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality SNat n
snat1 SNat n
snat2 ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (STKS ShS sh
sh1 SingletonTK x
x1, STKS ShS sh
sh2 SingletonTK x
x2)
    | Just (:~:) @TK x x
Refl <- SingletonTK x -> SingletonTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK x
x1 SingletonTK x
x2, Just (:~:) @[Nat] sh sh
Refl <- ShS sh -> ShS sh -> Maybe ((:~:) @[Nat] sh sh)
forall (a :: [Nat]) (b :: [Nat]).
ShS a -> ShS b -> Maybe ((:~:) @[Nat] a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality ShS sh
sh1 ShS sh
sh2 ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (STKX StaticShX sh
sh1 SingletonTK x
x1, STKX StaticShX sh
sh2 SingletonTK x
x2)
    | Just (:~:) @TK x x
Refl <- SingletonTK x -> SingletonTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK x
x1 SingletonTK x
x2, Just (:~:) @[Maybe Nat] sh sh
Refl <- StaticShX sh -> StaticShX sh -> Maybe ((:~:) @[Maybe Nat] sh sh)
forall (a :: [Maybe Nat]) (b :: [Maybe Nat]).
StaticShX a -> StaticShX b -> Maybe ((:~:) @[Maybe Nat] a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality StaticShX sh
sh1 StaticShX sh
sh2 ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (STKProduct SingletonTK y
x1 SingletonTK z
y1, STKProduct SingletonTK y
x2 SingletonTK z
y2)
    | Just (:~:) @TK y y
Refl <- SingletonTK y -> SingletonTK y -> Maybe ((:~:) @TK y y)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK y
x1 SingletonTK y
x2, Just (:~:) @TK z z
Refl <- SingletonTK z -> SingletonTK z -> Maybe ((:~:) @TK z z)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK z
y1 SingletonTK z
y2 ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (SingletonTK y1, SingletonTK y2)
_ -> Maybe ((:~:) @TK y1 y2)
forall a. Maybe a
Nothing

stkUnit :: SingletonTK TKUnit
stkUnit :: SingletonTK TKUnit
stkUnit = SingletonTK TKUnit
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar

buildSTK :: SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK :: forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK snat :: SNat k
snat@SNat k
SNat = \case
  stk :: SingletonTK y
stk@SingletonTK y
STKScalar -> ShS ((':) @Nat k ('[] @Nat))
-> SingletonTK y -> SingletonTK (TKS2 ((':) @Nat k ('[] @Nat)) y)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS (SNat k
snat SNat k -> ShS ('[] @Nat) -> ShS ((':) @Nat k ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) SingletonTK y
stk
  STKR SNat n
SNat SingletonTK x
x -> SNat (1 + n) -> SingletonTK x -> SingletonTK (TKR2 (1 + n) x)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR SNat (1 + n)
forall (n :: Nat). KnownNat n => SNat n
SNat SingletonTK x
x
  STKS ShS sh
sh SingletonTK x
x -> ShS ((':) @Nat k sh)
-> SingletonTK x -> SingletonTK (TKS2 ((':) @Nat k sh) x)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS (SNat k
snat SNat k -> ShS sh -> ShS ((':) @Nat k sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) SingletonTK x
x
  STKX StaticShX sh
sh SingletonTK x
x -> StaticShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
-> SingletonTK x
-> SingletonTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (SNat k -> SMayNat @Nat () SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
snat SMayNat @Nat () SNat ('Just @Nat k)
-> StaticShX sh -> StaticShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
forall {sh1 :: [Maybe Nat]} (n :: Maybe Nat) (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX sh
sh) SingletonTK x
x
  STKProduct SingletonTK y
stk1 SingletonTK z
stk2 -> SingletonTK (BuildTensorKind k y)
-> SingletonTK (BuildTensorKind k z)
-> SingletonTK
     (TKProduct (BuildTensorKind k y) (BuildTensorKind k z))
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK SNat k
snat SingletonTK y
stk1) (SNat k -> SingletonTK z -> SingletonTK (BuildTensorKind k z)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK SNat k
snat SingletonTK z
stk2)

razeSTK :: SingletonTK z -> SingletonTK (RazeTensorKind z)
razeSTK :: forall (z :: TK). SingletonTK z -> SingletonTK (RazeTensorKind z)
razeSTK = \case
  SingletonTK z
STKScalar -> String -> SingletonTK (RazeTensorKind (TKScalar r))
forall a. HasCallStack => String -> a
error String
"razeSTK: impossible argument"
  STKR snat :: SNat n
snat@SNat n
SNat SingletonTK x
x ->
    case SNat 1 -> SNat n -> OrderingI @Nat 1 n
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI @Nat a b
cmpNat (forall (n :: Nat). KnownNat n => SNat n
SNat @1) SNat n
snat of
      OrderingI @Nat 1 n
LTI -> SNat (n - 1) -> SingletonTK x -> SingletonTK (TKR2 (n - 1) x)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR SNat (n - 1)
forall (n :: Nat). KnownNat n => SNat n
SNat SingletonTK x
x
      OrderingI @Nat 1 n
EQI -> SNat 0 -> SingletonTK x -> SingletonTK (TKR2 0 x)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat SingletonTK x
x
      OrderingI @Nat 1 n
_ -> String -> SingletonTK (TKR2 (n - 1) x)
forall a. HasCallStack => String -> a
error String
"razeSTK: impossible argument"
  STKS ShS sh
ZSS SingletonTK x
_ -> String -> SingletonTK (TKS2 (Tail @Nat ('[] @Nat)) x)
forall a. HasCallStack => String -> a
error String
"razeSTK: impossible argument"
  STKS (SNat n
_ :$$ ShS sh
sh) SingletonTK x
x -> ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS ShS sh
sh SingletonTK x
x
  STKX StaticShX sh
ZKX SingletonTK x
_ -> String
-> SingletonTK (TKX2 (Tail @(Maybe Nat) ('[] @(Maybe Nat))) x)
forall a. HasCallStack => String -> a
error String
"razeSTK: impossible argument"
  STKX (SUnknown ()
_ :!% StaticShX sh
_) SingletonTK x
_ -> String -> SingletonTK (TKX2 sh x)
forall a. HasCallStack => String -> a
error String
"razeSTK: impossible argument"
  STKX (SKnown SNat n1
_ :!% StaticShX sh
sh) SingletonTK x
x -> StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
sh SingletonTK x
x
  STKProduct SingletonTK y
stk1 SingletonTK z
stk2 -> SingletonTK (RazeTensorKind y)
-> SingletonTK (RazeTensorKind z)
-> SingletonTK (TKProduct (RazeTensorKind y) (RazeTensorKind z))
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (SingletonTK y -> SingletonTK (RazeTensorKind y)
forall (z :: TK). SingletonTK z -> SingletonTK (RazeTensorKind z)
razeSTK SingletonTK y
stk1) (SingletonTK z -> SingletonTK (RazeTensorKind z)
forall (z :: TK). SingletonTK z -> SingletonTK (RazeTensorKind z)
razeSTK SingletonTK z
stk2)

adSTK :: SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK :: forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK = \case
  t :: SingletonTK y
t@(STKScalar @r) -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> SingletonTK y
SingletonTK (ADTensorKind y)
t
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> SingletonTK y
SingletonTK (ADTensorKind y)
t
      Maybe ((:~:) @Type r Float)
_ -> (:~:) @Type (ADTensorScalar r) Z1
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    SingletonTK (TKScalar (ADTensorScalar r)))
-> SingletonTK (TKScalar (ADTensorScalar r))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Type (ADTensorScalar r) Z1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorScalar r :~: Z1)
           SingletonTK TKUnit
SingletonTK (TKScalar (ADTensorScalar r))
((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
SingletonTK (TKScalar (ADTensorScalar r))
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar
  STKR SNat n
sh SingletonTK x
x -> SNat n
-> SingletonTK (ADTensorKind x)
-> SingletonTK (TKR2 n (ADTensorKind x))
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR SNat n
sh (SingletonTK (ADTensorKind x)
 -> SingletonTK (TKR2 n (ADTensorKind x)))
-> SingletonTK (ADTensorKind x)
-> SingletonTK (TKR2 n (ADTensorKind x))
forall a b. (a -> b) -> a -> b
$ SingletonTK x -> SingletonTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK x
x
  STKS ShS sh
sh SingletonTK x
x -> ShS sh
-> SingletonTK (ADTensorKind x)
-> SingletonTK (TKS2 sh (ADTensorKind x))
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS ShS sh
sh (SingletonTK (ADTensorKind x)
 -> SingletonTK (TKS2 sh (ADTensorKind x)))
-> SingletonTK (ADTensorKind x)
-> SingletonTK (TKS2 sh (ADTensorKind x))
forall a b. (a -> b) -> a -> b
$ SingletonTK x -> SingletonTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK x
x
  STKX StaticShX sh
sh SingletonTK x
x -> StaticShX sh
-> SingletonTK (ADTensorKind x)
-> SingletonTK (TKX2 sh (ADTensorKind x))
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
sh (SingletonTK (ADTensorKind x)
 -> SingletonTK (TKX2 sh (ADTensorKind x)))
-> SingletonTK (ADTensorKind x)
-> SingletonTK (TKX2 sh (ADTensorKind x))
forall a b. (a -> b) -> a -> b
$ SingletonTK x -> SingletonTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK x
x
  STKProduct SingletonTK y
stk1 SingletonTK z
stk2 -> SingletonTK (ADTensorKind y)
-> SingletonTK (ADTensorKind z)
-> SingletonTK (TKProduct (ADTensorKind y) (ADTensorKind z))
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk1) (SingletonTK z -> SingletonTK (ADTensorKind z)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK z
stk2)

lemKnownSTKOfBuild :: SNat k -> SingletonTK y
                     -> Dict KnownSTK (BuildTensorKind k y)
lemKnownSTKOfBuild :: forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> Dict @TK KnownSTK (BuildTensorKind k y)
lemKnownSTKOfBuild SNat k
snat = SingletonTK (BuildTensorKind k y)
-> Dict @TK KnownSTK (BuildTensorKind k y)
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK (SingletonTK (BuildTensorKind k y)
 -> Dict @TK KnownSTK (BuildTensorKind k y))
-> (SingletonTK y -> SingletonTK (BuildTensorKind k y))
-> SingletonTK y
-> Dict @TK KnownSTK (BuildTensorKind k y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category @k cat =>
cat b c -> cat a b -> cat a c
. SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK SNat k
snat

lemKnownSTKOfAD :: SingletonTK y
                  -> Dict KnownSTK (ADTensorKind y)
lemKnownSTKOfAD :: forall (y :: TK).
SingletonTK y -> Dict @TK KnownSTK (ADTensorKind y)
lemKnownSTKOfAD = SingletonTK (ADTensorKind y) -> Dict @TK KnownSTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK (SingletonTK (ADTensorKind y)
 -> Dict @TK KnownSTK (ADTensorKind y))
-> (SingletonTK y -> SingletonTK (ADTensorKind y))
-> SingletonTK y
-> Dict @TK KnownSTK (ADTensorKind y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category @k cat =>
cat b c -> cat a b -> cat a c
. SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK

lemBuildOfAD :: SNat k -> SingletonTK y
             -> BuildTensorKind k (ADTensorKind y)
                :~: ADTensorKind (BuildTensorKind k y)
lemBuildOfAD :: forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
     @TK
     (BuildTensorKind k (ADTensorKind y))
     (ADTensorKind (BuildTensorKind k y))
lemBuildOfAD snat :: SNat k
snat@SNat k
SNat = \case
  SingletonTK y
STKScalar -> (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind y))
  (ADTensorKind (BuildTensorKind k y))
(:~:)
  @TK
  (TKS ((':) @Nat k ('[] @Nat)) (ADTensorScalar r))
  (TKS ((':) @Nat k ('[] @Nat)) (ADTensorScalar r))
forall {k} (a :: k). (:~:) @k a a
Refl
  STKR{} -> (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind y))
  (ADTensorKind (BuildTensorKind k y))
(:~:)
  @TK (TKR2 (1 + n) (ADTensorKind x)) (TKR2 (1 + n) (ADTensorKind x))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
  STKS{} -> (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind y))
  (ADTensorKind (BuildTensorKind k y))
(:~:)
  @TK
  (TKS2 ((':) @Nat k sh) (ADTensorKind x))
  (TKS2 ((':) @Nat k sh) (ADTensorKind x))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
  STKX{} -> (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind y))
  (ADTensorKind (BuildTensorKind k y))
(:~:)
  @TK
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) (ADTensorKind x))
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) (ADTensorKind x))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
  STKProduct SingletonTK y
stk1 SingletonTK z
stk2 | (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind y))
  (ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
     @TK
     (BuildTensorKind k (ADTensorKind y))
     (ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
     @TK
     (BuildTensorKind k (ADTensorKind y))
     (ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk1
                       , (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind z))
  (ADTensorKind (BuildTensorKind k z))
Refl <- SNat k
-> SingletonTK z
-> (:~:)
     @TK
     (BuildTensorKind k (ADTensorKind z))
     (ADTensorKind (BuildTensorKind k z))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
     @TK
     (BuildTensorKind k (ADTensorKind y))
     (ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK z
stk2 -> (:~:)
  @TK
  (BuildTensorKind k (ADTensorKind y))
  (ADTensorKind (BuildTensorKind k y))
(:~:)
  @TK
  (TKProduct
     (BuildTensorKind k (ADTensorKind y))
     (BuildTensorKind k (ADTensorKind z)))
  (TKProduct
     (BuildTensorKind k (ADTensorKind y))
     (BuildTensorKind k (ADTensorKind z)))
forall {k} (a :: k). (:~:) @k a a
Refl

lengthSTK :: SingletonTK x -> Int
lengthSTK :: forall (x :: TK). SingletonTK x -> Int
lengthSTK SingletonTK x
STKScalar = Int
0
lengthSTK (STKR SNat n
snat SingletonTK x
_) = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
fromSNat SNat n
snat
lengthSTK (STKS ShS sh
sh SingletonTK x
_) = ShS sh -> Int
forall (sh :: [Nat]). ShS sh -> Int
shsLength ShS sh
sh
lengthSTK (STKX StaticShX sh
sh SingletonTK x
_) = StaticShX sh -> Int
forall (sh :: [Maybe Nat]). StaticShX sh -> Int
ssxLength StaticShX sh
sh
lengthSTK (STKProduct SingletonTK y
sy SingletonTK z
sz) = SingletonTK y -> Int
forall (x :: TK). SingletonTK x -> Int
lengthSTK SingletonTK y
sy Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` SingletonTK z -> Int
forall (x :: TK). SingletonTK x -> Int
lengthSTK SingletonTK z
sz

widthSTK :: SingletonTK y -> Int
widthSTK :: forall (x :: TK). SingletonTK x -> Int
widthSTK SingletonTK y
stk = case SingletonTK y
stk of
  STKScalar @r -> case TypeRep @Type r -> TypeRep @Type Z1 -> Maybe ((:~:) @Type r Z1)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Z1) of
    Just (:~:) @Type r Z1
Refl -> Int
0
    Maybe ((:~:) @Type r Z1)
_ -> Int
1
  STKR{} -> Int
1
  STKS{} -> Int
1
  STKX{} -> Int
1
  STKProduct SingletonTK y
stk1 SingletonTK z
stk2 -> SingletonTK y -> Int
forall (x :: TK). SingletonTK x -> Int
widthSTK SingletonTK y
stk1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ SingletonTK z -> Int
forall (x :: TK). SingletonTK x -> Int
widthSTK SingletonTK z
stk2

-- | This is copied, with modifications, from ox-arrays.
--
-- This is a recipe for converting arrays, not always followed,
-- and a proof a conversion is possible, with some proof obligations
-- delayed to runtime (in ConvXS' and ConvXX', where not only the ranks
-- of the shapes need to agree, but also the dimensions of the input
-- array and of the output shape, which is not all captured in the type).
-- As in ox-arrays, conversions only change the meta-data, not the underlying
-- vector representation of the array.
type role TKConversion nominal nominal
data TKConversion (a :: TK) (b :: TK) where
  ConvId  :: TKConversion a a
  ConvCmp :: TKConversion b c -> TKConversion a b -> TKConversion a c

  ConvRX  :: TKConversion (TKR2 n a) (TKX2 (Replicate n Nothing) a)
  ConvSX  :: TKConversion (TKS2 sh a) (TKX2 (MapJust sh) a)

  ConvXR  :: SingletonTK a -> TKConversion (TKX2 sh a) (TKR2 (Rank sh) a)
  ConvXS  :: TKConversion (TKX2 (MapJust sh) a) (TKS2 sh a)
  ConvXS' :: Rank sh ~ Rank sh'
          => FullShapeTK (TKS2 sh' a)
          -> TKConversion (TKX2 sh a) (TKS2 sh' a)

  ConvXX' :: Rank sh ~ Rank sh'
          => FullShapeTK (TKX2 sh' a)
          -> TKConversion (TKX2 sh a) (TKX2 sh' a)

  ConvRR  :: TKConversion a b -> TKConversion (TKR2 n a) (TKR2 n b)
  ConvSS  :: TKConversion a b -> TKConversion (TKS2 sh a) (TKS2 sh b)
  ConvXX  :: TKConversion a b -> TKConversion (TKX2 sh a) (TKX2 sh b)
  ConvT2  :: TKConversion a a'
          -> TKConversion b b'
          -> TKConversion (TKProduct a b) (TKProduct a' b')

  Conv0X  :: SingletonTK a -> TKConversion a (TKX2 '[] a)
  ConvX0  :: TKConversion (TKX2 '[] a) a

  ConvNest :: SingletonTK (TKX2 sh a)
           -> TKConversion (TKX2 (sh ++ sh') a) (TKX2 sh (TKX2 sh' a))
  ConvUnnest :: TKConversion (TKX2 sh (TKX2 sh' a)) (TKX2 (sh ++ sh') a)

  ConvZip   :: SingletonTK a -> SingletonTK b
            -> TKConversion (TKProduct (TKX2 sh a) (TKX2 sh b))
                            (TKX2 sh (TKProduct a b))
  ConvUnzip :: SingletonTK a -> SingletonTK b
            -> TKConversion (TKX2 sh (TKProduct a b))
                            (TKProduct (TKX2 sh a) (TKX2 sh b))

deriving instance Show (TKConversion a b)

instance Category TKConversion where
  id :: forall (a :: TK). TKConversion a a
id = TKConversion a a
forall (a :: TK). TKConversion a a
ConvId
  . :: forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
(.) = TKConversion b c -> TKConversion a b -> TKConversion a c
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp

convertSTK :: TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK :: forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK = \cases
  TKConversion a b
ConvId SingletonTK a
astk -> SingletonTK a
SingletonTK b
astk
  (ConvCmp TKConversion b b
c1 TKConversion a b
c2) SingletonTK a
astk -> TKConversion b b -> SingletonTK b -> SingletonTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion b b
c1 (TKConversion a b -> SingletonTK a -> SingletonTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion a b
c2 SingletonTK a
astk)
  TKConversion a b
ConvRX (STKR SNat n
n SingletonTK x
a) -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> SingletonTK x
-> SingletonTK (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
forall (n :: Nat).
SNat n -> StaticShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
ssxReplicate SNat n
n) SingletonTK x
a
  TKConversion a b
ConvSX (STKS ShS sh
sh SingletonTK x
a) -> StaticShX (MapJust @Nat sh)
-> SingletonTK x -> SingletonTK (TKX2 (MapJust @Nat sh) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh))
-> ShX (MapJust @Nat sh) Int -> StaticShX (MapJust @Nat sh)
forall a b. (a -> b) -> a -> b
$ ShS sh -> IShX (MapJust @Nat sh)
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh) SingletonTK x
a
  (ConvXR SingletonTK a
_stk) (STKX StaticShX sh
ssx SingletonTK x
a) -> SNat (Rank @(Maybe Nat) sh)
-> SingletonTK x -> SingletonTK (TKR2 (Rank @(Maybe Nat) sh) x)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR (StaticShX sh -> SNat (Rank @(Maybe Nat) sh)
forall (sh :: [Maybe Nat]).
StaticShX sh -> SNat (Rank @(Maybe Nat) sh)
ssxRank StaticShX sh
ssx) SingletonTK x
a
  TKConversion a b
ConvXS (STKX StaticShX sh
ssx SingletonTK x
a) -> ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS (StaticShX (MapJust @Nat sh) -> ShS sh
forall (sh :: [Nat]). StaticShX (MapJust @Nat sh) -> ShS sh
shsFromSSX StaticShX sh
StaticShX (MapJust @Nat sh)
ssx) SingletonTK x
a
  (ConvXS' (FTKS ShS sh
sh FullShapeTK x
_x)) (STKX StaticShX sh
_ssx2 SingletonTK x
a) -> ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS ShS sh
sh SingletonTK x
a
  (ConvXX' (FTKX IShX sh
shx FullShapeTK x
_x)) (STKX StaticShX sh
_ssx2 SingletonTK x
a) -> StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
shx) SingletonTK x
a
  (ConvRR TKConversion a b
c) (STKR SNat n
n SingletonTK x
a) -> SNat n -> SingletonTK b -> SingletonTK (TKR2 n b)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR SNat n
n (TKConversion a b -> SingletonTK a -> SingletonTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion a b
c SingletonTK a
SingletonTK x
a)
  (ConvSS TKConversion a b
c) (STKS ShS sh
sh SingletonTK x
a) -> ShS sh -> SingletonTK b -> SingletonTK (TKS2 sh b)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS ShS sh
sh (TKConversion a b -> SingletonTK a -> SingletonTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion a b
c SingletonTK a
SingletonTK x
a)
  (ConvXX TKConversion a b
c) (STKX StaticShX sh
ssx SingletonTK x
a) -> StaticShX sh -> SingletonTK b -> SingletonTK (TKX2 sh b)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
ssx (TKConversion a b -> SingletonTK a -> SingletonTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion a b
c SingletonTK a
SingletonTK x
a)
  (ConvT2 TKConversion a a'
c1 TKConversion b b'
c2) (STKProduct SingletonTK y
stk1 SingletonTK z
stk2) ->
    SingletonTK a' -> SingletonTK b' -> SingletonTK (TKProduct a' b')
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (TKConversion a a' -> SingletonTK a -> SingletonTK a'
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion a a'
c1 SingletonTK a
SingletonTK y
stk1) (TKConversion b b' -> SingletonTK b -> SingletonTK b'
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion b b'
c2 SingletonTK b
SingletonTK z
stk2)
  (Conv0X SingletonTK a
_stk) SingletonTK a
stk -> StaticShX ('[] @(Maybe Nat))
-> SingletonTK a -> SingletonTK (TKX2 ('[] @(Maybe Nat)) a)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]).
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
StaticShX sh
ZKX SingletonTK a
stk
  TKConversion a b
ConvX0 (STKX StaticShX sh
ZKX SingletonTK x
stk) -> SingletonTK b
SingletonTK x
stk
  (ConvNest (STKX StaticShX sh
ssx SingletonTK x
x)) (STKX StaticShX sh
shsh' SingletonTK x
_x) ->
    StaticShX sh
-> SingletonTK (TKX2 sh' x) -> SingletonTK (TKX2 sh (TKX2 sh' x))
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
ssx (StaticShX sh' -> SingletonTK x -> SingletonTK (TKX2 sh' x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (StaticShX sh
-> StaticShX ((++) @(Maybe Nat) sh sh') -> StaticShX sh'
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX ((++) @(Maybe Nat) sh sh') -> StaticShX sh'
ssxDropSSX StaticShX sh
ssx StaticShX sh
StaticShX ((++) @(Maybe Nat) sh sh')
shsh') SingletonTK x
x)
  TKConversion a b
ConvUnnest (STKX StaticShX sh
sh (STKX StaticShX sh
sh' SingletonTK x
x)) -> StaticShX ((++) @(Maybe Nat) sh sh')
-> SingletonTK x -> SingletonTK (TKX2 ((++) @(Maybe Nat) sh sh') x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (StaticShX sh
sh StaticShX sh -> StaticShX sh -> StaticShX ((++) @(Maybe Nat) sh sh)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Nat) sh sh')
`ssxAppend` StaticShX sh
sh') SingletonTK x
x
  (ConvZip SingletonTK a
_ SingletonTK b
_) (STKProduct (STKX StaticShX sh
sh SingletonTK x
a1) (STKX StaticShX sh
_sh SingletonTK x
a2)) ->
    StaticShX sh
-> SingletonTK (TKProduct x x)
-> SingletonTK (TKX2 sh (TKProduct x x))
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
sh (SingletonTK x -> SingletonTK x -> SingletonTK (TKProduct x x)
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct SingletonTK x
a1 SingletonTK x
a2)
  (ConvUnzip SingletonTK a
_ SingletonTK b
_) (STKX StaticShX sh
sh (STKProduct SingletonTK y
a1 SingletonTK z
a2)) ->
    SingletonTK (TKX2 sh y)
-> SingletonTK (TKX2 sh z)
-> SingletonTK (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (StaticShX sh -> SingletonTK y -> SingletonTK (TKX2 sh y)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
sh SingletonTK y
a1) (StaticShX sh -> SingletonTK z -> SingletonTK (TKX2 sh z)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX StaticShX sh
sh SingletonTK z
a2)

convertFTK :: TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK :: forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK = \cases
  TKConversion a b
ConvId FullShapeTK a
aftk -> FullShapeTK a
FullShapeTK b
aftk
  (ConvCmp TKConversion b b
c1 TKConversion a b
c2) FullShapeTK a
aftk -> TKConversion b b -> FullShapeTK b -> FullShapeTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion b b
c1 (TKConversion a b -> FullShapeTK a -> FullShapeTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a b
c2 FullShapeTK a
aftk)
  TKConversion a b
ConvRX (FTKR IShR n
shr FullShapeTK x
a) -> IShX (Replicate @(Maybe Nat) n ('Nothing @Nat))
-> FullShapeTK x
-> FullShapeTK (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
shr) FullShapeTK x
a
  TKConversion a b
ConvSX (FTKS ShS sh
sh FullShapeTK x
a) -> IShX (MapJust @Nat sh)
-> FullShapeTK x -> FullShapeTK (TKX2 (MapJust @Nat sh) x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (ShS sh -> IShX (MapJust @Nat sh)
forall (sh :: [Nat]). ShS sh -> IShX (MapJust @Nat sh)
shxFromShS ShS sh
sh) FullShapeTK x
a
  (ConvXR SingletonTK a
_stk) (FTKX IShX sh
shx FullShapeTK x
a) -> IShR (Rank @(Maybe Nat) sh)
-> FullShapeTK x -> FullShapeTK (TKR2 (Rank @(Maybe Nat) sh) x)
forall (r :: Nat) (b :: TK).
IShR r -> FullShapeTK b -> FullShapeTK (TKR2 r b)
FTKR (IShX sh -> IShR (Rank @(Maybe Nat) sh)
forall (sh :: [Maybe Nat]). IShX sh -> IShR (Rank @(Maybe Nat) sh)
shrFromShX IShX sh
shx) FullShapeTK x
a
  TKConversion a b
ConvXS (FTKX IShX sh
shx FullShapeTK x
a) -> ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> FullShapeTK b -> FullShapeTK (TKS2 r b)
FTKS (ShX (MapJust @Nat sh) Int -> ShS sh
forall (sh :: [Nat]) i. ShX (MapJust @Nat sh) i -> ShS sh
shsFromShX IShX sh
ShX (MapJust @Nat sh) Int
shx) FullShapeTK x
a
  (ConvXS' FullShapeTK (TKS2 sh' a)
ftk) FullShapeTK a
_ -> FullShapeTK b
FullShapeTK (TKS2 sh' a)
ftk
  (ConvXX' FullShapeTK (TKX2 sh' a)
ftk) FullShapeTK a
_ -> FullShapeTK b
FullShapeTK (TKX2 sh' a)
ftk
  (ConvRR TKConversion a b
c) (FTKR IShR n
shr FullShapeTK x
a) -> IShR n -> FullShapeTK b -> FullShapeTK (TKR2 n b)
forall (r :: Nat) (b :: TK).
IShR r -> FullShapeTK b -> FullShapeTK (TKR2 r b)
FTKR IShR n
shr (TKConversion a b -> FullShapeTK a -> FullShapeTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a b
c FullShapeTK a
FullShapeTK x
a)
  (ConvSS TKConversion a b
c) (FTKS ShS sh
sh FullShapeTK x
a) -> ShS sh -> FullShapeTK b -> FullShapeTK (TKS2 sh b)
forall (r :: [Nat]) (b :: TK).
ShS r -> FullShapeTK b -> FullShapeTK (TKS2 r b)
FTKS ShS sh
sh (TKConversion a b -> FullShapeTK a -> FullShapeTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a b
c FullShapeTK a
FullShapeTK x
a)
  (ConvXX TKConversion a b
c) (FTKX IShX sh
shx FullShapeTK x
a) -> IShX sh -> FullShapeTK b -> FullShapeTK (TKX2 sh b)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX IShX sh
shx (TKConversion a b -> FullShapeTK a -> FullShapeTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a b
c FullShapeTK a
FullShapeTK x
a)
  (ConvT2 TKConversion a a'
c1 TKConversion b b'
c2) (FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2) ->
    FullShapeTK a' -> FullShapeTK b' -> FullShapeTK (TKProduct a' b')
forall (r :: TK) (b :: TK).
FullShapeTK r -> FullShapeTK b -> FullShapeTK (TKProduct r b)
FTKProduct (TKConversion a a' -> FullShapeTK a -> FullShapeTK a'
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a a'
c1 FullShapeTK a
FullShapeTK y
ftk1) (TKConversion b b' -> FullShapeTK b -> FullShapeTK b'
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion b b'
c2 FullShapeTK b
FullShapeTK z
ftk2)
  (Conv0X SingletonTK a
_stk) FullShapeTK a
ftk -> IShX ('[] @(Maybe Nat))
-> FullShapeTK a -> FullShapeTK (TKX2 ('[] @(Maybe Nat)) a)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX IShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX FullShapeTK a
ftk
  TKConversion a b
ConvX0 (FTKX ShX sh Int
ZSX FullShapeTK x
ftk) -> FullShapeTK b
FullShapeTK x
ftk
  (ConvNest @_ @_ @sh' (STKX StaticShX sh
ssx SingletonTK x
_x)) (FTKX IShX sh
shsh' FullShapeTK x
x) ->
    IShX sh
-> FullShapeTK (TKX2 sh' x) -> FullShapeTK (TKX2 sh (TKX2 sh' x))
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (Proxy @[Maybe Nat] sh'
-> StaticShX sh -> ShX ((++) @(Maybe Nat) sh sh') Int -> IShX sh
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i
       (proxy :: [Maybe Nat] -> Type).
proxy sh'
-> StaticShX sh -> ShX ((++) @(Maybe Nat) sh sh') i -> ShX sh i
shxTakeSSX (forall (t :: [Maybe Nat]). Proxy @[Maybe Nat] t
forall {k} (t :: k). Proxy @k t
Proxy @sh') StaticShX sh
ssx IShX sh
ShX ((++) @(Maybe Nat) sh sh') Int
shsh') (IShX sh' -> FullShapeTK x -> FullShapeTK (TKX2 sh' x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (StaticShX sh -> ShX ((++) @(Maybe Nat) sh sh') Int -> IShX sh'
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
StaticShX sh -> ShX ((++) @(Maybe Nat) sh sh') i -> ShX sh' i
shxDropSSX StaticShX sh
ssx IShX sh
ShX ((++) @(Maybe Nat) sh sh') Int
shsh') FullShapeTK x
x)
  TKConversion a b
ConvUnnest (FTKX IShX sh
sh (FTKX IShX sh
sh' FullShapeTK x
x)) -> IShX ((++) @(Maybe Nat) sh sh')
-> FullShapeTK x -> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh sh') x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (IShX sh
sh IShX sh -> IShX sh -> ShX ((++) @(Maybe Nat) sh sh) Int
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Nat) sh sh') i
`shxAppend` IShX sh
sh') FullShapeTK x
x
  (ConvZip SingletonTK a
_ SingletonTK b
_) (FTKProduct (FTKX IShX sh
sh FullShapeTK x
a1) (FTKX IShX sh
_sh FullShapeTK x
a2)) ->
    IShX sh
-> FullShapeTK (TKProduct x x)
-> FullShapeTK (TKX2 sh (TKProduct x x))
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX IShX sh
sh (FullShapeTK x -> FullShapeTK x -> FullShapeTK (TKProduct x x)
forall (r :: TK) (b :: TK).
FullShapeTK r -> FullShapeTK b -> FullShapeTK (TKProduct r b)
FTKProduct FullShapeTK x
a1 FullShapeTK x
a2)
  (ConvUnzip SingletonTK a
_ SingletonTK b
_) (FTKX IShX sh
sh (FTKProduct FullShapeTK y
a1 FullShapeTK z
a2)) ->
    FullShapeTK (TKX2 sh y)
-> FullShapeTK (TKX2 sh z)
-> FullShapeTK (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (r :: TK) (b :: TK).
FullShapeTK r -> FullShapeTK b -> FullShapeTK (TKProduct r b)
FTKProduct (IShX sh -> FullShapeTK y -> FullShapeTK (TKX2 sh y)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX IShX sh
sh FullShapeTK y
a1) (IShX sh -> FullShapeTK z -> FullShapeTK (TKX2 sh z)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX IShX sh
sh FullShapeTK z
a2)

buildTKConversion :: SNat k -> FullShapeTK a
                  -> TKConversion a b
                  -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion :: forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion SNat k
k FullShapeTK a
aftk TKConversion a b
c0 = case TKConversion a b
c0 of
  TKConversion a b
ConvId -> TKConversion (BuildTensorKind k a) (BuildTensorKind k a)
TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
forall (a :: TK). TKConversion a a
ConvId
  ConvCmp TKConversion b b
c1 TKConversion a b
c2 -> TKConversion (BuildTensorKind k b) (BuildTensorKind k b)
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (SNat k
-> FullShapeTK b
-> TKConversion b b
-> TKConversion (BuildTensorKind k b) (BuildTensorKind k b)
forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion SNat k
k (TKConversion a b -> FullShapeTK a -> FullShapeTK b
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a b
c2 FullShapeTK a
aftk) TKConversion b b
c1)
                           (SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion SNat k
k FullShapeTK a
aftk TKConversion a b
c2)
  TKConversion a b
ConvRX | FTKR @n IShR n
shr FullShapeTK x
xstk <- FullShapeTK a
aftk
         , (:~:)
  @Nat
  (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
  n
Refl <- Proxy @Nat n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n)
         , (:~:)
  @Nat
  (Rank
     @(Maybe Nat) (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)))
  (1 + n)
Refl <- Proxy @Nat (1 + n)
-> (:~:)
     @Nat
     (Rank
        @(Maybe Nat) (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)))
     (1 + n)
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @(1 + n)) ->
    TKConversion
  (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (FullShapeTK
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
((Rank @(Maybe Nat) r :: Nat) ~ (Rank @(Maybe Nat) b :: Nat)) =>
FullShapeTK (TKX2 b sh) -> TKConversion (TKX2 r sh) (TKX2 b sh)
ConvXX' (IShX
  ((':)
     @(Maybe Nat)
     ('Just @Nat k)
     (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> FullShapeTK x
-> FullShapeTK
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (SNat k -> SMayNat @Nat Int SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
k SMayNat @Nat Int SNat ('Just @Nat k)
-> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
-> IShX
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
shr) FullShapeTK x
xstk)) TKConversion
  (TKR2 (1 + n) x)
  (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
forall (r :: Nat) (b :: TK).
TKConversion
  (TKR2 r b) (TKX2 (Replicate @(Maybe Nat) r ('Nothing @Nat)) b)
ConvRX
  TKConversion a b
ConvSX -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
TKConversion
  (TKS2 ((':) @Nat k sh) a) (TKX2 (MapJust @Nat ((':) @Nat k sh)) a)
forall (r :: [Nat]) (b :: TK).
TKConversion (TKS2 r b) (TKX2 (MapJust @Nat r) b)
ConvSX
  ConvXR SingletonTK a
stk -> SingletonTK a
-> TKConversion
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) a)
     (TKR2 (Rank @(Maybe Nat) ((':) @(Maybe Nat) ('Just @Nat k) sh)) a)
forall (r :: TK) (b :: [Maybe Nat]).
SingletonTK r
-> TKConversion (TKX2 b r) (TKR2 (Rank @(Maybe Nat) b) r)
ConvXR SingletonTK a
stk
  TKConversion a b
ConvXS -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
TKConversion
  (TKX2 (MapJust @Nat ((':) @Nat k sh)) a) (TKS2 ((':) @Nat k sh) a)
forall (r :: [Nat]) (b :: TK).
TKConversion (TKX2 (MapJust @Nat r) b) (TKS2 r b)
ConvXS
  ConvXS' FullShapeTK (TKS2 sh' a)
ftk -> FullShapeTK (TKS2 ((':) @Nat k sh') a)
-> TKConversion
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) a)
     (TKS2 ((':) @Nat k sh') a)
forall (r :: [Maybe Nat]) (b :: [Nat]) (sh :: TK).
((Rank @(Maybe Nat) r :: Nat) ~ (Rank @Nat b :: Nat)) =>
FullShapeTK (TKS2 b sh) -> TKConversion (TKX2 r sh) (TKS2 b sh)
ConvXS' (SNat k
-> FullShapeTK (TKS2 sh' a)
-> FullShapeTK (BuildTensorKind k (TKS2 sh' a))
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK (TKS2 sh' a)
ftk)
  ConvXX' FullShapeTK (TKX2 sh' a)
ftk -> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh') a)
-> TKConversion
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) a)
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh') a)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
((Rank @(Maybe Nat) r :: Nat) ~ (Rank @(Maybe Nat) b :: Nat)) =>
FullShapeTK (TKX2 b sh) -> TKConversion (TKX2 r sh) (TKX2 b sh)
ConvXX' (SNat k
-> FullShapeTK (TKX2 sh' a)
-> FullShapeTK (BuildTensorKind k (TKX2 sh' a))
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK (TKX2 sh' a)
ftk)
  ConvRR TKConversion a b
c -> TKConversion a b -> TKConversion (TKR2 (1 + n) a) (TKR2 (1 + n) b)
forall (r :: TK) (b :: TK) (sh :: Nat).
TKConversion r b -> TKConversion (TKR2 sh r) (TKR2 sh b)
ConvRR TKConversion a b
c
  ConvSS TKConversion a b
c -> TKConversion a b
-> TKConversion (TKS2 ((':) @Nat k sh) a) (TKS2 ((':) @Nat k sh) b)
forall (r :: TK) (b :: TK) (sh :: [Nat]).
TKConversion r b -> TKConversion (TKS2 sh r) (TKS2 sh b)
ConvSS TKConversion a b
c
  ConvXX TKConversion a b
c -> TKConversion a b
-> TKConversion
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) a)
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) b)
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
TKConversion r b -> TKConversion (TKX2 sh r) (TKX2 sh b)
ConvXX TKConversion a b
c
  ConvT2 TKConversion a a'
c1 TKConversion b b'
c2 | FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2 <- FullShapeTK a
aftk ->
    TKConversion (BuildTensorKind k a) (BuildTensorKind k a')
-> TKConversion (BuildTensorKind k b) (BuildTensorKind k b')
-> TKConversion
     (TKProduct (BuildTensorKind k a) (BuildTensorKind k b))
     (TKProduct (BuildTensorKind k a') (BuildTensorKind k b'))
forall (r :: TK) (b :: TK) (sh :: TK) (b' :: TK).
TKConversion r b
-> TKConversion sh b'
-> TKConversion (TKProduct r sh) (TKProduct b b')
ConvT2 (SNat k
-> FullShapeTK y
-> TKConversion y a'
-> TKConversion (BuildTensorKind k y) (BuildTensorKind k a')
forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion SNat k
k FullShapeTK y
ftk1 TKConversion a a'
TKConversion y a'
c1) (SNat k
-> FullShapeTK z
-> TKConversion z b'
-> TKConversion (BuildTensorKind k z) (BuildTensorKind k b')
forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion SNat k
k FullShapeTK z
ftk2 TKConversion b b'
TKConversion z b'
c2)
  Conv0X SingletonTK a
_astk -> case FullShapeTK a
aftk of
    FullShapeTK a
FTKScalar -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
TKConversion
  (TKS2 ((':) @Nat k ('[] @Nat)) (TKScalar r))
  (TKX2 (MapJust @Nat ((':) @Nat k ('[] @Nat))) (TKScalar r))
forall (r :: [Nat]) (b :: TK).
TKConversion (TKS2 r b) (TKX2 (MapJust @Nat r) b)
ConvSX
    FTKR @n IShR n
shr FullShapeTK x
x | (:~:)
  @Nat
  (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
  n
Refl <- Proxy @Nat n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n)
                  , (:~:)
  @Nat
  (Rank
     @(Maybe Nat) (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)))
  (1 + n)
Refl <- Proxy @Nat (1 + n)
-> (:~:)
     @Nat
     (Rank
        @(Maybe Nat) (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)))
     (1 + n)
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @(1 + n)) ->
      TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (TKConversion
  (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) (TKR2 n x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
TKConversion r b -> TKConversion (TKX2 sh r) (TKX2 sh b)
ConvXX (SingletonTK x
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
     (TKR2
        (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat))) x)
forall (r :: TK) (b :: [Maybe Nat]).
SingletonTK r
-> TKConversion (TKX2 b r) (TKR2 (Rank @(Maybe Nat) b) r)
ConvXR (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)))
              (TKConversion
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (SingletonTK
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) x)
-> TKConversion
     (TKX2
        ((++)
           @(Maybe Nat)
           ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
forall (r :: [Maybe Nat]) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK (TKX2 r b)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) r sh) b) (TKX2 r (TKX2 sh b))
ConvNest (StaticShX ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (SNat k -> SMayNat @Nat () SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
k SMayNat @Nat () SNat ('Just @Nat k)
-> StaticShX ('[] @(Maybe Nat))
-> StaticShX ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
forall {sh1 :: [Maybe Nat]} (n :: Maybe Nat) (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]).
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
StaticShX sh
ZKX) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)))
                       (TKConversion
  (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
-> TKConversion
     (TKR2 (1 + n) x)
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp
                          (FullShapeTK
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
((Rank @(Maybe Nat) r :: Nat) ~ (Rank @(Maybe Nat) b :: Nat)) =>
FullShapeTK (TKX2 b sh) -> TKConversion (TKX2 r sh) (TKX2 b sh)
ConvXX' (IShX
  ((':)
     @(Maybe Nat)
     ('Just @Nat k)
     (Replicate @(Maybe Nat) n ('Nothing @Nat)))
-> FullShapeTK x
-> FullShapeTK
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (SNat k -> SMayNat @Nat Int SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
k SMayNat @Nat Int SNat ('Just @Nat k)
-> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
-> IShX
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% IShR n -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) Int
forall (n :: Nat) i.
ShR n i -> ShX (Replicate @(Maybe Nat) n ('Nothing @Nat)) i
shxFromShR IShR n
shr) FullShapeTK x
x))
                          TKConversion
  (TKR2 (1 + n) x)
  (TKX2 (Replicate @(Maybe Nat) (1 + n) ('Nothing @Nat)) x)
forall (r :: Nat) (b :: TK).
TKConversion
  (TKR2 r b) (TKX2 (Replicate @(Maybe Nat) r ('Nothing @Nat)) b)
ConvRX))
    FTKS ShS sh
_sh FullShapeTK x
x ->
      TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (MapJust @Nat sh) x))
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
-> TKConversion
     (TKS2 ((':) @Nat k sh) x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (MapJust @Nat sh) x))
-> TKConversion
     (TKS2 ((':) @Nat k sh) x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (TKConversion (TKX2 (MapJust @Nat sh) x) (TKS2 sh x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (MapJust @Nat sh) x))
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
TKConversion r b -> TKConversion (TKX2 sh r) (TKX2 sh b)
ConvXX TKConversion (TKX2 (MapJust @Nat sh) x) (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
TKConversion (TKX2 (MapJust @Nat r) b) (TKS2 r b)
ConvXS)
              (TKConversion
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (MapJust @Nat sh) x))
-> TKConversion
     (TKS2 ((':) @Nat k sh) x)
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
-> TKConversion
     (TKS2 ((':) @Nat k sh) x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (MapJust @Nat sh) x))
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (SingletonTK
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) x)
-> TKConversion
     (TKX2
        ((++)
           @(Maybe Nat)
           ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
           (MapJust @Nat sh))
        x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (MapJust @Nat sh) x))
forall (r :: [Maybe Nat]) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK (TKX2 r b)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) r sh) b) (TKX2 r (TKX2 sh b))
ConvNest (StaticShX ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (SNat k -> SMayNat @Nat () SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
k SMayNat @Nat () SNat ('Just @Nat k)
-> StaticShX ('[] @(Maybe Nat))
-> StaticShX ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
forall {sh1 :: [Maybe Nat]} (n :: Maybe Nat) (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]).
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
StaticShX sh
ZKX) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)))
                       TKConversion
  (TKS2 ((':) @Nat k sh) x)
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
TKConversion
  (TKS2 ((':) @Nat k sh) x) (TKX2 (MapJust @Nat ((':) @Nat k sh)) x)
forall (r :: [Nat]) (b :: TK).
TKConversion (TKS2 r b) (TKX2 (MapJust @Nat r) b)
ConvSX)
    FTKX IShX sh
_ssx FullShapeTK x
x -> SingletonTK
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) x)
-> TKConversion
     (TKX2
        ((++)
           @(Maybe Nat)
           ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
           sh)
        x)
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKX2 sh x))
forall (r :: [Maybe Nat]) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK (TKX2 r b)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) r sh) b) (TKX2 r (TKX2 sh b))
ConvNest (StaticShX ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
-> SingletonTK x
-> SingletonTK
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (SNat k -> SMayNat @Nat () SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
k SMayNat @Nat () SNat ('Just @Nat k)
-> StaticShX ('[] @(Maybe Nat))
-> StaticShX ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
forall {sh1 :: [Maybe Nat]} (n :: Maybe Nat) (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]).
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
StaticShX sh
ZKX) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x))
    FTKProduct FullShapeTK y
aftk1 FullShapeTK z
aftk2 ->
      SNat k
-> FullShapeTK a
-> TKConversion a (TKX2 ('[] @(Maybe Nat)) (TKProduct y z))
-> TKConversion
     (BuildTensorKind k a)
     (BuildTensorKind k (TKX2 ('[] @(Maybe Nat)) (TKProduct y z)))
forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion
        SNat k
k FullShapeTK a
aftk (TKConversion
  (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
  (TKX2 ('[] @(Maybe Nat)) (TKProduct y z))
-> TKConversion
     a
     (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
-> TKConversion a (TKX2 ('[] @(Maybe Nat)) (TKProduct y z))
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (SingletonTK y
-> SingletonTK z
-> TKConversion
     (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
     (TKX2 ('[] @(Maybe Nat)) (TKProduct y z))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK r
-> SingletonTK b
-> TKConversion
     (TKProduct (TKX2 sh r) (TKX2 sh b)) (TKX2 sh (TKProduct r b))
ConvZip (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
aftk1) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
aftk2))
                        (TKConversion y (TKX2 ('[] @(Maybe Nat)) y)
-> TKConversion z (TKX2 ('[] @(Maybe Nat)) z)
-> TKConversion
     (TKProduct y z)
     (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
forall (r :: TK) (b :: TK) (sh :: TK) (b' :: TK).
TKConversion r b
-> TKConversion sh b'
-> TKConversion (TKProduct r sh) (TKProduct b b')
ConvT2 (SingletonTK y -> TKConversion y (TKX2 ('[] @(Maybe Nat)) y)
forall (a :: TK).
SingletonTK a -> TKConversion a (TKX2 ('[] @(Maybe Nat)) a)
Conv0X (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
aftk1))
                                (SingletonTK z -> TKConversion z (TKX2 ('[] @(Maybe Nat)) z)
forall (a :: TK).
SingletonTK a -> TKConversion a (TKX2 ('[] @(Maybe Nat)) a)
Conv0X (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
aftk2))))
  TKConversion a b
ConvX0 -> case FullShapeTK a
aftk of
    FTKX ShX sh Int
ZSX FullShapeTK x
FTKScalar -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
TKConversion
  (TKX2 (MapJust @Nat ((':) @Nat k ('[] @Nat))) (TKScalar r))
  (TKS2 ((':) @Nat k ('[] @Nat)) (TKScalar r))
forall (r :: [Nat]) (b :: TK).
TKConversion (TKX2 (MapJust @Nat r) b) (TKS2 r b)
ConvXS
    FTKX ShX sh Int
ZSX (FTKR @n IShR n
_n FullShapeTK x
x) | (:~:)
  @Nat
  (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
  n
Refl <- Proxy @Nat n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n) ->
      TKConversion
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
  (TKR2 (1 + n) x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
     (TKR2 (1 + n) x)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (SingletonTK x
-> TKConversion
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
     (TKR2
        (Rank
           @(Maybe Nat)
           ((':)
              @(Maybe Nat)
              ('Just @Nat k)
              (Replicate @(Maybe Nat) n ('Nothing @Nat))))
        x)
forall (r :: TK) (b :: [Maybe Nat]).
SingletonTK r
-> TKConversion (TKX2 b r) (TKR2 (Rank @(Maybe Nat) b) r)
ConvXR (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x))
              (TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat k)
           (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        x)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat k)
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
  (TKX2
     ((++)
        @(Maybe Nat)
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     x)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
TKConversion (TKX2 r (TKX2 b sh)) (TKX2 ((++) @(Maybe Nat) r b) sh)
ConvUnnest (TKConversion
  (TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKR2 n x))
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
TKConversion r b -> TKConversion (TKX2 sh r) (TKX2 sh b)
ConvXX TKConversion
  (TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (r :: Nat) (b :: TK).
TKConversion
  (TKR2 r b) (TKX2 (Replicate @(Maybe Nat) r ('Nothing @Nat)) b)
ConvRX))
    FTKX ShX sh Int
ZSX FTKS{} ->
      TKConversion
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
  (TKS2 ((':) @Nat k sh) x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
     (TKS2 ((':) @Nat k sh) x)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp TKConversion
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
  (TKS2 ((':) @Nat k sh) x)
TKConversion
  (TKX2 (MapJust @Nat ((':) @Nat k sh)) x) (TKS2 ((':) @Nat k sh) x)
forall (r :: [Nat]) (b :: TK).
TKConversion (TKX2 (MapJust @Nat r) b) (TKS2 r b)
ConvXS
              (TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (MapJust @Nat sh) x))
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (MapJust @Nat sh) x))
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (MapJust @Nat sh) x))
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) (MapJust @Nat sh)) x)
TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
     (TKX2 (MapJust @Nat sh) x))
  (TKX2
     ((++)
        @(Maybe Nat)
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (MapJust @Nat sh))
     x)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
TKConversion (TKX2 r (TKX2 b sh)) (TKX2 ((++) @(Maybe Nat) r b) sh)
ConvUnnest (TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
-> TKConversion
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKS2 sh x))
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        (TKX2 (MapJust @Nat sh) x))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
TKConversion r b -> TKConversion (TKX2 sh r) (TKX2 sh b)
ConvXX TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (r :: [Nat]) (b :: TK).
TKConversion (TKS2 r b) (TKX2 (MapJust @Nat r) b)
ConvSX))
    FTKX ShX sh Int
ZSX FTKX{} -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
TKConversion
  (TKX2
     ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat))) (TKX2 sh x))
  (TKX2
     ((++)
        @(Maybe Nat)
        ((':) @(Maybe Nat) ('Just @Nat k) ('[] @(Maybe Nat)))
        sh)
     x)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
TKConversion (TKX2 r (TKX2 b sh)) (TKX2 ((++) @(Maybe Nat) r b) sh)
ConvUnnest
    FTKX ShX sh Int
ZSX (FTKProduct FullShapeTK y
aftk1 FullShapeTK z
aftk2) ->
      SNat k
-> FullShapeTK a
-> TKConversion a (TKProduct y z)
-> TKConversion
     (BuildTensorKind k a) (BuildTensorKind k (TKProduct y z))
forall (k :: Nat) (a :: TK) (b :: TK).
SNat k
-> FullShapeTK a
-> TKConversion a b
-> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
buildTKConversion
        SNat k
k FullShapeTK a
aftk (TKConversion
  (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
  (TKProduct y z)
-> TKConversion
     a
     (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
-> TKConversion a (TKProduct y z)
forall (b :: TK) (c :: TK) (a :: TK).
TKConversion b c -> TKConversion a b -> TKConversion a c
ConvCmp (TKConversion (TKX2 ('[] @(Maybe Nat)) y) y
-> TKConversion (TKX2 ('[] @(Maybe Nat)) z) z
-> TKConversion
     (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
     (TKProduct y z)
forall (r :: TK) (b :: TK) (sh :: TK) (b' :: TK).
TKConversion r b
-> TKConversion sh b'
-> TKConversion (TKProduct r sh) (TKProduct b b')
ConvT2 TKConversion (TKX2 ('[] @(Maybe Nat)) y) y
forall (a :: TK). TKConversion (TKX2 ('[] @(Maybe Nat)) a) a
ConvX0 TKConversion (TKX2 ('[] @(Maybe Nat)) z) z
forall (a :: TK). TKConversion (TKX2 ('[] @(Maybe Nat)) a) a
ConvX0)
                        (SingletonTK y
-> SingletonTK z
-> TKConversion
     (TKX2 ('[] @(Maybe Nat)) (TKProduct y z))
     (TKProduct (TKX2 ('[] @(Maybe Nat)) y) (TKX2 ('[] @(Maybe Nat)) z))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK r
-> SingletonTK b
-> TKConversion
     (TKX2 sh (TKProduct r b)) (TKProduct (TKX2 sh r) (TKX2 sh b))
ConvUnzip (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
aftk1) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
aftk2)))
  ConvNest (STKX StaticShX sh
sh SingletonTK x
x) -> SingletonTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> TKConversion
     (TKX2
        ((++) @(Maybe Nat) ((':) @(Maybe Nat) ('Just @Nat k) sh) sh') x)
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) (TKX2 sh' x))
forall (r :: [Maybe Nat]) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK (TKX2 r b)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) r sh) b) (TKX2 r (TKX2 sh b))
ConvNest (StaticShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
-> SingletonTK x
-> SingletonTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (SNat k -> SMayNat @Nat () SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
k SMayNat @Nat () SNat ('Just @Nat k)
-> StaticShX sh -> StaticShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
forall {sh1 :: [Maybe Nat]} (n :: Maybe Nat) (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX sh
sh) SingletonTK x
x)
  TKConversion a b
ConvUnnest -> TKConversion (BuildTensorKind k a) (BuildTensorKind k b)
TKConversion
  (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) (TKX2 sh' a))
  (TKX2
     ((++) @(Maybe Nat) ((':) @(Maybe Nat) ('Just @Nat k) sh) sh') a)
forall (r :: [Maybe Nat]) (b :: [Maybe Nat]) (sh :: TK).
TKConversion (TKX2 r (TKX2 b sh)) (TKX2 ((++) @(Maybe Nat) r b) sh)
ConvUnnest
  ConvZip SingletonTK a
astk1 SingletonTK b
astk2 -> SingletonTK a
-> SingletonTK b
-> TKConversion
     (TKProduct
        (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) a)
        (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) b))
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) (TKProduct a b))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK r
-> SingletonTK b
-> TKConversion
     (TKProduct (TKX2 sh r) (TKX2 sh b)) (TKX2 sh (TKProduct r b))
ConvZip SingletonTK a
astk1 SingletonTK b
astk2
  ConvUnzip SingletonTK a
astk1 SingletonTK b
astk2 -> SingletonTK a
-> SingletonTK b
-> TKConversion
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) (TKProduct a b))
     (TKProduct
        (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) a)
        (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) b))
forall (r :: TK) (b :: TK) (sh :: [Maybe Nat]).
SingletonTK r
-> SingletonTK b
-> TKConversion
     (TKX2 sh (TKProduct r b)) (TKProduct (TKX2 sh r) (TKX2 sh b))
ConvUnzip SingletonTK a
astk1 SingletonTK b
astk2


-- * Full shape tensor kind quasi-singletons

-- | Full shape tensor kind singleton type.
type role FullShapeTK nominal
data FullShapeTK y where
  FTKScalar :: GoodScalar r
            => FullShapeTK (TKScalar r)
  FTKR :: IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
  FTKS :: ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
  FTKX :: IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
  FTKProduct :: FullShapeTK y -> FullShapeTK z
             -> FullShapeTK (TKProduct y z)

deriving instance Show (FullShapeTK y)
deriving instance Eq (FullShapeTK y)

-- | A plausible implementation of `testEquality` on `FullShapeTK`. It does not
-- take into account shape difference in ranked and mixed tensors
-- that `FullShapeTK`, but not `SingletonTK`, captures.
matchingFTK :: FullShapeTK y1 -> FullShapeTK y2 -> Maybe (y1 :~: y2)
matchingFTK :: forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y1
ftk1 FullShapeTK y2
ftk2 = case (FullShapeTK y1
ftk1, FullShapeTK y2
ftk2) of
  (FTKScalar @r1, FTKScalar @r2)
    | Just (:~:) @Type r r
Refl <- TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r1) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r2) ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (FTKR IShR n
sh1 FullShapeTK x
x1, FTKR IShR n
sh2 FullShapeTK x
x2)
    | Just (:~:) @TK x x
Refl <- FullShapeTK x -> FullShapeTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK x
x1 FullShapeTK x
x2
    , Just (:~:) @Nat n n
Refl <- SNat n -> SNat n -> Maybe ((:~:) @Nat n n)
forall (a :: Nat) (b :: Nat).
SNat a -> SNat b -> Maybe ((:~:) @Nat a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh1) (IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh2) ->  -- weaker!!!
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (FTKS ShS sh
sh1 FullShapeTK x
x1, FTKS ShS sh
sh2 FullShapeTK x
x2)
    | Just (:~:) @TK x x
Refl <- FullShapeTK x -> FullShapeTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK x
x1 FullShapeTK x
x2
    , Just (:~:) @[Nat] sh sh
Refl <- ShS sh -> ShS sh -> Maybe ((:~:) @[Nat] sh sh)
forall (a :: [Nat]) (b :: [Nat]).
ShS a -> ShS b -> Maybe ((:~:) @[Nat] a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality ShS sh
sh1 ShS sh
sh2 ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (FTKX IShX sh
sh1 FullShapeTK x
x1, FTKX IShX sh
sh2 FullShapeTK x
x2)
    | Just (:~:) @TK x x
Refl <- FullShapeTK x -> FullShapeTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK x
x1 FullShapeTK x
x2
    , Just (:~:) @[Maybe Nat] sh sh
Refl <- StaticShX sh -> StaticShX sh -> Maybe ((:~:) @[Maybe Nat] sh sh)
forall (a :: [Maybe Nat]) (b :: [Maybe Nat]).
StaticShX a -> StaticShX b -> Maybe ((:~:) @[Maybe Nat] a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh1) (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh2) ->  -- !!!
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (FTKProduct FullShapeTK y
x1 FullShapeTK z
y1, FTKProduct FullShapeTK y
x2 FullShapeTK z
y2)
    | Just (:~:) @TK y y
Refl <- FullShapeTK y -> FullShapeTK y -> Maybe ((:~:) @TK y y)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y
x1 FullShapeTK y
x2, Just (:~:) @TK z z
Refl <- FullShapeTK z -> FullShapeTK z -> Maybe ((:~:) @TK z z)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK z
y1 FullShapeTK z
y2 ->
      (:~:) @TK y1 y2 -> Maybe ((:~:) @TK y1 y2)
forall a. a -> Maybe a
Just (:~:) @TK y1 y1
(:~:) @TK y1 y2
forall {k} (a :: k). (:~:) @k a a
Refl
  (FullShapeTK y1, FullShapeTK y2)
_ -> Maybe ((:~:) @TK y1 y2)
forall a. Maybe a
Nothing

-- | A conversion that is fully determined by the property that it
-- commutes with the `testEquality` implementations.
ftkToSTK :: FullShapeTK y -> SingletonTK y
ftkToSTK :: forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK = \case
  FullShapeTK y
FTKScalar -> SingletonTK y
SingletonTK (TKScalar r)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar
  FTKR IShR n
sh FullShapeTK x
x -> SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
forall (r :: Nat) (b :: TK).
SNat r -> SingletonTK b -> SingletonTK (TKR2 r b)
STKR (IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)
  FTKS ShS sh
sh FullShapeTK x
x -> ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> SingletonTK b -> SingletonTK (TKS2 r b)
STKS ShS sh
sh (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)
  FTKX IShX sh
sh FullShapeTK x
x -> StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (r :: [Maybe Nat]) (b :: TK).
StaticShX r -> SingletonTK b -> SingletonTK (TKX2 r b)
STKX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)
  FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2 -> SingletonTK y -> SingletonTK z -> SingletonTK (TKProduct y z)
forall (r :: TK) (b :: TK).
SingletonTK r -> SingletonTK b -> SingletonTK (TKProduct r b)
STKProduct (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk1) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
ftk2)

ftkUnit :: FullShapeTK TKUnit
ftkUnit :: FullShapeTK TKUnit
ftkUnit = FullShapeTK TKUnit
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar

buildFTK :: SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK :: forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK snat :: SNat k
snat@SNat k
SNat = \case
  FullShapeTK y
FTKScalar -> ShS ((':) @Nat k ('[] @Nat))
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKS2 ((':) @Nat k ('[] @Nat)) (TKScalar r))
forall (r :: [Nat]) (b :: TK).
ShS r -> FullShapeTK b -> FullShapeTK (TKS2 r b)
FTKS (SNat k
snat SNat k -> ShS ('[] @Nat) -> ShS ((':) @Nat k ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  FTKR IShR n
sh FullShapeTK x
x -> IShR (1 + n) -> FullShapeTK x -> FullShapeTK (TKR2 (1 + n) x)
forall (r :: Nat) (b :: TK).
IShR r -> FullShapeTK b -> FullShapeTK (TKR2 r b)
FTKR (SNat k -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat k
snat Int -> IShR n -> IShR (1 + n)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: IShR n
sh) FullShapeTK x
x
  FTKS ShS sh
sh FullShapeTK x
x -> ShS ((':) @Nat k sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat k sh) x)
forall (r :: [Nat]) (b :: TK).
ShS r -> FullShapeTK b -> FullShapeTK (TKS2 r b)
FTKS (SNat k
snat SNat k -> ShS sh -> ShS ((':) @Nat k sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
  FTKX IShX sh
sh FullShapeTK x
x -> IShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX (SNat k -> SMayNat @Nat Int SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
SKnown SNat k
snat SMayNat @Nat Int SNat ('Just @Nat k)
-> IShX sh -> IShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% IShX sh
sh) FullShapeTK x
x
  FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2 -> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK (BuildTensorKind k z)
-> FullShapeTK
     (TKProduct (BuildTensorKind k y) (BuildTensorKind k z))
forall (r :: TK) (b :: TK).
FullShapeTK r -> FullShapeTK b -> FullShapeTK (TKProduct r b)
FTKProduct (SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat FullShapeTK y
ftk1) (SNat k -> FullShapeTK z -> FullShapeTK (BuildTensorKind k z)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat FullShapeTK z
ftk2)

razeFTK :: forall y k.
           SNat k -> SingletonTK y
        -> FullShapeTK (BuildTensorKind k y)
        -> FullShapeTK y
razeFTK :: forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK snat :: SNat k
snat@SNat k
SNat SingletonTK y
stk FullShapeTK (BuildTensorKind k y)
ftk = case (SingletonTK y
stk, FullShapeTK (BuildTensorKind k y)
ftk) of
  (SingletonTK y
STKScalar, FTKS (SNat n
_ :$$ ShS sh
ZSS) FullShapeTK x
FTKScalar) -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
  (STKR{}, FTKR (Int
_ :$: ShR n Int
sh) FullShapeTK x
x) -> ShR n Int -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (r :: Nat) (b :: TK).
IShR r -> FullShapeTK b -> FullShapeTK (TKR2 r b)
FTKR ShR n Int
sh FullShapeTK x
x
  (STKR{}, FTKR ShR n Int
ZSR FullShapeTK x
_) -> String -> FullShapeTK y
forall a. HasCallStack => String -> a
error String
"razeFTK: impossible built tensor kind"
  (STKS{}, FTKS (SNat n
_ :$$ ShS sh
sh) FullShapeTK x
x) -> ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (r :: [Nat]) (b :: TK).
ShS r -> FullShapeTK b -> FullShapeTK (TKS2 r b)
FTKS ShS sh
sh FullShapeTK x
x
  (STKX{}, FTKX (SMayNat @Nat Int SNat n
_ :$% ShX sh Int
sh) FullShapeTK x
x) -> ShX sh Int -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX ShX sh Int
sh FullShapeTK x
x
  (STKProduct SingletonTK y
stk1 SingletonTK z
stk2, FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2) ->
    FullShapeTK y -> FullShapeTK z -> FullShapeTK (TKProduct y z)
forall (r :: TK) (b :: TK).
FullShapeTK r -> FullShapeTK b -> FullShapeTK (TKProduct r b)
FTKProduct (SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
snat SingletonTK y
stk1 FullShapeTK y
FullShapeTK (BuildTensorKind k y)
ftk1) (SNat k
-> SingletonTK z
-> FullShapeTK (BuildTensorKind k z)
-> FullShapeTK z
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
snat SingletonTK z
stk2 FullShapeTK z
FullShapeTK (BuildTensorKind k z)
ftk2)

adFTK :: FullShapeTK y
      -> FullShapeTK (ADTensorKind y)
adFTK :: forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK = \case
  t :: FullShapeTK y
t@(FTKScalar @r) -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> FullShapeTK y
FullShapeTK (ADTensorKind y)
t
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> FullShapeTK y
FullShapeTK (ADTensorKind y)
t
      Maybe ((:~:) @Type r Float)
_ -> (:~:) @Type (ADTensorScalar r) Z1
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    FullShapeTK (ADTensorKind y))
-> FullShapeTK (ADTensorKind y)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Type (ADTensorScalar r) Z1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorScalar r :~: Z1) ((((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
  FullShapeTK (ADTensorKind y))
 -> FullShapeTK (ADTensorKind y))
-> (((ADTensorScalar r :: Type) ~ (Z1 :: Type)) =>
    FullShapeTK (ADTensorKind y))
-> FullShapeTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$
           forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar @Z1
  FTKR IShR n
sh FullShapeTK x
x -> IShR n
-> FullShapeTK (ADTensorKind x)
-> FullShapeTK (TKR2 n (ADTensorKind x))
forall (r :: Nat) (b :: TK).
IShR r -> FullShapeTK b -> FullShapeTK (TKR2 r b)
FTKR IShR n
sh (FullShapeTK (ADTensorKind x)
 -> FullShapeTK (TKR2 n (ADTensorKind x)))
-> FullShapeTK (ADTensorKind x)
-> FullShapeTK (TKR2 n (ADTensorKind x))
forall a b. (a -> b) -> a -> b
$ FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
x
  FTKS ShS sh
sh FullShapeTK x
x -> ShS sh
-> FullShapeTK (ADTensorKind x)
-> FullShapeTK (TKS2 sh (ADTensorKind x))
forall (r :: [Nat]) (b :: TK).
ShS r -> FullShapeTK b -> FullShapeTK (TKS2 r b)
FTKS ShS sh
sh (FullShapeTK (ADTensorKind x)
 -> FullShapeTK (TKS2 sh (ADTensorKind x)))
-> FullShapeTK (ADTensorKind x)
-> FullShapeTK (TKS2 sh (ADTensorKind x))
forall a b. (a -> b) -> a -> b
$ FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
x
  FTKX IShX sh
sh FullShapeTK x
x -> IShX sh
-> FullShapeTK (ADTensorKind x)
-> FullShapeTK (TKX2 sh (ADTensorKind x))
forall (r :: [Maybe Nat]) (b :: TK).
IShX r -> FullShapeTK b -> FullShapeTK (TKX2 r b)
FTKX IShX sh
sh (FullShapeTK (ADTensorKind x)
 -> FullShapeTK (TKX2 sh (ADTensorKind x)))
-> FullShapeTK (ADTensorKind x)
-> FullShapeTK (TKX2 sh (ADTensorKind x))
forall a b. (a -> b) -> a -> b
$ FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
x
  FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2 -> FullShapeTK (ADTensorKind y)
-> FullShapeTK (ADTensorKind z)
-> FullShapeTK (TKProduct (ADTensorKind y) (ADTensorKind z))
forall (r :: TK) (b :: TK).
FullShapeTK r -> FullShapeTK b -> FullShapeTK (TKProduct r b)
FTKProduct (FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK y
ftk1) (FullShapeTK z -> FullShapeTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK z
ftk2)

-- A test whether the argument tensor collection is free
-- from any non-differentiable types, such as integers.
differentiableFTK :: FullShapeTK y -> Bool
differentiableFTK :: forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK = \case
  FTKScalar @r -> case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
    Just (:~:) @Type r Double
Refl -> Bool
True
    Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
      Just (:~:) @Type r Float
Refl -> Bool
True
      Maybe ((:~:) @Type r Float)
_ -> Bool
False
  FTKR IShR n
_ FullShapeTK x
x -> FullShapeTK x -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK x
x
  FTKS ShS sh
_ FullShapeTK x
x -> FullShapeTK x -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK x
x
  FTKX IShX sh
_ FullShapeTK x
x -> FullShapeTK x -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK x
x
  FTKProduct FullShapeTK y
ftk1 FullShapeTK z
ftk2 -> FullShapeTK y -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK y
ftk1 Bool -> Bool -> Bool
&& FullShapeTK z -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK z
ftk2

type role DummyDualTarget nominal
type DummyDualTarget :: Target
newtype DummyDualTarget y = DummyDualTarget (FullShapeTK y)