{-# LANGUAGE AllowAmbiguousTypes, OverloadedLists, QuantifiedConstraints,
             UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | A collection of classes containing array operations,
-- with some extra algebraic operations and dual numbers
-- operations added in.
--
-- Note that @Ast*@ modules rarely depend on @Ops*@ and @Carriers*@ modules
-- (except for "HordeAd.Core.AstInterpret" and "HordeAd.Core.AstEnv"
-- that describe how to go from @Ast*@ to @Ops*@). Similarly, @Ops*@
-- and @Carriers*@ modules rarely depend on @Ast*@ modules
-- (except for "HordeAd.Core.OpsAst" and "HordeAd.Core.CarriersAst"
-- that describe how to define @Ops*@ in terms of @Ast*@).
-- Syntax is relatively separated from semantics and they meet
-- in the interpreter ("HordeAd.Core.AstInterpret")
-- and in the semantic model constructed from syntax ("HordeAd.Core.OpsAst").
--
-- (A copy of the text above is in "HordeAd.Core.Ast".)
module HordeAd.Core.Ops
  ( -- * The tensor classes and support datatypes
    LetTensor(..), ShareTensor(..), BaseTensor(..), HFun(..)
    -- * The giga-constraint
  , ADReady, ADReadyNoLet, ADReadyEqs, ADReadyClasses, ADReadyEqsClasses
  , AllTargetShow, CommonTargetEqOrd
    -- * Helper functions
  , rtr, rflatten, str, sflatten, xtr, xflatten
  , tmapAccumR, tmapAccumL
  , rbuild, sbuild, xbuild
    -- * Helper classes and types
  , IntegralHAndIntElt, RealFloatAndFloatElt
  , TensorSupportsX, TensorSupportsS, TensorSupportsR, TensorSupports
  ) where

import Prelude

import Data.Foldable qualified as Foldable
import Data.Int (Int64)
import Data.Kind (Constraint, Type)
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import Data.Vector.Strict qualified as Data.Vector
import GHC.Exts (IsList (..))
import GHC.TypeLits (KnownNat, type (+), type (<=), type (<=?))
import Type.Reflection (typeRep)

import Data.Array.Nested (type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Convert (withShsFromShX)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (Init, unsafeCoerceRefl)

import HordeAd.Core.CarriersConcrete
import HordeAd.Core.ConvertTensor
import HordeAd.Core.TensorKind
import HordeAd.Core.Types

-- These user API functions are used in default definitions of methods,
-- so they have to be defined already here:

rtr :: forall n x target. (KnownSTK x, BaseTensor target)
    => target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr :: forall (n :: Natural) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr = PermR -> target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
forall (n :: Natural) (x :: TK).
KnownSTK x =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose [Int
Item PermR
1, Int
Item PermR
0]
rflatten :: forall n x target. (KnownSTK x, BaseTensor target)
         => target (TKR2 n x) -> target (TKR2 1 x)
rflatten :: forall (n :: Natural) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 n x) -> target (TKR2 1 x)
rflatten target (TKR2 n x)
u = IShR 1 -> target (TKR2 n x) -> target (TKR2 1 x)
forall (n :: Natural) (m :: Natural) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Natural) (m :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape (target (TKR2 n x) -> Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> Int
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> Int
rsize target (TKR2 n x)
u Int -> ShR 0 Int -> IShR 1
forall {n1 :: Natural} {i} (n :: Natural).
((n + 1 :: Natural) ~ (n1 :: Natural)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Natural) i.
((n :: Natural) ~ (0 :: Natural)) =>
ShR n i
ZSR) target (TKR2 n x)
u
str :: forall n m sh x target. (KnownSTK x, BaseTensor target)
    => target (TKS2 (n ': m ': sh) x) -> target (TKS2 (m ': n ': sh) x)
str :: forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str = (:~:)
  @Bool
  (OrdCond
     @Bool (CmpNat 2 ((Rank @Natural sh + 1) + 1)) 'True 'True 'False)
  'True
-> (((OrdCond
        @Bool
        (CmpNat 2 ((Rank @Natural sh + 1) + 1))
        'True
        'True
        'False :: Bool)
     ~ ('True :: Bool)) =>
    target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
    -> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x))
-> target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @Bool
  (OrdCond
     @Bool (CmpNat 2 ((Rank @Natural sh + 1) + 1)) 'True 'True 'False)
  'True
(:~:)
  @Bool
  (OrdCond
     @Bool
     (Compare
        @Natural 2 (Rank @Natural ((':) @Natural n ((':) @Natural m sh))))
     'True
     'True
     'False)
  'True
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: (2 <=? Rank (n ': m ': sh)) :~: True) ((((OrdCond
      @Bool
      (CmpNat 2 ((Rank @Natural sh + 1) + 1))
      'True
      'True
      'False :: Bool)
   ~ ('True :: Bool)) =>
  target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
  -> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x))
 -> target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
 -> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x))
-> (((OrdCond
        @Bool
        (CmpNat 2 ((Rank @Natural sh + 1) + 1))
        'True
        'True
        'False :: Bool)
     ~ ('True :: Bool)) =>
    target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
    -> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x))
-> target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
forall a b. (a -> b) -> a -> b
$
      Perm ((':) @Natural 1 ((':) @Natural 0 ('[] @Natural)))
-> target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target
     (TKS2
        (PermutePrefix
           @Natural
           ((':) @Natural 1 ((':) @Natural 0 ('[] @Natural)))
           ((':) @Natural n ((':) @Natural m sh)))
        x)
forall (perm :: [Natural]) (sh :: [Natural]) (x :: TK).
(IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @Natural sh),
 KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
forall (target :: Target) (perm :: [Natural]) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @Natural sh),
 KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
tstranspose (forall (l :: [Natural]). KnownPerm l => Perm l
Permutation.makePerm @'[1, 0])
sflatten :: (KnownShS sh, KnownSTK x, BaseTensor target )
         => target (TKS2 sh x) -> target (TKS2 '[Product sh] x)
sflatten :: forall (sh :: [Natural]) (x :: TK) (target :: Target).
(KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
sflatten @sh | SNat (Product sh)
SNat <- ShS sh -> SNat (Product sh)
forall (sh :: [Natural]). ShS sh -> SNat (Product sh)
shsProduct (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh) = ShS ((':) @Natural (Product sh) ('[] @Natural))
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Natural]) (sh2 :: [Natural])
       (x :: TK).
(BaseTensor target,
 (Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS ((':) @Natural (Product sh) ('[] @Natural))
forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS
xtr :: forall n m sh x target. (KnownSTK x, BaseTensor target)
    => target (TKX2 (Just n ': Just m ': sh) x)
    -> target (TKX2 (Just m ': Just n ': sh) x)
xtr :: forall (n :: Natural) (m :: Natural) (sh :: [Maybe Natural])
       (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target
  (TKX2
     ((':)
        @(Maybe Natural)
        ('Just @Natural n)
        ((':) @(Maybe Natural) ('Just @Natural m) sh))
     x)
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural m)
           ((':) @(Maybe Natural) ('Just @Natural n) sh))
        x)
xtr = (:~:)
  @Bool
  (OrdCond
     @Bool
     (CmpNat 2 ((Rank @(Maybe Natural) sh + 1) + 1))
     'True
     'True
     'False)
  'True
-> (((OrdCond
        @Bool
        (CmpNat 2 ((Rank @(Maybe Natural) sh + 1) + 1))
        'True
        'True
        'False :: Bool)
     ~ ('True :: Bool)) =>
    target
      (TKX2
         ((':)
            @(Maybe Natural)
            ('Just @Natural n)
            ((':) @(Maybe Natural) ('Just @Natural m) sh))
         x)
    -> target
         (TKX2
            ((':)
               @(Maybe Natural)
               ('Just @Natural m)
               ((':) @(Maybe Natural) ('Just @Natural n) sh))
            x))
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural n)
           ((':) @(Maybe Natural) ('Just @Natural m) sh))
        x)
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural m)
           ((':) @(Maybe Natural) ('Just @Natural n) sh))
        x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @Bool
  (OrdCond
     @Bool
     (CmpNat 2 ((Rank @(Maybe Natural) sh + 1) + 1))
     'True
     'True
     'False)
  'True
(:~:)
  @Bool
  (OrdCond
     @Bool
     (Compare
        @Natural
        2
        (Rank
           @(Maybe Natural)
           ((':)
              @(Maybe Natural)
              ('Just @Natural n)
              ((':) @(Maybe Natural) ('Just @Natural m) sh))))
     'True
     'True
     'False)
  'True
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                 :: (2 <=? Rank (Just n ': Just m ': sh)) :~: True) ((((OrdCond
      @Bool
      (CmpNat 2 ((Rank @(Maybe Natural) sh + 1) + 1))
      'True
      'True
      'False :: Bool)
   ~ ('True :: Bool)) =>
  target
    (TKX2
       ((':)
          @(Maybe Natural)
          ('Just @Natural n)
          ((':) @(Maybe Natural) ('Just @Natural m) sh))
       x)
  -> target
       (TKX2
          ((':)
             @(Maybe Natural)
             ('Just @Natural m)
             ((':) @(Maybe Natural) ('Just @Natural n) sh))
          x))
 -> target
      (TKX2
         ((':)
            @(Maybe Natural)
            ('Just @Natural n)
            ((':) @(Maybe Natural) ('Just @Natural m) sh))
         x)
 -> target
      (TKX2
         ((':)
            @(Maybe Natural)
            ('Just @Natural m)
            ((':) @(Maybe Natural) ('Just @Natural n) sh))
         x))
-> (((OrdCond
        @Bool
        (CmpNat 2 ((Rank @(Maybe Natural) sh + 1) + 1))
        'True
        'True
        'False :: Bool)
     ~ ('True :: Bool)) =>
    target
      (TKX2
         ((':)
            @(Maybe Natural)
            ('Just @Natural n)
            ((':) @(Maybe Natural) ('Just @Natural m) sh))
         x)
    -> target
         (TKX2
            ((':)
               @(Maybe Natural)
               ('Just @Natural m)
               ((':) @(Maybe Natural) ('Just @Natural n) sh))
            x))
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural n)
           ((':) @(Maybe Natural) ('Just @Natural m) sh))
        x)
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural m)
           ((':) @(Maybe Natural) ('Just @Natural n) sh))
        x)
forall a b. (a -> b) -> a -> b
$
      Perm ((':) @Natural 1 ((':) @Natural 0 ('[] @Natural)))
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural n)
           ((':) @(Maybe Natural) ('Just @Natural m) sh))
        x)
-> target
     (TKX2
        (PermutePrefix
           @(Maybe Natural)
           ((':) @Natural 1 ((':) @Natural 0 ('[] @Natural)))
           ((':)
              @(Maybe Natural)
              ('Just @Natural n)
              ((':) @(Maybe Natural) ('Just @Natural m) sh)))
        x)
forall (perm :: [Natural]) (sh :: [Maybe Natural]) (x :: TK).
(IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @(Maybe Natural) sh),
 KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Natural) perm sh) x)
forall (target :: Target) (perm :: [Natural])
       (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @(Maybe Natural) sh),
 KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Natural) perm sh) x)
txtranspose (forall (l :: [Natural]). KnownPerm l => Perm l
Permutation.makePerm @'[1, 0])
xflatten :: forall sh x target. (KnownSTK x, BaseTensor target)
         => target (TKX2 sh x) -> target (TKX2 '[Nothing] x)
xflatten :: forall (sh :: [Maybe Natural]) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKX2 sh x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
xflatten target (TKX2 sh x)
u = IShX
  ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
-> target (TKX2 sh x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
forall (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape (Int -> SMayNat @Natural Int SNat ('Nothing @Natural)
forall {k} i (f :: k -> Type). i -> SMayNat @k i f ('Nothing @k)
Nested.SUnknown (target (TKX2 sh x) -> Int
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> Int
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> Int
xsize target (TKX2 sh x)
u) SMayNat @Natural Int SNat ('Nothing @Natural)
-> ShX ('[] @(Maybe Natural)) Int
-> IShX
     ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Natural)) Int
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
ShX sh i
ZSX) target (TKX2 sh x)
u

rbuild :: (KnownNat m, KnownNat n, KnownSTK x, BaseTensor target)
       => IShR (m + n)  -- ^ the shape of the resulting tensor
       -> (IxROf target m -> target (TKR2 n x))
            -- ^ the function to build with
       -> target (TKR2 (m + n) x)
rbuild :: forall (m :: Natural) (n :: Natural) (x :: TK) (target :: Target).
(KnownNat m, KnownNat n, KnownSTK x, BaseTensor target) =>
IShR (m + n)
-> (IxROf target m -> target (TKR2 n x)) -> target (TKR2 (m + n) x)
rbuild @m @n @x @target IShR (m + n)
sh0 IxROf target m -> target (TKR2 n x)
f0 =
  let buildSh :: IShR m1 -> (IxROf target m1 -> target (TKR2 n x))
              -> target (TKR2 (m1 + n) x)
      buildSh :: forall (m1 :: Natural).
IShR m1
-> (IxROf target m1 -> target (TKR2 n x))
-> target (TKR2 (m1 + n) x)
buildSh ShR m1 Int
ZSR IxROf target m1 -> target (TKR2 n x)
f = IxROf target m1 -> target (TKR2 n x)
f IxROf target m1
forall (n :: Natural) i.
((n :: Natural) ~ (0 :: Natural)) =>
IxR n i
ZIR
      buildSh (Int
k :$: ShR n Int
sh) IxROf target m1 -> target (TKR2 n x)
f | SNat n
SNat <- ShR n Int -> SNat n
forall (n :: Natural) i. ShR n i -> SNat n
shrRank ShR n Int
sh =
        let g :: PrimalOf target (TKScalar Int64) -> target (TKR2 (n + n) x)
g PrimalOf target (TKScalar Int64)
i = ShR n Int
-> (IxROf target n -> target (TKR2 n x)) -> target (TKR2 (n + n) x)
forall (m1 :: Natural).
IShR m1
-> (IxROf target m1 -> target (TKR2 n x))
-> target (TKR2 (m1 + n) x)
buildSh ShR n Int
sh (\IxROf target n
ix -> IxROf target m1 -> target (TKR2 n x)
f (PrimalOf target (TKScalar Int64)
i PrimalOf target (TKScalar Int64)
-> IxROf target n -> IxROf target m1
forall {n1 :: Natural} {i} (n :: Natural).
((n + 1 :: Natural) ~ (n1 :: Natural)) =>
i -> IxR n i -> IxR n1 i
:.: IxROf target n
ix))
        in Int
-> (PrimalOf target (TKScalar Int64) -> target (TKR2 (n + n) x))
-> target (TKR2 (1 + (n + n)) x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int
-> (PrimalOf target (TKScalar Int64) -> target (TKR2 n x))
-> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trbuild1 Int
k PrimalOf target (TKScalar Int64) -> target (TKR2 (n + n) x)
g
  in IShR m
-> (IxROf target m -> target (TKR2 n x)) -> target (TKR2 (m + n) x)
forall (m1 :: Natural).
IShR m1
-> (IxROf target m1 -> target (TKR2 n x))
-> target (TKR2 (m1 + n) x)
buildSh (forall (m :: Natural) (n :: Natural) i.
(KnownNat n, KnownNat m) =>
ShR (m + n) i -> ShR m i
shrTake @m @n IShR (m + n)
sh0) IxROf target m -> target (TKR2 n x)
f0
sbuild :: (KnownShS (Take m sh), KnownShS sh, KnownSTK x, BaseTensor target)
       => (IxSOf target (Take m sh) -> target (TKS2 (Drop m sh) x))
            -- ^ the function to build with
       -> target (TKS2 sh x)
sbuild :: forall (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownShS (Take @Natural m sh), KnownShS sh, KnownSTK x,
 BaseTensor target) =>
(IxSOf target (Take @Natural m sh)
 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 sh x)
sbuild @m @sh @x @target =
  let buildSh
        :: forall sh1.
           ShS sh1 -> ShS (sh1 ++ Drop m sh)
        -> (IxSOf target sh1 -> target (TKS2 (Drop m sh) x))
        -> target (TKS2 (sh1 ++ Drop m sh) x)
      buildSh :: forall (sh1 :: [Natural]).
ShS sh1
-> ShS ((++) @Natural sh1 (Drop @Natural m sh))
-> (IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x)
buildSh ShS sh1
sh1 ShS ((++) @Natural sh1 (Drop @Natural m sh))
sh1m IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x)
f = case (ShS sh1
sh1, ShS ((++) @Natural sh1 (Drop @Natural m sh))
sh1m) of
        (ShS sh1
ZSS, ShS ((++) @Natural sh1 (Drop @Natural m sh))
_) -> IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x)
f IxSOf target sh1
forall (sh :: [Natural]) i.
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
IxS sh i
ZIS
        (SNat n
SNat :$$ ShS sh
sh2, SNat n
_ :$$ ShS sh
sh2m) ->
          ShS sh
-> (KnownShS sh =>
    target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x))
-> target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x)
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh2m ((KnownShS sh =>
  target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x))
 -> target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x))
-> (KnownShS sh =>
    target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x))
-> target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x)
forall a b. (a -> b) -> a -> b
$
          let g :: PrimalOf target (TKScalar Int64)
-> target (TKS2 ((++) @Natural sh (Drop @Natural m sh)) x)
g PrimalOf target (TKScalar Int64)
i = ShS sh
-> ShS ((++) @Natural sh (Drop @Natural m sh))
-> (IxSOf target sh -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 ((++) @Natural sh (Drop @Natural m sh)) x)
forall (sh1 :: [Natural]).
ShS sh1
-> ShS ((++) @Natural sh1 (Drop @Natural m sh))
-> (IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x)
buildSh ShS sh
sh2 ShS sh
ShS ((++) @Natural sh (Drop @Natural m sh))
sh2m (IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x)
f (IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x))
-> (IxSOf target sh -> IxSOf target sh1)
-> IxSOf target sh
-> target (TKS2 (Drop @Natural m sh) x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimalOf target (TKScalar Int64)
i PrimalOf target (TKScalar Int64)
-> IxSOf target sh -> IxSOf target sh1
forall {sh1 :: [Natural]} {i} (n :: Natural) (sh :: [Natural]).
(KnownNat n,
 ((':) @Natural n sh :: [Natural]) ~ (sh1 :: [Natural])) =>
i -> IxS sh i -> IxS sh1 i
:.$))
          in (PrimalOf target (TKScalar Int64) -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural n sh) x)
forall (k :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat k, KnownShS sh, KnownSTK x) =>
(PrimalOf target (TKScalar Int64) -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
forall (target :: Target) (k :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
tsbuild1 PrimalOf target (TKScalar Int64) -> target (TKS2 sh x)
PrimalOf target (TKScalar Int64)
-> target (TKS2 ((++) @Natural sh (Drop @Natural m sh)) x)
g
  in (:~:)
  @[Natural]
  sh
  ((++) @Natural (Take @Natural m sh) (Drop @Natural m sh))
-> (((sh :: [Natural])
     ~ ((++)
          @Natural
          (Take @Natural m sh)
          (Drop @Natural m sh) :: [Natural])) =>
    (IxSOf target (Take @Natural m sh)
     -> target (TKS2 (Drop @Natural m sh) x))
    -> target (TKS2 sh x))
-> (IxSOf target (Take @Natural m sh)
    -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  sh
  ((++) @Natural (Take @Natural m sh) (Drop @Natural m sh))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: sh :~: Take m sh ++ Drop m sh)
     ((((sh :: [Natural])
   ~ ((++)
        @Natural
        (Take @Natural m sh)
        (Drop @Natural m sh) :: [Natural])) =>
  (IxSOf target (Take @Natural m sh)
   -> target (TKS2 (Drop @Natural m sh) x))
  -> target (TKS2 sh x))
 -> (IxSOf target (Take @Natural m sh)
     -> target (TKS2 (Drop @Natural m sh) x))
 -> target (TKS2 sh x))
-> (((sh :: [Natural])
     ~ ((++)
          @Natural
          (Take @Natural m sh)
          (Drop @Natural m sh) :: [Natural])) =>
    (IxSOf target (Take @Natural m sh)
     -> target (TKS2 (Drop @Natural m sh) x))
    -> target (TKS2 sh x))
-> (IxSOf target (Take @Natural m sh)
    -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ ShS (Take @Natural m sh)
-> ShS ((++) @Natural (Take @Natural m sh) (Drop @Natural m sh))
-> (IxSOf target (Take @Natural m sh)
    -> target (TKS2 (Drop @Natural m sh) x))
-> target
     (TKS2 ((++) @Natural (Take @Natural m sh) (Drop @Natural m sh)) x)
forall (sh1 :: [Natural]).
ShS sh1
-> ShS ((++) @Natural sh1 (Drop @Natural m sh))
-> (IxSOf target sh1 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 ((++) @Natural sh1 (Drop @Natural m sh)) x)
buildSh (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @(Take m sh)) (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh)
xbuild :: ( KnownShX (Take m sh), KnownSTK x
          , BaseTensor target, ConvertTensor target )
       => IShX sh  -- ^ the shape of the resulting tensor
       -> (IxXOf target (Take m sh) -> target (TKX2 (Drop m sh) x))
            -- ^ the function to build with
       -> target (TKX2 sh x)
xbuild :: forall (m :: Natural) (sh :: [Maybe Natural]) (x :: TK)
       (target :: Target).
(KnownShX (Take @(Maybe Natural) m sh), KnownSTK x,
 BaseTensor target, ConvertTensor target) =>
IShX sh
-> (IxXOf target (Take @(Maybe Natural) m sh)
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target (TKX2 sh x)
xbuild @m @sh @x @target IShX sh
sh0 IxXOf target (Take @(Maybe Natural) m sh)
-> target (TKX2 (Drop @(Maybe Natural) m sh) x)
f0 =
  let buildSh :: IShX sh1 -> IShX (sh1 ++ Drop m sh)
              -> (IxXOf target sh1 -> target (TKX2 (Drop m sh) x))
              -> target (TKX2 (sh1 ++ Drop m sh) x)
      buildSh :: forall (sh1 :: [Maybe Natural]).
IShX sh1
-> IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
-> (IxXOf target sh1
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
buildSh IShX sh1
sh1 IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
sh1m IxXOf target sh1 -> target (TKX2 (Drop @(Maybe Natural) m sh) x)
f = case (IShX sh1
sh1, IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
sh1m) of
        (IShX sh1
ZSX, IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
_) -> IxXOf target sh1 -> target (TKX2 (Drop @(Maybe Natural) m sh) x)
f IxXOf target sh1
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
IxX sh i
ZIX
        (SMayNat @Natural Int SNat n
k :$% ShX sh Int
sh2, SMayNat @Natural Int SNat n
_ :$% ShX sh Int
sh2m) ->
          StaticShX sh
-> (KnownShX sh =>
    target
      (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (ShX sh Int -> StaticShX sh
forall (sh :: [Maybe Natural]) i. ShX sh i -> StaticShX sh
ssxFromShX ShX sh Int
sh2m) ((KnownShX sh =>
  target
    (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
 -> target
      (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> (KnownShX sh =>
    target
      (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
forall a b. (a -> b) -> a -> b
$
          let g :: PrimalOf target (TKScalar Int64)
-> target
     (TKX2 ((++) @(Maybe Natural) sh (Drop @(Maybe Natural) m sh)) x)
g PrimalOf target (TKScalar Int64)
i = ShX sh Int
-> IShX ((++) @(Maybe Natural) sh (Drop @(Maybe Natural) m sh))
-> (IxXOf target sh
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh (Drop @(Maybe Natural) m sh)) x)
forall (sh1 :: [Maybe Natural]).
IShX sh1
-> IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
-> (IxXOf target sh1
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
buildSh ShX sh Int
sh2 ShX sh Int
IShX ((++) @(Maybe Natural) sh (Drop @(Maybe Natural) m sh))
sh2m (IxXOf target sh1 -> target (TKX2 (Drop @(Maybe Natural) m sh) x)
f (IxXOf target sh1 -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> (IxXOf target sh -> IxXOf target sh1)
-> IxXOf target sh
-> target (TKX2 (Drop @(Maybe Natural) m sh) x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimalOf target (TKScalar Int64)
i PrimalOf target (TKScalar Int64)
-> IxXOf target sh -> IxXOf target sh1
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
i -> IxX sh i -> IxX sh1 i
:.%))
          in Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n
    -> target
         (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (SMayNat @Natural Int SNat n -> Int
forall (n :: Maybe Natural). SMayNat @Natural Int SNat n -> Int
fromSMayNat' SMayNat @Natural Int SNat n
k) ((forall (n :: Natural).
  KnownNat n =>
  SNat n
  -> target
       (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
 -> target
      (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n
    -> target
         (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
forall a b. (a -> b) -> a -> b
$ \(SNat @n) ->
               StaticShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
forall (x :: TK) (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (ShX ((':) @(Maybe Natural) n sh) Int
-> StaticShX ((':) @(Maybe Natural) n sh)
forall (sh :: [Maybe Natural]) i. ShX sh i -> StaticShX sh
ssxFromShX ShX ((':) @(Maybe Natural) n sh) Int
IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
sh1m) (target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
 -> target
      (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
forall a b. (a -> b) -> a -> b
$ forall (target :: Target) (k :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShX sh, KnownSTK x) =>
(IntOf target -> target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
txbuild1 @_ @n PrimalOf target (TKScalar Int64) -> target (TKX2 sh x)
PrimalOf target (TKScalar Int64)
-> target
     (TKX2 ((++) @(Maybe Natural) sh (Drop @(Maybe Natural) m sh)) x)
g
  in (:~:)
  @[Maybe Natural]
  sh
  ((++)
     @(Maybe Natural)
     (Take @(Maybe Natural) m sh)
     (Drop @(Maybe Natural) m sh))
-> (((sh :: [Maybe Natural])
     ~ ((++)
          @(Maybe Natural)
          (Take @(Maybe Natural) m sh)
          (Drop @(Maybe Natural) m sh) :: [Maybe Natural])) =>
    target (TKX2 sh x))
-> target (TKX2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  sh
  ((++)
     @(Maybe Natural)
     (Take @(Maybe Natural) m sh)
     (Drop @(Maybe Natural) m sh))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: sh :~: Take m sh ++ Drop m sh)
     ((((sh :: [Maybe Natural])
   ~ ((++)
        @(Maybe Natural)
        (Take @(Maybe Natural) m sh)
        (Drop @(Maybe Natural) m sh) :: [Maybe Natural])) =>
  target (TKX2 sh x))
 -> target (TKX2 sh x))
-> (((sh :: [Maybe Natural])
     ~ ((++)
          @(Maybe Natural)
          (Take @(Maybe Natural) m sh)
          (Drop @(Maybe Natural) m sh) :: [Maybe Natural])) =>
    target (TKX2 sh x))
-> target (TKX2 sh x)
forall a b. (a -> b) -> a -> b
$ IShX (Take @(Maybe Natural) m sh)
-> IShX
     ((++)
        @(Maybe Natural)
        (Take @(Maybe Natural) m sh)
        (Drop @(Maybe Natural) m sh))
-> (IxXOf target (Take @(Maybe Natural) m sh)
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target
     (TKX2
        ((++)
           @(Maybe Natural)
           (Take @(Maybe Natural) m sh)
           (Drop @(Maybe Natural) m sh))
        x)
forall (sh1 :: [Maybe Natural]).
IShX sh1
-> IShX ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh))
-> (IxXOf target sh1
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target
     (TKX2 ((++) @(Maybe Natural) sh1 (Drop @(Maybe Natural) m sh)) x)
buildSh (Proxy @[Maybe Natural] (Drop @(Maybe Natural) m sh)
-> StaticShX (Take @(Maybe Natural) m sh)
-> IShX
     ((++)
        @(Maybe Natural)
        (Take @(Maybe Natural) m sh)
        (Drop @(Maybe Natural) m sh))
-> IShX (Take @(Maybe Natural) m sh)
forall (sh :: [Maybe Natural]) (sh' :: [Maybe Natural]) i
       (proxy :: [Maybe Natural] -> Type).
proxy sh'
-> StaticShX sh -> ShX ((++) @(Maybe Natural) sh sh') i -> ShX sh i
shxTakeSSX (forall (t :: [Maybe Natural]). Proxy @[Maybe Natural] t
forall {k} (t :: k). Proxy @k t
Proxy @(Drop m sh)) (forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX @(Take m sh)) IShX sh
IShX
  ((++)
     @(Maybe Natural)
     (Take @(Maybe Natural) m sh)
     (Drop @(Maybe Natural) m sh))
sh0) IShX sh
IShX
  ((++)
     @(Maybe Natural)
     (Take @(Maybe Natural) m sh)
     (Drop @(Maybe Natural) m sh))
sh0 IxXOf target (Take @(Maybe Natural) m sh)
-> target (TKX2 (Drop @(Maybe Natural) m sh) x)
f0

-- | A strict right mapAccum.
tmapAccumR
  :: forall accy by ey k target. BaseTensor target
  => Proxy target
  -> SNat k  -- ^ length of the input
  -> FullShapeTK accy  -- ^ shape of the accumulator
  -> FullShapeTK by  -- ^ shape of the output
  -> FullShapeTK ey  -- ^ shape of an individual input
  -> (forall f. ADReady f
      => f accy -> f ey -> f (TKProduct accy by))
       -- ^ the function to mapAccum with
  -> target accy  -- ^ the initial accumulator
  -> target (BuildTensorKind k ey)  -- ^ the inputs
  -> target (TKProduct accy (BuildTensorKind k by))
{-# INLINE tmapAccumR #-}  -- this doesn't want to specialize
tmapAccumR :: forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Natural)
       (target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
    ADReady f =>
    f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumR Proxy @Target target
proxy !SNat k
k !FullShapeTK accy
accftk !FullShapeTK by
bftk !FullShapeTK ey
eftk forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by)
f target accy
acc0 target (BuildTensorKind k ey)
es =
  let xftk :: FullShapeTK (TKProduct accy ey)
xftk = FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk
      fl :: forall f. ADReady f
         => f (TKProduct accy ey)
         -> f (TKProduct accy by)
      fl :: forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl !f (TKProduct accy ey)
args = f (TKProduct accy ey)
-> (f (TKProduct accy ey) -> f (TKProduct accy by))
-> f (TKProduct accy by)
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct accy ey)
args ((f (TKProduct accy ey) -> f (TKProduct accy by))
 -> f (TKProduct accy by))
-> (f (TKProduct accy ey) -> f (TKProduct accy by))
-> f (TKProduct accy by)
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct accy ey)
args1 ->
                   f accy -> f ey -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by)
f (f (TKProduct accy ey) -> f accy
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct accy ey)
args1) (f (TKProduct accy ey) -> f ey
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct accy ey)
args1)
  in Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Natural).
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
       (k :: Natural).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumRDer Proxy @Target target
proxy SNat k
k FullShapeTK accy
accftk FullShapeTK by
bftk FullShapeTK ey
eftk
                   (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target FullShapeTK (TKProduct accy ey)
xftk ((forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy by))
-> HFun (TKProduct accy ey) (TKProduct accy by)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl))
                   (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x
-> HFun x z
-> HFunOf target (TKProduct (ADTensorKind x) x) (ADTensorKind z)
tjvp @target FullShapeTK (TKProduct accy ey)
xftk (HFun (TKProduct accy ey) (TKProduct accy by)
 -> HFunOf
      target
      (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
      (ADTensorKind (TKProduct accy by)))
-> HFun (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy by))
-> HFun (TKProduct accy ey) (TKProduct accy by)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl)
                   (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x
-> HFun x z
-> HFunOf target (TKProduct (ADTensorKind z) x) (ADTensorKind x)
tvjp @target FullShapeTK (TKProduct accy ey)
xftk (HFun (TKProduct accy ey) (TKProduct accy by)
 -> HFunOf
      target
      (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
      (ADTensorKind (TKProduct accy ey)))
-> HFun (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy by))
-> HFun (TKProduct accy ey) (TKProduct accy by)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl)
                   target accy
acc0 target (BuildTensorKind k ey)
es
-- | A strict left mapAccum.
tmapAccumL
  :: forall accy by ey k target. BaseTensor target
  => Proxy target
  -> SNat k  -- ^ length of the input
  -> FullShapeTK accy  -- ^ shape of the accumulator
  -> FullShapeTK by  -- ^ shape of the output
  -> FullShapeTK ey  -- ^ shape of an individual input
  -> (forall f. ADReady f
      => f accy -> f ey -> f (TKProduct accy by))
       -- ^ the function to mapAccum with
  -> target accy  -- ^ the initial accumulator
  -> target (BuildTensorKind k ey)  -- ^ the inputs
  -> target (TKProduct accy (BuildTensorKind k by))
{-# INLINE tmapAccumL #-}  -- this doesn't want to specialize
tmapAccumL :: forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Natural)
       (target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
    ADReady f =>
    f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumL Proxy @Target target
proxy !SNat k
k !FullShapeTK accy
accftk !FullShapeTK by
bftk !FullShapeTK ey
eftk forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by)
f target accy
acc0 target (BuildTensorKind k ey)
es =
  let xftk :: FullShapeTK (TKProduct accy ey)
xftk = FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk
      fl :: forall f. ADReady f
         => f (TKProduct accy ey)
         -> f (TKProduct accy by)
      fl :: forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl !f (TKProduct accy ey)
args = f (TKProduct accy ey)
-> (f (TKProduct accy ey) -> f (TKProduct accy by))
-> f (TKProduct accy by)
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct accy ey)
args ((f (TKProduct accy ey) -> f (TKProduct accy by))
 -> f (TKProduct accy by))
-> (f (TKProduct accy ey) -> f (TKProduct accy by))
-> f (TKProduct accy by)
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct accy ey)
args1 ->
                   f accy -> f ey -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by)
f (f (TKProduct accy ey) -> f accy
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct accy ey)
args1) (f (TKProduct accy ey) -> f ey
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct accy ey)
args1)
  in Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Natural).
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
       (k :: Natural).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumLDer Proxy @Target target
proxy SNat k
k FullShapeTK accy
accftk FullShapeTK by
bftk FullShapeTK ey
eftk
                   (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target FullShapeTK (TKProduct accy ey)
xftk ((forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy by))
-> HFun (TKProduct accy ey) (TKProduct accy by)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl))
                   (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x
-> HFun x z
-> HFunOf target (TKProduct (ADTensorKind x) x) (ADTensorKind z)
tjvp @target FullShapeTK (TKProduct accy ey)
xftk (HFun (TKProduct accy ey) (TKProduct accy by)
 -> HFunOf
      target
      (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
      (ADTensorKind (TKProduct accy by)))
-> HFun (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy by))
-> HFun (TKProduct accy ey) (TKProduct accy by)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl)
                   (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x
-> HFun x z
-> HFunOf target (TKProduct (ADTensorKind z) x) (ADTensorKind x)
tvjp @target FullShapeTK (TKProduct accy ey)
xftk (HFun (TKProduct accy ey) (TKProduct accy by)
 -> HFunOf
      target
      (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
      (ADTensorKind (TKProduct accy ey)))
-> HFun (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy by))
-> HFun (TKProduct accy ey) (TKProduct accy by)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy by)
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy by)
fl)
                   target accy
acc0 target (BuildTensorKind k ey)
es

type TensorSupports :: (Type -> Constraint) -> (Type -> Constraint)
                    -> Target -> Constraint
type TensorSupports c1 c2 f =
  forall r. GoodScalar r
            => c1 r => c2 (f (TKScalar r))

type TensorSupportsR :: (Type -> Constraint) -> (Type -> Constraint)
                     -> Target -> Constraint
type TensorSupportsR c1 c2 f =
  forall r n. GoodScalar r
              => c1 r => c2 (f (TKR n r))

type TensorSupportsS :: (Type -> Constraint) -> (Type -> Constraint)
                     -> Target -> Constraint
type TensorSupportsS c1 c2 f =
  forall r sh. GoodScalar r
               => c1 r => c2 (f (TKS sh r))

type TensorSupportsX :: (Type -> Constraint) -> (Type -> Constraint)
                     -> Target -> Constraint
type TensorSupportsX c1 c2 f =
  forall r sh. GoodScalar r
               => c1 r => c2 (f (TKX sh r))

class (RealFloatH r, Nested.FloatElt r)
      => RealFloatAndFloatElt r
instance (RealFloatH r, Nested.FloatElt r)
         => RealFloatAndFloatElt r

class (IntegralH r, Nested.IntElt r)
      => IntegralHAndIntElt r
instance (IntegralH r, Nested.IntElt r)
      => IntegralHAndIntElt r

class LetTensor (target :: Target) where
  ttlet :: target x -> (target x -> target z) -> target z
  ttletPrimal :: PrimalOf target x -> (PrimalOf target x -> target z)
              -> target z
  toShare :: target y -> ShareOf target y
  tunshare :: ShareOf target y -> target y
  tunshare = [Char] -> ShareOf target y -> target y
forall a. HasCallStack => [Char] -> a
error [Char]
"tunshare: this instance should never be used"
  tappend :: forall m n y. BaseTensor target
          => SNat m -> SNat n -> SingletonTK y
          -> target (BuildTensorKind m y) -> target (BuildTensorKind n y)
          -> target (BuildTensorKind (m + n) y)
  tappend msnat :: SNat m
msnat@SNat m
SNat nsnat :: SNat n
nsnat@SNat n
SNat SingletonTK y
stk target (BuildTensorKind m y)
a target (BuildTensorKind n y)
b = case SingletonTK y
stk of
    SingletonTK y
STKScalar -> target (TKS2 ((':) @Natural m ('[] @Natural)) (TKScalar r))
-> target (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))
-> target
     (TKS2 ((':) @Natural (m + n) ('[] @Natural)) (TKScalar r))
forall (m :: Natural) (n :: Natural) (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Natural m sh) x)
-> target (TKS2 ((':) @Natural n sh) x)
-> target (TKS2 ((':) @Natural (m + n) sh) x)
forall (target :: Target) (m :: Natural) (n :: Natural)
       (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Natural m sh) x)
-> target (TKS2 ((':) @Natural n sh) x)
-> target (TKS2 ((':) @Natural (m + n) sh) x)
tsappend target (BuildTensorKind m y)
target (TKS2 ((':) @Natural m ('[] @Natural)) (TKScalar r))
a target (BuildTensorKind n y)
target (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))
b
    STKR SNat n
_ SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trappend target (BuildTensorKind m y)
target (TKR2 (1 + n) x)
a target (BuildTensorKind n y)
target (TKR2 (1 + n) x)
b
    STKS ShS sh
_ SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKS2 ((':) @Natural m sh) x)
-> target (TKS2 ((':) @Natural n sh) x)
-> target (TKS2 ((':) @Natural (m + n) sh) x)
forall (m :: Natural) (n :: Natural) (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Natural m sh) x)
-> target (TKS2 ((':) @Natural n sh) x)
-> target (TKS2 ((':) @Natural (m + n) sh) x)
forall (target :: Target) (m :: Natural) (n :: Natural)
       (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Natural m sh) x)
-> target (TKS2 ((':) @Natural n sh) x)
-> target (TKS2 ((':) @Natural (m + n) sh) x)
tsappend target (BuildTensorKind m y)
target (TKS2 ((':) @Natural m sh) x)
a target (BuildTensorKind n y)
target (TKS2 ((':) @Natural n sh) x)
b
    STKX StaticShX sh
_ SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural m) sh) x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target
     (TKX2 ((':) @(Maybe Natural) ('Just @Natural (m + n)) sh) x)
forall (m :: Natural) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural m) sh) x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target
     (TKX2 ((':) @(Maybe Natural) ('Just @Natural (m + n)) sh) x)
forall (target :: Target) (m :: Natural) (n :: Natural)
       (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural m) sh) x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target
     (TKX2 ((':) @(Maybe Natural) ('Just @Natural (m + n)) sh) x)
txappend target (BuildTensorKind m y)
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural m) sh) x)
a target (BuildTensorKind n y)
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
b
    STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
      target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
-> (target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
    -> target (BuildTensorKind (m + n) y))
-> target (BuildTensorKind (m + n) y)
forall (x :: TK) (z :: TK).
target x -> (target x -> target z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet target (BuildTensorKind m y)
target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
a ((target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
  -> target (BuildTensorKind (m + n) y))
 -> target (BuildTensorKind (m + n) y))
-> (target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
    -> target (BuildTensorKind (m + n) y))
-> target (BuildTensorKind (m + n) y)
forall a b. (a -> b) -> a -> b
$ \ !target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
aShared -> target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
-> (target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
    -> target (BuildTensorKind (m + n) y))
-> target (BuildTensorKind (m + n) y)
forall (x :: TK) (z :: TK).
target x -> (target x -> target z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet target (BuildTensorKind n y)
target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
b ((target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
  -> target (BuildTensorKind (m + n) y))
 -> target (BuildTensorKind (m + n) y))
-> (target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
    -> target (BuildTensorKind (m + n) y))
-> target (BuildTensorKind (m + n) y)
forall a b. (a -> b) -> a -> b
$ \ !target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
bShared ->
        target (BuildTensorKind (m + n) y1)
-> target (BuildTensorKind (m + n) z)
-> target
     (TKProduct
        (BuildTensorKind (m + n) y1) (BuildTensorKind (m + n) z))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (SNat m
-> SNat n
-> SingletonTK y1
-> target (BuildTensorKind m y1)
-> target (BuildTensorKind n y1)
-> target (BuildTensorKind (m + n) y1)
forall (m :: Natural) (n :: Natural) (y :: TK).
BaseTensor target =>
SNat m
-> SNat n
-> SingletonTK y
-> target (BuildTensorKind m y)
-> target (BuildTensorKind n y)
-> target (BuildTensorKind (m + n) y)
forall (target :: Target) (m :: Natural) (n :: Natural) (y :: TK).
(LetTensor target, BaseTensor target) =>
SNat m
-> SNat n
-> SingletonTK y
-> target (BuildTensorKind m y)
-> target (BuildTensorKind n y)
-> target (BuildTensorKind (m + n) y)
tappend SNat m
msnat SNat n
nsnat SingletonTK y1
stk1 (target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
-> target (BuildTensorKind m y1)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
aShared) (target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
-> target (BuildTensorKind n y1)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
bShared))
              (SNat m
-> SNat n
-> SingletonTK z
-> target (BuildTensorKind m z)
-> target (BuildTensorKind n z)
-> target (BuildTensorKind (m + n) z)
forall (m :: Natural) (n :: Natural) (y :: TK).
BaseTensor target =>
SNat m
-> SNat n
-> SingletonTK y
-> target (BuildTensorKind m y)
-> target (BuildTensorKind n y)
-> target (BuildTensorKind (m + n) y)
forall (target :: Target) (m :: Natural) (n :: Natural) (y :: TK).
(LetTensor target, BaseTensor target) =>
SNat m
-> SNat n
-> SingletonTK y
-> target (BuildTensorKind m y)
-> target (BuildTensorKind n y)
-> target (BuildTensorKind (m + n) y)
tappend SNat m
msnat SNat n
nsnat SingletonTK z
stk2 (target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
-> target (BuildTensorKind m z)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (BuildTensorKind m y1) (BuildTensorKind m z))
aShared) (target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
-> target (BuildTensorKind n z)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (BuildTensorKind n y1) (BuildTensorKind n z))
bShared))
  tD :: BaseTensor target
     => SingletonTK y -> PrimalOf target y -> DualOf target y
     -> target y
  tD SingletonTK y
stk PrimalOf target y
p DualOf target y
d =
    -- Lets needed, because taddTarget requires duplicable arguments.
    PrimalOf target y -> (PrimalOf target y -> target y) -> target y
forall (x :: TK) (z :: TK).
PrimalOf target x -> (PrimalOf target x -> target z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
PrimalOf target x -> (PrimalOf target x -> target z) -> target z
ttletPrimal PrimalOf target y
p ((PrimalOf target y -> target y) -> target y)
-> (PrimalOf target y -> target y) -> target y
forall a b. (a -> b) -> a -> b
$ \PrimalOf target y
pShared ->
    target y -> (target y -> target y) -> target y
forall (x :: TK) (z :: TK).
target x -> (target x -> target z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (DualOf target y -> target y
forall (y :: TK). DualOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
DualOf target y -> target y
tfromDual DualOf target y
d) ((target y -> target y) -> target y)
-> (target y -> target y) -> target y
forall a b. (a -> b) -> a -> b
$ \target y
dShared ->
      SingletonTK y -> target y -> target y -> target y
forall (y :: TK). SingletonTK y -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> target y -> target y
taddTarget SingletonTK y
stk (SingletonTK y -> PrimalOf target y -> target y
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK y
stk PrimalOf target y
pShared) target y
dShared
  -- | A strict left fold.
  tfold
    :: forall yn ym k. BaseTensor target
    => SNat k  -- ^ length of the input
    -> SingletonTK yn  -- ^ partial shape of the accumulator
    -> SingletonTK ym  -- ^ partial shape of an individual input
    -> (forall f. ADReady f => f yn -> f ym -> f yn)
         -- ^ the function to fold with
    -> target yn  -- ^ the initial accumulator
    -> target (BuildTensorKind k ym)  -- ^ the inputs
    -> target yn
  {-# INLINE tfold #-}  -- this doesn't want to specialize
  tfold SNat k
k SingletonTK yn
nstk SingletonTK ym
mstk forall (f :: Target). ADReady f => f yn -> f ym -> f yn
f target yn
acc0 target (BuildTensorKind k ym)
es =
    target (TKProduct yn (TKS ((':) @Natural k ('[] @Natural)) Z1))
-> target yn
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1
    (target (TKProduct yn (TKS ((':) @Natural k ('[] @Natural)) Z1))
 -> target yn)
-> target (TKProduct yn (TKS ((':) @Natural k ('[] @Natural)) Z1))
-> target yn
forall a b. (a -> b) -> a -> b
$ Proxy @Target target
-> SNat k
-> FullShapeTK yn
-> FullShapeTK (TKScalar Z1)
-> FullShapeTK ym
-> (forall (f :: Target).
    ADReady f =>
    f yn -> f ym -> f (TKProduct yn (TKScalar Z1)))
-> target yn
-> target (BuildTensorKind k ym)
-> target (TKProduct yn (BuildTensorKind k (TKScalar Z1)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Natural)
       (target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
    ADReady f =>
    f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumL (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
       SNat k
k
       (SingletonTK yn -> target yn -> FullShapeTK yn
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK yn
nstk target yn
acc0)
       (forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar @Z1)
       (SNat k
-> SingletonTK ym
-> FullShapeTK (BuildTensorKind k ym)
-> FullShapeTK ym
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
k SingletonTK ym
mstk (SingletonTK (BuildTensorKind k ym)
-> target (BuildTensorKind k ym)
-> FullShapeTK (BuildTensorKind k ym)
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk (SNat k -> SingletonTK ym -> SingletonTK (BuildTensorKind k ym)
forall (k :: Natural) (y :: TK).
SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK SNat k
k SingletonTK ym
mstk) target (BuildTensorKind k ym)
es))
       (let g :: forall f. ADReady f
              => f yn -> f ym -> f (TKProduct yn TKUnit)
            g :: forall (f :: Target).
ADReady f =>
f yn -> f ym -> f (TKProduct yn (TKScalar Z1))
g !f yn
acc !f ym
e = f yn -> f (TKScalar Z1) -> f (TKProduct yn (TKScalar Z1))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f yn -> f ym -> f yn
forall (f :: Target). ADReady f => f yn -> f ym -> f yn
f f yn
acc f ym
e) (Z1 -> f (TKScalar Z1)
forall r. GoodScalar r => r -> f (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
r -> target (TKScalar r)
tkconcrete Z1
Z1)
        in f yn -> f ym -> f (TKProduct yn (TKScalar Z1))
forall (f :: Target).
ADReady f =>
f yn -> f ym -> f (TKProduct yn (TKScalar Z1))
g)
       target yn
acc0
       target (BuildTensorKind k ym)
es
  -- | A strict left scan.
  tscan
    :: forall yn ym k. BaseTensor target
    => SNat k  -- ^ length of the input
    -> SingletonTK yn  -- ^ partial shape of the accumulator
    -> SingletonTK ym  -- ^ partial shape of an individual input
    -> (forall f. ADReady f => f yn -> f ym -> f yn)
         -- ^ the function to scan with
    -> target yn  -- ^ the initial accumulator
    -> target (BuildTensorKind k ym)  -- ^ the inputs
    -> target (BuildTensorKind (1 + k) yn)
  {-# INLINE tscan #-}  -- this doesn't want to specialize
  tscan SNat k
k SingletonTK yn
nstk SingletonTK ym
mstk forall (f :: Target). ADReady f => f yn -> f ym -> f yn
f target yn
acc0 target (BuildTensorKind k ym)
es =
    let bs :: target (BuildTensorKind k yn)
        bs :: target (BuildTensorKind k yn)
bs = target (TKProduct yn (BuildTensorKind k yn))
-> target (BuildTensorKind k yn)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2
             (target (TKProduct yn (BuildTensorKind k yn))
 -> target (BuildTensorKind k yn))
-> target (TKProduct yn (BuildTensorKind k yn))
-> target (BuildTensorKind k yn)
forall a b. (a -> b) -> a -> b
$ Proxy @Target target
-> SNat k
-> FullShapeTK yn
-> FullShapeTK yn
-> FullShapeTK ym
-> (forall (f :: Target).
    ADReady f =>
    f yn -> f ym -> f (TKProduct yn yn))
-> target yn
-> target (BuildTensorKind k ym)
-> target (TKProduct yn (BuildTensorKind k yn))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Natural)
       (target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
    ADReady f =>
    f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumL (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
                SNat k
k
                (SingletonTK yn -> target yn -> FullShapeTK yn
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK yn
nstk target yn
acc0)
                (SingletonTK yn -> target yn -> FullShapeTK yn
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK yn
nstk target yn
acc0)
                (SNat k
-> SingletonTK ym
-> FullShapeTK (BuildTensorKind k ym)
-> FullShapeTK ym
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
k SingletonTK ym
mstk (SingletonTK (BuildTensorKind k ym)
-> target (BuildTensorKind k ym)
-> FullShapeTK (BuildTensorKind k ym)
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk (SNat k -> SingletonTK ym -> SingletonTK (BuildTensorKind k ym)
forall (k :: Natural) (y :: TK).
SNat k -> SingletonTK y -> SingletonTK (BuildTensorKind k y)
buildSTK SNat k
k SingletonTK ym
mstk) target (BuildTensorKind k ym)
es))
              (let g :: forall f. ADReady f
                     => f yn -> f ym -> f (TKProduct yn yn)
                   g :: forall (f :: Target).
ADReady f =>
f yn -> f ym -> f (TKProduct yn yn)
g !f yn
acc !f ym
e = f yn -> (f yn -> f (TKProduct yn yn)) -> f (TKProduct yn yn)
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (f yn -> f ym -> f yn
forall (f :: Target). ADReady f => f yn -> f ym -> f yn
f f yn
acc f ym
e) ((f yn -> f (TKProduct yn yn)) -> f (TKProduct yn yn))
-> (f yn -> f (TKProduct yn yn)) -> f (TKProduct yn yn)
forall a b. (a -> b) -> a -> b
$ \ !f yn
res -> f yn -> f yn -> f (TKProduct yn yn)
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f yn
res f yn
res
               in f yn -> f ym -> f (TKProduct yn yn)
forall (f :: Target).
ADReady f =>
f yn -> f ym -> f (TKProduct yn yn)
g)
              target yn
acc0
              target (BuildTensorKind k ym)
es
    in SNat 1
-> SNat k
-> SingletonTK yn
-> target (BuildTensorKind 1 yn)
-> target (BuildTensorKind k yn)
-> target (BuildTensorKind (1 + k) yn)
forall (m :: Natural) (n :: Natural) (y :: TK).
BaseTensor target =>
SNat m
-> SNat n
-> SingletonTK y
-> target (BuildTensorKind m y)
-> target (BuildTensorKind n y)
-> target (BuildTensorKind (m + n) y)
forall (target :: Target) (m :: Natural) (n :: Natural) (y :: TK).
(LetTensor target, BaseTensor target) =>
SNat m
-> SNat n
-> SingletonTK y
-> target (BuildTensorKind m y)
-> target (BuildTensorKind n y)
-> target (BuildTensorKind (m + n) y)
tappend (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SNat k
k SingletonTK yn
nstk
               (SNat 1
-> SingletonTK yn
-> Vector (target yn)
-> target (BuildTensorKind 1 yn)
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Natural).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK yn
nstk ([target yn] -> Vector (target yn)
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList [target yn
Item [target yn]
acc0])) target (BuildTensorKind k yn)
bs

class ShareTensor (target :: Target) where
  tshare :: target y -> target y
  tunpair :: target (TKProduct x z) -> (target x, target z)
  -- This would suffer from lack of sharing with LetTensor, because
  -- ttlet doesn't work over a list. With sharing it's fine.
  tunravelToListShare :: forall y k. (BaseTensor target, ConvertTensor target)
                      => SNat k -> SingletonTK y
                      -> target (BuildTensorKind k y)
                      -> [target y]
  tunravelToListShare snat :: SNat k
snat@SNat k
SNat SingletonTK y
stk target (BuildTensorKind k y)
u = case SingletonTK y
stk of
    SingletonTK y
STKScalar -> let !uShared :: target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
uShared = target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
-> target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (BuildTensorKind k y)
target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
u
                 in (target (TKS ('[] @Natural) r) -> target y)
-> [target (TKS ('[] @Natural) r)] -> [target y]
forall a b. (a -> b) -> [a] -> [b]
map target (TKS ('[] @Natural) r) -> target y
target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKS ('[] @Natural) r) -> target (TKScalar r)
kfromS ([target (TKS ('[] @Natural) r)] -> [target y])
-> [target (TKS ('[] @Natural) r)] -> [target y]
forall a b. (a -> b) -> a -> b
$ target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
-> [target (TKS ('[] @Natural) r)]
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
tsunravelToList target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
uShared
    STKR SNat n
SNat SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> let !uShared :: target (TKR2 (1 + n) x)
uShared = target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (BuildTensorKind k y)
target (TKR2 (1 + n) x)
u
                                           in target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
trunravelToList target (TKR2 (1 + n) x)
uShared
    STKS ShS sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> let !uShared :: target (TKS2 ((':) @Natural k sh) x)
uShared = target (TKS2 ((':) @Natural k sh) x)
-> target (TKS2 ((':) @Natural k sh) x)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (BuildTensorKind k y)
target (TKS2 ((':) @Natural k sh) x)
u
                                         in ShS sh -> (KnownShS sh => [target y]) -> [target y]
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh
                                            ((KnownShS sh => [target y]) -> [target y])
-> (KnownShS sh => [target y]) -> [target y]
forall a b. (a -> b) -> a -> b
$ target (TKS2 ((':) @Natural k sh) x) -> [target (TKS2 sh x)]
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
tsunravelToList target (TKS2 ((':) @Natural k sh) x)
uShared
    STKX StaticShX sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> let !uShared :: target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
uShared = target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (BuildTensorKind k y)
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
u
                                         in StaticShX sh -> (KnownShX sh => [target y]) -> [target y]
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh
                                            ((KnownShX sh => [target y]) -> [target y])
-> (KnownShX sh => [target y]) -> [target y]
forall a b. (a -> b) -> a -> b
$ target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
-> [target (TKX2 sh x)]
forall (n :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> [target (TKX2 sh x)]
forall (target :: Target) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> [target (TKX2 sh x)]
txunravelToList target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
uShared
    STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
      let (!target (BuildTensorKind k y1)
u1, !target (BuildTensorKind k z)
u2) = target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
-> (target (BuildTensorKind k y1), target (BuildTensorKind k z))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (BuildTensorKind k y)
target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
u
      in (target y1 -> target z -> target y)
-> [target y1] -> [target z] -> [target y]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith target y1 -> target z -> target y
target y1 -> target z -> target (TKProduct y1 z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (SNat k
-> SingletonTK y1 -> target (BuildTensorKind k y1) -> [target y1]
forall (y :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y -> target (BuildTensorKind k y) -> [target y]
forall (target :: Target) (y :: TK) (k :: Natural).
(ShareTensor target, BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y -> target (BuildTensorKind k y) -> [target y]
tunravelToListShare SNat k
snat SingletonTK y1
stk1 target (BuildTensorKind k y1)
u1)
                       (SNat k
-> SingletonTK z -> target (BuildTensorKind k z) -> [target z]
forall (y :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y -> target (BuildTensorKind k y) -> [target y]
forall (target :: Target) (y :: TK) (k :: Natural).
(ShareTensor target, BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y -> target (BuildTensorKind k y) -> [target y]
tunravelToListShare SNat k
snat SingletonTK z
stk2 target (BuildTensorKind k z)
u2)

-- | The superclasses indicate that it's not only a container array,
-- but also a mathematical tensor, sporting numeric operations.
class ( Num (IntOf target)
      , IntegralH (IntOf target)
      , TensorSupports Num Num target
      , TensorSupports RealFloatAndFloatElt Floating target
      , TensorSupports RealFloatAndFloatElt RealFloatH target
      , TensorSupports IntegralHAndIntElt IntegralH target
      , TensorSupportsR Num Num target
      , TensorSupportsR RealFloatAndFloatElt Floating target
      , TensorSupportsR RealFloatAndFloatElt RealFloatH target
      , TensorSupportsR IntegralHAndIntElt IntegralH target
      , TensorSupportsS Num Num target
      , TensorSupportsS RealFloatAndFloatElt Floating target
      , TensorSupportsS RealFloatAndFloatElt RealFloatH target
      , TensorSupportsS IntegralHAndIntElt IntegralH target
      , TensorSupportsX Num Num target
      , TensorSupportsX RealFloatAndFloatElt Floating target
      , TensorSupportsX RealFloatAndFloatElt RealFloatH target
      , TensorSupportsX IntegralHAndIntElt IntegralH target )
      => BaseTensor (target :: Target) where

  -- First type argument being @target@ is acceptable here, since these
  -- operations are mostly used when the shape is not known at the type level,
  -- so it can't be used as an explicit type argument.
  rshape :: forall n x. KnownSTK x
         => target (TKR2 n x) -> IShR n
  rlength :: forall n x. KnownSTK x
          => target (TKR2 n x) -> Int
  rlength = ShR n Int -> Int
forall (sh :: Natural) i. ShR sh i -> Int
shrLength (ShR n Int -> Int)
-> (target (TKR2 n x) -> ShR n Int) -> target (TKR2 n x) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKR2 n x) -> ShR n Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape
  rsize :: forall n x. KnownSTK x
        => target (TKR2 n x) -> Int
  rsize = IShR n -> Int
forall (n :: Natural). IShR n -> Int
shrSize (IShR n -> Int)
-> (target (TKR2 n x) -> IShR n) -> target (TKR2 n x) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKR2 n x) -> IShR n
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape
  rwidth :: forall n x. KnownSTK x
          => target (TKR2 (1 + n) x) -> Int
  rwidth target (TKR2 (1 + n) x)
a = case target (TKR2 (1 + n) x) -> IShR (1 + n)
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 (1 + n) x)
a of
    Int
k :$: ShR n Int
_ -> Int
k

  sshape :: forall sh x. KnownSTK x
         => target (TKS2 sh x) -> ShS sh
  slength :: forall sh x. KnownSTK x
          => target (TKS2 sh x) -> Int
  slength = ShS sh -> Int
forall (sh :: [Natural]). ShS sh -> Int
shsLength (ShS sh -> Int)
-> (target (TKS2 sh x) -> ShS sh) -> target (TKS2 sh x) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKS2 sh x) -> ShS sh
forall (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape
  ssize :: forall sh x. KnownSTK x
        => target (TKS2 sh x) -> Int
  ssize = ShS sh -> Int
forall (sh :: [Natural]). ShS sh -> Int
shsSize (ShS sh -> Int)
-> (target (TKS2 sh x) -> ShS sh) -> target (TKS2 sh x) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKS2 sh x) -> ShS sh
forall (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape
  swidth :: forall n sh x. KnownSTK x
          => target (TKS2 (n ': sh) x) -> Int
  swidth target (TKS2 ((':) @Natural n sh) x)
a = case target (TKS2 ((':) @Natural n sh) x) -> ShS ((':) @Natural n sh)
forall (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape target (TKS2 ((':) @Natural n sh) x)
a of
    SNat n
n :$$ ShS sh
_ -> SNat n -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat n
n

  xshape :: forall sh x. KnownSTK x
         => target (TKX2 sh x) -> IShX sh
  xlength :: forall sh x. KnownSTK x
          => target (TKX2 sh x) -> Int
  xlength = ShX sh Int -> Int
forall (sh :: [Maybe Natural]) i. ShX sh i -> Int
shxLength (ShX sh Int -> Int)
-> (target (TKX2 sh x) -> ShX sh Int) -> target (TKX2 sh x) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKX2 sh x) -> ShX sh Int
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape
  xsize :: forall sh x. KnownSTK x
        => target (TKX2 sh x) -> Int
  xsize = IShX sh -> Int
forall (sh :: [Maybe Natural]). IShX sh -> Int
shxSize (IShX sh -> Int)
-> (target (TKX2 sh x) -> IShX sh) -> target (TKX2 sh x) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKX2 sh x) -> IShX sh
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape
  xwidth :: forall mn sh x. KnownSTK x
          => target (TKX2 (mn ': sh) x) -> Int
  xwidth target (TKX2 ((':) @(Maybe Natural) mn sh) x)
a = case target (TKX2 ((':) @(Maybe Natural) mn sh) x)
-> IShX ((':) @(Maybe Natural) mn sh)
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 ((':) @(Maybe Natural) mn sh) x)
a of
    SMayNat @Natural Int SNat n
mn :$% ShX sh Int
_ -> SMayNat @Natural Int SNat n -> Int
forall (n :: Maybe Natural). SMayNat @Natural Int SNat n -> Int
fromSMayNat' SMayNat @Natural Int SNat n
mn

  tsize :: SingletonTK y -> target y -> Int
  tsize SingletonTK y
stk target y
a = case SingletonTK y
stk of
    STKScalar @r -> case TypeRep @Type r -> TypeRep @Type Z1 -> Maybe ((:~:) @Type r Z1)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Z1) of
      Just (:~:) @Type r Z1
Refl -> Int
0
      Maybe ((:~:) @Type r Z1)
_ -> Int
1
    STKR SNat n
_ SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKR2 n x) -> Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> Int
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> Int
rsize target y
target (TKR2 n x)
a
    STKS ShS sh
_ SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKS2 sh x) -> Int
forall (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> Int
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> Int
ssize target y
target (TKS2 sh x)
a
    STKX StaticShX sh
_ SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKX2 sh x) -> Int
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> Int
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> Int
xsize target y
target (TKX2 sh x)
a
    STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
      SingletonTK y1 -> target y1 -> Int
forall (y :: TK). SingletonTK y -> target y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK y1
stk1 (target (TKProduct y1 z) -> target y1
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target y
target (TKProduct y1 z)
a) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ SingletonTK z -> target z -> Int
forall (y :: TK). SingletonTK y -> target y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK z
stk2 (target (TKProduct y1 z) -> target z
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target y
target (TKProduct y1 z)
a)
  tftk :: SingletonTK y -> target y -> FullShapeTK y

  -- Unlikely to require type applications at all
  tpair :: target x -> target z -> target (TKProduct x z)
  tproject1 :: target (TKProduct x z) -> target x
  tproject2 :: target (TKProduct x z) -> target z

  -----------
  -- Everything below is indended to be rarely used and usually there are
  -- more specific and/or more convienient functions that do the same job
  -- in other modules.
  -----------------

  -- | The operation is potentially strict in all arguments.
  tcond :: Boolean (BoolOf target)
        => SingletonTK y
        -> BoolOf target -> target y -> target y -> target y

  -- A more precise type would have `PrimalOf target`, but it's require
  -- the user to convert, so we leave that precision to the AST only
  -- which means the AST instance will automatically insert such
  -- conversions as needed. The same holds for trfloor and many others.
  trconcrete :: GoodScalar r
             => Nested.Ranked n r -> target (TKR n r)
  tsconcrete :: GoodScalar r
             => Nested.Shaped sh r -> target (TKS sh r)
  txconcrete :: GoodScalar r
             => Nested.Mixed sh r -> target (TKX sh r)
  tkconcrete :: GoodScalar r => r -> target (TKScalar r)
  tconcrete :: FullShapeTK y -> Concrete y -> target y

  -- These nine methods can't be replaced by tfromVector, because the concrete
  -- instance has much faster implementations.
  --
  -- This is morally non-empty strict vectors:
  trfromVector :: (KnownNat n, KnownSTK x)
               => Data.Vector.Vector (target (TKR2 n x))
               -> target (TKR2 (1 + n) x)
  trfromVector Vector (target (TKR2 n x))
v = Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKR2 (1 + n) x))
-> target (TKR2 (1 + n) x)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (Vector (target (TKR2 n x)) -> Int
forall (v :: Type -> Type) a. Vector v a => v a -> Int
V.length Vector (target (TKR2 n x))
v) ((forall (n :: Natural).
  KnownNat n =>
  SNat n -> target (TKR2 (1 + n) x))
 -> target (TKR2 (1 + n) x))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKR2 (1 + n) x))
-> target (TKR2 (1 + n) x)
forall a b. (a -> b) -> a -> b
$ \SNat n
k ->
    SNat n
-> SingletonTK (TKR2 n x)
-> Vector (target (TKR2 n x))
-> target (BuildTensorKind n (TKR2 n x))
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Natural).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector SNat n
k (SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
forall (n :: Natural) (x :: TK).
SNat n -> SingletonTK x -> SingletonTK (TKR2 n x)
STKR SNat n
forall (n :: Natural). KnownNat n => SNat n
SNat SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK) Vector (target (TKR2 n x))
v
  trfromVector0N :: forall n x. KnownSTK x
                 => IShR n -> Data.Vector.Vector (target (TKR2 0 x))
                 -> target (TKR2 n x)
  trfromVector0N IShR n
sh Vector (target (TKR2 0 x))
v | Dict @Type KnownElt (RepConcrete x)
Dict <- SingletonTK x -> Dict @Type KnownElt (RepConcrete x)
forall (y :: TK).
SingletonTK y -> Dict @Type KnownElt (RepConcrete y)
eltDictRep (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x) =
    if Vector (target (TKR2 0 x)) -> Bool
forall (v :: Type -> Type) a. Vector v a => v a -> Bool
V.null Vector (target (TKR2 0 x))
v
    then let arr :: Ranked 1 (RepConcrete x)
arr = Ranked 1 (RepConcrete x)
forall a. KnownElt a => Ranked 1 a
Nested.remptyArray
         in IShR n -> target (TKR2 1 x) -> target (TKR2 n x)
forall (n :: Natural) (m :: Natural) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Natural) (m :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape IShR n
sh (target (TKR2 1 x) -> target (TKR2 n x))
-> target (TKR2 1 x) -> target (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ FullShapeTK (TKR2 1 x) -> Concrete (TKR2 1 x) -> target (TKR2 1 x)
forall (y :: TK). FullShapeTK y -> Concrete y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK (TKR2 1 x)
-> RepConcrete (TKR2 1 x) -> FullShapeTK (TKR2 1 x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK (TKR2 1 x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Ranked 1 (RepConcrete x)
RepConcrete (TKR2 1 x)
arr) (RepConcrete (TKR2 1 x) -> Concrete (TKR2 1 x)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete Ranked 1 (RepConcrete x)
RepConcrete (TKR2 1 x)
arr)
    else IShR n -> target (TKR2 1 x) -> target (TKR2 n x)
forall (n :: Natural) (m :: Natural) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Natural) (m :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape IShR n
sh (target (TKR2 1 x) -> target (TKR2 n x))
-> target (TKR2 1 x) -> target (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ Vector (target (TKR2 0 x)) -> target (TKR2 (1 + 0) x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trfromVector Vector (target (TKR2 0 x))
v
  trunravelToList :: (KnownNat n, KnownSTK x)
                  => target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
  trunravelToList @n @x target (TKR2 (1 + n) x)
t =
    let f :: Int -> target (TKR2 n x)
        f :: Int -> target (TKR2 n x)
f Int
i = target (TKR2 (1 + n) x) -> IxROf target 1 -> target (TKR2 n x)
forall (m :: Natural) (n :: Natural) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Natural) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
trindex target (TKR2 (1 + n) x)
t (Int -> IntOf target
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i IntOf target -> IxR 0 (IntOf target) -> IxROf target 1
forall {n1 :: Natural} {i} (n :: Natural).
((n + 1 :: Natural) ~ (n1 :: Natural)) =>
i -> IxR n i -> IxR n1 i
:.: IxR 0 (IntOf target)
forall (n :: Natural) i.
((n :: Natural) ~ (0 :: Natural)) =>
IxR n i
ZIR)
    in (Int -> target (TKR2 n x)) -> PermR -> [target (TKR2 n x)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> target (TKR2 n x)
f [Int
Item PermR
0 .. target (TKR2 (1 + n) x) -> Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth target (TKR2 (1 + n) x)
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

  tsfromVector :: (KnownNat n, KnownShS sh, KnownSTK x)
               => Data.Vector.Vector (target (TKS2 sh x))
               -> target (TKS2 (n ': sh) x)
  tsfromVector = SNat n
-> SingletonTK (TKS2 sh x)
-> Vector (target (TKS2 sh x))
-> target (BuildTensorKind n (TKS2 sh x))
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Natural).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector SNat n
forall (n :: Natural). KnownNat n => SNat n
SNat (ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (sh :: [Natural]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK)
  tsfromVector0N :: (KnownShS sh, KnownSTK x)
                 => Data.Vector.Vector (target (TKS2 '[] x))
                 -> target (TKS2 sh x)
  tsfromVector0N @sh @x Vector (target (TKS2 ('[] @Natural) x))
v | Dict @Type KnownElt (RepConcrete x)
Dict <- SingletonTK x -> Dict @Type KnownElt (RepConcrete x)
forall (y :: TK).
SingletonTK y -> Dict @Type KnownElt (RepConcrete y)
eltDictRep (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x)
                          , SNat (Product sh)
SNat <- ShS sh -> SNat (Product sh)
forall (sh :: [Natural]). ShS sh -> SNat (Product sh)
shsProduct (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh) =
    if Vector (target (TKS2 ('[] @Natural) x)) -> Bool
forall (v :: Type -> Type) a. Vector v a => v a -> Bool
V.null Vector (target (TKS2 ('[] @Natural) x))
v
    then (:~:) @Natural (Product sh) 0
-> (((Product sh :: Natural) ~ (0 :: Natural)) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Natural (Product sh) 0
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Product sh :~: 0) ((((Product sh :: Natural) ~ (0 :: Natural)) => target (TKS2 sh x))
 -> target (TKS2 sh x))
-> (((Product sh :: Natural) ~ (0 :: Natural)) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$
         let arr :: Shaped ((':) @Natural 0 ('[] @Natural)) (RepConcrete x)
arr = ShS ('[] @Natural)
-> Shaped ((':) @Natural 0 ('[] @Natural)) (RepConcrete x)
forall a (sh :: [Natural]).
KnownElt a =>
ShS sh -> Shaped ((':) @Natural 0 sh) a
Nested.semptyArray ShS ('[] @Natural)
forall (sh :: [Natural]).
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
ShS sh
ZSS
         in ShS sh
-> target (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> target (TKS2 sh x)
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Natural]) (sh2 :: [Natural])
       (x :: TK).
(BaseTensor target,
 (Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS sh
forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS (target (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
 -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ FullShapeTK (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> Concrete (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> target (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
forall (y :: TK). FullShapeTK y -> Concrete y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> RepConcrete (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> FullShapeTK (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Shaped ((':) @Natural 0 ('[] @Natural)) (RepConcrete x)
RepConcrete (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
arr) (RepConcrete (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
-> Concrete (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete Shaped ((':) @Natural 0 ('[] @Natural)) (RepConcrete x)
RepConcrete (TKS2 ((':) @Natural 0 ('[] @Natural)) x)
arr)
    else ShS sh
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
-> target (TKS2 sh x)
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Natural]) (sh2 :: [Natural])
       (x :: TK).
(BaseTensor target,
 (Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh) (target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
 -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ Vector (target (TKS2 ('[] @Natural) x))
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
Vector (target (TKS2 sh x)) -> target (TKS2 ((':) @Natural n sh) x)
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
Vector (target (TKS2 sh x)) -> target (TKS2 ((':) @Natural n sh) x)
tsfromVector Vector (target (TKS2 ('[] @Natural) x))
v
  tsunravelToList :: (KnownNat n, KnownShS sh, KnownSTK x)
                  => target (TKS2 (n ': sh) x) -> [target (TKS2 sh x)]
  tsunravelToList @_ @sh @x target (TKS2 ((':) @Natural n sh) x)
t =
    let f :: Int -> target (TKS2 sh x)
        f :: Int -> target (TKS2 sh x)
f Int
i = target (TKS2 ((++) @Natural ((':) @Natural n ('[] @Natural)) sh) x)
-> IxSOf target ((':) @Natural n ('[] @Natural))
-> target (TKS2 sh x)
forall (shm :: [Natural]) (shn :: [Natural]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex target (TKS2 ((':) @Natural n sh) x)
target (TKS2 ((++) @Natural ((':) @Natural n ('[] @Natural)) sh) x)
t (Int -> IntOf target
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i IntOf target
-> IxS ('[] @Natural) (IntOf target)
-> IxSOf target ((':) @Natural n ('[] @Natural))
forall {sh1 :: [Natural]} {i} (n :: Natural) (sh :: [Natural]).
(KnownNat n,
 ((':) @Natural n sh :: [Natural]) ~ (sh1 :: [Natural])) =>
i -> IxS sh i -> IxS sh1 i
:.$ IxS ('[] @Natural) (IntOf target)
forall (sh :: [Natural]) i.
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
IxS sh i
ZIS)
    in (Int -> target (TKS2 sh x)) -> PermR -> [target (TKS2 sh x)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> target (TKS2 sh x)
f [Int
Item PermR
0 .. target (TKS2 ((':) @Natural n sh) x) -> Int
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Natural n sh) x) -> Int
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> Int
swidth target (TKS2 ((':) @Natural n sh) x)
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

  txfromVector :: (KnownNat n, KnownShX sh, KnownSTK x)
               => Data.Vector.Vector (target (TKX2 sh x))
               -> target (TKX2 (Just n ': sh) x)
  txfromVector = SNat n
-> SingletonTK (TKX2 sh x)
-> Vector (target (TKX2 sh x))
-> target (BuildTensorKind n (TKX2 sh x))
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Natural).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector SNat n
forall (n :: Natural). KnownNat n => SNat n
SNat (StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (sh :: [Maybe Natural]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK)
  txfromVector0N :: forall sh x. KnownSTK x
                 => IShX sh -> Data.Vector.Vector (target (TKX2 '[] x))
                 -> target (TKX2 sh x)
  txfromVector0N IShX sh
sh Vector (target (TKX2 ('[] @(Maybe Natural)) x))
v | Dict @Type KnownElt (RepConcrete x)
Dict <- SingletonTK x -> Dict @Type KnownElt (RepConcrete x)
forall (y :: TK).
SingletonTK y -> Dict @Type KnownElt (RepConcrete y)
eltDictRep (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x) =
    if Vector (target (TKX2 ('[] @(Maybe Natural)) x)) -> Bool
forall (v :: Type -> Type) a. Vector v a => v a -> Bool
V.null Vector (target (TKX2 ('[] @(Maybe Natural)) x))
v
    then let arr :: Mixed
  ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
  (RepConcrete x)
arr = ShX ('[] @(Maybe Natural)) Int
-> Mixed
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     (RepConcrete x)
forall a (sh :: [Maybe Natural]).
KnownElt a =>
IShX sh -> Mixed ((':) @(Maybe Natural) ('Just @Natural 0) sh) a
Nested.memptyArray ShX ('[] @(Maybe Natural)) Int
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
ShX sh i
ZSX
         in IShX sh
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
-> target (TKX2 sh x)
forall (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape IShX sh
sh (target
   (TKX2
      ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
      x)
 -> target (TKX2 sh x))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
-> target (TKX2 sh x)
forall a b. (a -> b) -> a -> b
$ FullShapeTK
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     x)
-> Concrete
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
forall (y :: TK). FullShapeTK y -> Concrete y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     x)
-> RepConcrete
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
-> FullShapeTK
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Mixed
  ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
  (RepConcrete x)
RepConcrete
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     x)
arr) (RepConcrete
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     x)
-> Concrete
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
        x)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete Mixed
  ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
  (RepConcrete x)
RepConcrete
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural 0) ('[] @(Maybe Natural)))
     x)
arr)
    else Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKX2 sh x))
-> target (TKX2 sh x)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (IShX sh -> Int
forall (sh :: [Maybe Natural]). IShX sh -> Int
shxSize IShX sh
sh) ((forall (n :: Natural).
  KnownNat n =>
  SNat n -> target (TKX2 sh x))
 -> target (TKX2 sh x))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKX2 sh x))
-> target (TKX2 sh x)
forall a b. (a -> b) -> a -> b
$ \(SNat @n) ->
           forall (target :: Target) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape @_ @'[Just n] IShX sh
sh (target
   (TKX2
      ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
      x)
 -> target (TKX2 sh x))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        x)
-> target (TKX2 sh x)
forall a b. (a -> b) -> a -> b
$ Vector (target (TKX2 ('[] @(Maybe Natural)) x))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        x)
forall (n :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
Vector (target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
forall (target :: Target) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
Vector (target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
txfromVector Vector (target (TKX2 ('[] @(Maybe Natural)) x))
v
  txunravelToList :: (KnownNat n, KnownShX sh, KnownSTK x)
                  => target (TKX2 (Just n ': sh) x) -> [target (TKX2 sh x)]
  txunravelToList @_ @sh @x target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
t =
    let f :: Int -> target (TKX2 sh x)
        f :: Int -> target (TKX2 sh x)
f Int
i = target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        sh)
     x)
-> IxXOf
     target
     ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
-> target (TKX2 sh x)
forall (sh1 :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
txindex target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        sh)
     x)
t (Int -> IntOf target
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i IntOf target
-> IxX ('[] @(Maybe Natural)) (IntOf target)
-> IxXOf
     target
     ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
i -> IxX sh i -> IxX sh1 i
:.% IxX ('[] @(Maybe Natural)) (IntOf target)
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
IxX sh i
ZIX)
    in (Int -> target (TKX2 sh x)) -> PermR -> [target (TKX2 sh x)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> target (TKX2 sh x)
f [Int
Item PermR
0 .. target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> Int
forall (mn :: Maybe Natural) (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Natural) mn sh) x) -> Int
forall (target :: Target) (mn :: Maybe Natural)
       (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) mn sh) x) -> Int
xwidth target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

  tfromVector
    :: forall y k.
       SNat k -> SingletonTK y -> Data.Vector.Vector (target y)
    -> target (BuildTensorKind k y)
  tfromListR :: SingletonTK y -> ListR k (target y)
             -> target (BuildTensorKind k y)
  tfromListR SingletonTK y
stk ListR k (target y)
l =
    SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (y :: TK) (k :: Natural).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Natural).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector (ListR k (target y) -> SNat k
forall (n :: Natural) i. ListR n i -> SNat n
listrRank ListR k (target y)
l) SingletonTK y
stk (Vector (target y) -> target (BuildTensorKind k y))
-> (ListR k (target y) -> Vector (target y))
-> ListR k (target y)
-> target (BuildTensorKind k y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [target y] -> Vector (target y)
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList ([target y] -> Vector (target y))
-> (ListR k (target y) -> [target y])
-> ListR k (target y)
-> Vector (target y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListR k (target y) -> [target y]
forall a. ListR k a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList (ListR k (target y) -> target (BuildTensorKind k y))
-> ListR k (target y) -> target (BuildTensorKind k y)
forall a b. (a -> b) -> a -> b
$ ListR k (target y)
l

  -- A number suffix in the name may indicate the rank of the codomain,
  -- if bounded. Suffix 1 may also mean the operations builds up codomain
  -- by 1 dimension.
  trsum :: (KnownNat n, KnownSTK x)
        => target (TKR2 (1 + n) x) -> target (TKR2 n x)
  -- This op (and it's Delta constructor) is worthwhile, because flattening
  -- is O(n) sometimes, unlike transpose, etc.
  trsum0 :: (KnownNat n, KnownSTK x)
         => target (TKR2 n x) -> target (TKR2 0 x)
  trsum0 = target (TKR2 1 x) -> target (TKR2 0 x)
target (TKR2 (1 + 0) x) -> target (TKR2 0 x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum (target (TKR2 1 x) -> target (TKR2 0 x))
-> (target (TKR2 n x) -> target (TKR2 1 x))
-> target (TKR2 n x)
-> target (TKR2 0 x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKR2 n x) -> target (TKR2 1 x)
forall (n :: Natural) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 n x) -> target (TKR2 1 x)
rflatten
  trdot0 :: (KnownNat n, GoodScalar r)
         => target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
  trdot0 target (TKR2 n (TKScalar r))
t target (TKR2 n (TKScalar r))
u = target (TKR2 (1 + 0) (TKScalar r)) -> target (TKR2 0 (TKScalar r))
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum (target (TKR2 n (TKScalar r)) -> target (TKR2 1 (TKScalar r))
forall (n :: Natural) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 n x) -> target (TKR2 1 x)
rflatten (target (TKR2 n (TKScalar r))
t target (TKR2 n (TKScalar r))
-> target (TKR2 n (TKScalar r)) -> target (TKR2 n (TKScalar r))
forall a. Num a => a -> a -> a
* target (TKR2 n (TKScalar r))
u))
  trdot1In :: (KnownNat n, GoodScalar r)
           => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
           -> target (TKR n r)
  trdot1In @n target (TKR2 (1 + n) (TKScalar r))
t target (TKR2 (1 + n) (TKScalar r))
u = target (TKR2 (1 + n) (TKScalar r)) -> target (TKR2 n (TKScalar r))
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum (target (TKR2 (1 + n) (TKScalar r))
 -> target (TKR2 n (TKScalar r)))
-> target (TKR2 (1 + n) (TKScalar r))
-> target (TKR2 n (TKScalar r))
forall a b. (a -> b) -> a -> b
$ PermR
-> target (TKR2 (1 + n) (TKScalar r))
-> target (TKR2 (1 + n) (TKScalar r))
forall (n :: Natural) (x :: TK).
KnownSTK x =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose (Int -> PermR
permCycle (Int -> PermR) -> Int -> PermR
forall a b. (a -> b) -> a -> b
$ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ forall (n :: Natural) r. (KnownNat n, Num r) => r
valueOf @n) (target (TKR2 (1 + n) (TKScalar r))
t target (TKR2 (1 + n) (TKScalar r))
-> target (TKR2 (1 + n) (TKScalar r))
-> target (TKR2 (1 + n) (TKScalar r))
forall a. Num a => a -> a -> a
* target (TKR2 (1 + n) (TKScalar r))
u)
  trmatvecmul :: GoodScalar r
              => target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
-- How to generalize (#69)? The few straightforward generalizations
-- differ in types but all are far from matmul2.
-- rmatvecmul m v = rflatten $ rmap1 (rreplicate 1 . rdot0 v) m
  trmatvecmul target (TKR 2 r)
m target (TKR 1 r)
v = Int
-> (IntOf target -> target (TKR2 0 (TKScalar r)))
-> target (TKR2 (1 + 0) (TKScalar r))
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trbuild1 (target (TKR2 (1 + 1) (TKScalar r)) -> Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
m) (\IntOf target
i -> target (TKR 1 r)
-> target (TKR 1 r) -> target (TKR2 0 (TKScalar r))
forall (n :: Natural) r.
(KnownNat n, GoodScalar r) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
forall (target :: Target) (n :: Natural) r.
(BaseTensor target, KnownNat n, GoodScalar r) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
trdot0 target (TKR 1 r)
v (target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
m target (TKR2 (1 + 1) (TKScalar r))
-> IxROf target 1 -> target (TKR 1 r)
forall (m :: Natural) (n :: Natural) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Natural) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
`trindex` [Item (IxROf target 1)
IntOf target
i]))
  trmatmul2 :: GoodScalar r
            => target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
-- How to generalize to tmatmul (#69)?
-- Just rmatmul2 the two outermost dimensions?
-- rmatmul2 m1 m2 = rmap1 (rmatvecmul (rtr m2)) m1
  trmatmul2 target (TKR 2 r)
m1 target (TKR 2 r)
m2 =
    Int
-> (IntOf target -> target (TKR2 1 (TKScalar r)))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trbuild1 (target (TKR2 (1 + 1) (TKScalar r)) -> Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
m1) (\IntOf target
i -> target (TKR 2 r)
-> target (TKR2 1 (TKScalar r)) -> target (TKR2 1 (TKScalar r))
forall r.
GoodScalar r =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
trmatvecmul (target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Natural) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr target (TKR 2 r)
target (TKR2 (2 + 0) (TKScalar r))
m2) (target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
m1 target (TKR2 (1 + 1) (TKScalar r))
-> IxROf target 1 -> target (TKR2 1 (TKScalar r))
forall (m :: Natural) (n :: Natural) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Natural) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
`trindex` [Item (IxROf target 1)
IntOf target
i]))
  trreplicate :: (KnownNat n, KnownSTK x)
              => Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
  trreplicate0N :: (KnownNat n, KnownSTK x)
                => IShR n -> target (TKR2 0 x) -> target (TKR2 n x)
  trreplicate0N IShR n
sh = IShR n -> target (TKR2 1 x) -> target (TKR2 n x)
forall (n :: Natural) (m :: Natural) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Natural) (m :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape IShR n
sh (target (TKR2 1 x) -> target (TKR2 n x))
-> (target (TKR2 0 x) -> target (TKR2 1 x))
-> target (TKR2 0 x)
-> target (TKR2 n x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> target (TKR2 0 x) -> target (TKR2 (1 + 0) x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
trreplicate (IShR n -> Int
forall (n :: Natural). IShR n -> Int
shrSize IShR n
sh)

  tssum :: (KnownNat n, KnownShS sh, KnownSTK x)
        => target (TKS2 (n ': sh) x) -> target (TKS2 sh x)
  tssum0 :: (KnownShS sh, KnownSTK x)
         => target (TKS2 sh x) -> target (TKS2 '[] x)
  tssum0 @sh | SNat (Product sh)
SNat <- ShS sh -> SNat (Product sh)
forall (sh :: [Natural]). ShS sh -> SNat (Product sh)
shsProduct (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh) = target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
-> target (TKS2 ('[] @Natural) x)
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
tssum (target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
 -> target (TKS2 ('[] @Natural) x))
-> (target (TKS2 sh x)
    -> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x))
-> target (TKS2 sh x)
-> target (TKS2 ('[] @Natural) x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKS2 sh x)
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
forall (sh :: [Natural]) (x :: TK) (target :: Target).
(KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
sflatten
  tsdot0 :: (KnownShS sh, GoodScalar r)
         => target (TKS sh r) -> target (TKS sh r) -> target (TKS '[] r)
  tsdot0 @sh target (TKS sh r)
t target (TKS sh r)
u | SNat (Product sh)
SNat <- ShS sh -> SNat (Product sh)
forall (sh :: [Natural]). ShS sh -> SNat (Product sh)
shsProduct (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh) = target
  (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) (TKScalar r))
-> target (TKS2 ('[] @Natural) (TKScalar r))
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
tssum (target (TKS sh r)
-> target
     (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) (TKScalar r))
forall (sh :: [Natural]) (x :: TK) (target :: Target).
(KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
sflatten (target (TKS sh r)
t target (TKS sh r) -> target (TKS sh r) -> target (TKS sh r)
forall a. Num a => a -> a -> a
* target (TKS sh r)
u))
  tsdot1In :: (KnownShS sh, GoodScalar r)
           => SNat n -> target (TKS (sh ++ '[n]) r)
           -> target (TKS (sh ++ '[n]) r)
           -> target (TKS sh r)
  tsdot1In @sh (SNat @n) target (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
t target (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
u =
    let cpermR :: PermR
cpermR = Int -> PermR
permCycle (Int -> PermR) -> Int -> PermR
forall a b. (a -> b) -> a -> b
$ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShS sh -> Int
forall (sh :: [Natural]). ShS sh -> Int
shsLength (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh)
    in PermR
-> (forall (list :: [Natural]).
    Perm list -> target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall r.
PermR -> (forall (list :: [Natural]). Perm list -> r) -> r
Permutation.permFromList PermR
cpermR ((forall (list :: [Natural]).
  Perm list -> target (TKS2 sh (TKScalar r)))
 -> target (TKS2 sh (TKScalar r)))
-> (forall (list :: [Natural]).
    Perm list -> target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(Perm list
cperm :: Permutation.Perm cperm) ->
         (:~:)
  @Natural
  (Rank @Natural list)
  (Rank @Natural ((++) @Natural sh ((':) @Natural n ('[] @Natural))))
-> (((Rank @Natural list :: Natural)
     ~ (Rank
          @Natural
          ((++) @Natural sh ((':) @Natural n ('[] @Natural))) :: Natural)) =>
    target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @Natural
  (Rank @Natural list)
  (Rank @Natural ((++) @Natural sh ((':) @Natural n ('[] @Natural))))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Rank cperm :~: Rank (sh ++ '[n])) ((((Rank @Natural list :: Natural)
   ~ (Rank
        @Natural
        ((++) @Natural sh ((':) @Natural n ('[] @Natural))) :: Natural)) =>
  target (TKS2 sh (TKScalar r)))
 -> target (TKS2 sh (TKScalar r)))
-> (((Rank @Natural list :: Natural)
     ~ (Rank
          @Natural
          ((++) @Natural sh ((':) @Natural n ('[] @Natural))) :: Natural)) =>
    target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$
         (:~:)
  @[Natural]
  ((++)
     @Natural
     (Permute
        @Natural
        list
        (TakeLen
           @Natural
           @Natural
           list
           ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
     (DropLen
        @Natural
        @Natural
        list
        ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
  ((':) @Natural n sh)
-> ((((++)
        @Natural
        (Permute
           @Natural
           list
           (TakeLen
              @Natural
              @Natural
              list
              ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
        (DropLen
           @Natural
           @Natural
           list
           ((++) @Natural sh ((':) @Natural n ('[] @Natural)))) :: [Natural])
     ~ ((':) @Natural n sh :: [Natural])) =>
    target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  ((++)
     @Natural
     (Permute
        @Natural
        list
        (TakeLen
           @Natural
           @Natural
           list
           ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
     (DropLen
        @Natural
        @Natural
        list
        ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
  ((':) @Natural n sh)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                    :: Permutation.PermutePrefix cperm (sh ++ '[n])
                       :~: n : sh) (((((++)
      @Natural
      (Permute
         @Natural
         list
         (TakeLen
            @Natural
            @Natural
            list
            ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
      (DropLen
         @Natural
         @Natural
         list
         ((++) @Natural sh ((':) @Natural n ('[] @Natural)))) :: [Natural])
   ~ ((':) @Natural n sh :: [Natural])) =>
  target (TKS2 sh (TKScalar r)))
 -> target (TKS2 sh (TKScalar r)))
-> ((((++)
        @Natural
        (Permute
           @Natural
           list
           (TakeLen
              @Natural
              @Natural
              list
              ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
        (DropLen
           @Natural
           @Natural
           list
           ((++) @Natural sh ((':) @Natural n ('[] @Natural)))) :: [Natural])
     ~ ((':) @Natural n sh :: [Natural])) =>
    target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$
         target (TKS2 sh (TKScalar r))
-> Maybe (target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> target (TKS2 sh (TKScalar r))
forall a. HasCallStack => [Char] -> a
error [Char]
"tsdot1In: impossible non-permutation")
         (Maybe (target (TKS2 sh (TKScalar r)))
 -> target (TKS2 sh (TKScalar r)))
-> Maybe (target (TKS2 sh (TKScalar r)))
-> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Perm list
-> (IsPermutation list => target (TKS2 sh (TKScalar r)))
-> Maybe (target (TKS2 sh (TKScalar r)))
forall r (list :: [Natural]).
Perm list -> (IsPermutation list => r) -> Maybe r
Permutation.permCheckPermutation Perm list
cperm
         ((IsPermutation list => target (TKS2 sh (TKScalar r)))
 -> Maybe (target (TKS2 sh (TKScalar r))))
-> (IsPermutation list => target (TKS2 sh (TKScalar r)))
-> Maybe (target (TKS2 sh (TKScalar r)))
forall a b. (a -> b) -> a -> b
$ target (TKS2 ((':) @Natural n sh) (TKScalar r))
-> target (TKS2 sh (TKScalar r))
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
tssum (target (TKS2 ((':) @Natural n sh) (TKScalar r))
 -> target (TKS2 sh (TKScalar r)))
-> target (TKS2 ((':) @Natural n sh) (TKScalar r))
-> target (TKS2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Perm list
-> target
     (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
-> target
     (TKS2
        ((++)
           @Natural
           (Permute
              @Natural
              list
              (TakeLen
                 @Natural
                 @Natural
                 list
                 ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
           (DropLen
              @Natural
              @Natural
              list
              ((++) @Natural sh ((':) @Natural n ('[] @Natural)))))
        (TKScalar r))
forall (perm :: [Natural]) (sh :: [Natural]) (x :: TK).
(IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @Natural sh),
 KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
forall (target :: Target) (perm :: [Natural]) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @Natural sh),
 KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
tstranspose Perm list
cperm (target (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
t target (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
-> target
     (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
-> target
     (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
forall a. Num a => a -> a -> a
* target (TKS ((++) @Natural sh ((':) @Natural n ('[] @Natural))) r)
u)
  tsmatvecmul :: (KnownNat m, KnownNat n, GoodScalar r)
              => target (TKS '[m, n] r) -> target (TKS '[n] r)
              -> target (TKS '[m] r)
  tsmatvecmul @m target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
m target (TKS ((':) @Natural n ('[] @Natural)) r)
v = forall (target :: Target) (k :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
tsbuild1 @_ @m (\IntOf target
i -> target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS2 ('[] @Natural) (TKScalar r))
forall (sh :: [Natural]) r.
(KnownShS sh, GoodScalar r) =>
target (TKS sh r)
-> target (TKS sh r) -> target (TKS ('[] @Natural) r)
forall (target :: Target) (sh :: [Natural]) r.
(BaseTensor target, KnownShS sh, GoodScalar r) =>
target (TKS sh r)
-> target (TKS sh r) -> target (TKS ('[] @Natural) r)
tsdot0 target (TKS ((':) @Natural n ('[] @Natural)) r)
v (target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
target
  (TKS2
     ((++)
        @Natural
        ((':) @Natural m ('[] @Natural))
        ((':) @Natural n ('[] @Natural)))
     (TKScalar r))
m target
  (TKS2
     ((++)
        @Natural
        ((':) @Natural m ('[] @Natural))
        ((':) @Natural n ('[] @Natural)))
     (TKScalar r))
-> IxSOf target ((':) @Natural m ('[] @Natural))
-> target (TKS ((':) @Natural n ('[] @Natural)) r)
forall (shm :: [Natural]) (shn :: [Natural]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
`tsindex` (IntOf target
i IntOf target
-> IxS ('[] @Natural) (IntOf target)
-> IxSOf target ((':) @Natural m ('[] @Natural))
forall {sh1 :: [Natural]} {i} (n :: Natural) (sh :: [Natural]).
(KnownNat n,
 ((':) @Natural n sh :: [Natural]) ~ (sh1 :: [Natural])) =>
i -> IxS sh i -> IxS sh1 i
:.$ IxS ('[] @Natural) (IntOf target)
forall (sh :: [Natural]) i.
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
IxS sh i
ZIS)))
  tsmatmul2 :: (KnownNat m, KnownNat n, KnownNat p, GoodScalar r)
            => target (TKS '[m, n] r) -> target (TKS '[n, p] r)
            -> target (TKS '[m, p] r)
  tsmatmul2 @m target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
m1 target (TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
m2 =
    forall (target :: Target) (k :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
tsbuild1 @_ @m (\IntOf target
i -> target (TKS ((':) @Natural p ((':) @Natural n ('[] @Natural))) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS2 ((':) @Natural p ('[] @Natural)) (TKScalar r))
forall (m :: Natural) (n :: Natural) r.
(KnownNat m, KnownNat n, GoodScalar r) =>
target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural m ('[] @Natural)) r)
forall (target :: Target) (m :: Natural) (n :: Natural) r.
(BaseTensor target, KnownNat m, KnownNat n, GoodScalar r) =>
target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural m ('[] @Natural)) r)
tsmatvecmul (target (TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
-> target
     (TKS ((':) @Natural p ((':) @Natural n ('[] @Natural))) r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str target (TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
m2) (target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
target
  (TKS2
     ((++)
        @Natural
        ((':) @Natural m ('[] @Natural))
        ((':) @Natural n ('[] @Natural)))
     (TKScalar r))
m1 target
  (TKS2
     ((++)
        @Natural
        ((':) @Natural m ('[] @Natural))
        ((':) @Natural n ('[] @Natural)))
     (TKScalar r))
-> IxSOf target ((':) @Natural m ('[] @Natural))
-> target (TKS ((':) @Natural n ('[] @Natural)) r)
forall (shm :: [Natural]) (shn :: [Natural]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
`tsindex` (IntOf target
i IntOf target
-> IxS ('[] @Natural) (IntOf target)
-> IxSOf target ((':) @Natural m ('[] @Natural))
forall {sh1 :: [Natural]} {i} (n :: Natural) (sh :: [Natural]).
(KnownNat n,
 ((':) @Natural n sh :: [Natural]) ~ (sh1 :: [Natural])) =>
i -> IxS sh i -> IxS sh1 i
:.$ IxS ('[] @Natural) (IntOf target)
forall (sh :: [Natural]) i.
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
IxS sh i
ZIS)))
  tsreplicate :: forall sh k x. KnownSTK x
              => SNat k -> ShS sh -> target (TKS2 sh x)
              -> target (TKS2 (k ': sh) x)
  tsreplicate0N :: forall sh x. KnownSTK x
                => ShS sh -> target (TKS2 '[] x)
                -> target (TKS2 sh x)
  tsreplicate0N ShS sh
sh = ShS sh
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
-> target (TKS2 sh x)
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Natural]) (sh2 :: [Natural])
       (x :: TK).
(BaseTensor target,
 (Product sh :: Natural) ~ (Product sh2 :: Natural), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS sh
sh (target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
 -> target (TKS2 sh x))
-> (target (TKS2 ('[] @Natural) x)
    -> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x))
-> target (TKS2 ('[] @Natural) x)
-> target (TKS2 sh x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat (Product sh)
-> ShS ('[] @Natural)
-> target (TKS2 ('[] @Natural) x)
-> target (TKS2 ((':) @Natural (Product sh) ('[] @Natural)) x)
forall (sh :: [Natural]) (k :: Natural) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
forall (target :: Target) (sh :: [Natural]) (k :: Natural)
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
tsreplicate (ShS sh -> SNat (Product sh)
forall (sh :: [Natural]). ShS sh -> SNat (Product sh)
shsProduct ShS sh
sh) ShS ('[] @Natural)
forall (sh :: [Natural]).
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
ShS sh
ZSS

  -- The choice in BuildTensorKind makes it hard to support this one,
  -- due to DeltaSum and AstSum being typed with BuildTensorKind:
  -- xsum :: (KnownShX sh, KnownShX (mn ': sh), KnownSTK x)
  --     => target (TKX2 (mn ': sh) x) -> target (TKX2 sh x)
  txsum :: (KnownNat n, KnownShX sh, KnownSTK x)
        => target (TKX2 (Just n ': sh) x) -> target (TKX2 sh x)
  txsum0 :: (KnownShX sh, KnownSTK x, ConvertTensor target)
         => target (TKX2 sh x) -> target (TKX2 '[] x)
  txsum0 target (TKX2 sh x)
t = Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKX2 ('[] @(Maybe Natural)) x))
-> target (TKX2 ('[] @(Maybe Natural)) x)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (IShX sh -> Int
forall (sh :: [Maybe Natural]). IShX sh -> Int
shxSize (IShX sh -> Int) -> IShX sh -> Int
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh x) -> IShX sh
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 sh x)
t) ((forall (n :: Natural).
  KnownNat n =>
  SNat n -> target (TKX2 ('[] @(Maybe Natural)) x))
 -> target (TKX2 ('[] @(Maybe Natural)) x))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKX2 ('[] @(Maybe Natural)) x))
-> target (TKX2 ('[] @(Maybe Natural)) x)
forall a b. (a -> b) -> a -> b
$ \SNat n
snat ->
    target
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
     x)
-> target (TKX2 ('[] @(Maybe Natural)) x)
forall (n :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
txsum (StaticShX
  ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        x)
forall (x :: TK) (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (SNat n -> SMayNat @Natural () SNat ('Just @Natural n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat n
snat SMayNat @Natural () SNat ('Just @Natural n)
-> StaticShX ('[] @(Maybe Natural))
-> StaticShX
     ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX ('[] @(Maybe Natural))
forall (sh :: [Maybe Natural]).
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
StaticShX sh
ZKX) (target
   (TKX2
      ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
      x)
 -> target
      (TKX2
         ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
         x))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        x)
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
forall (sh :: [Maybe Natural]) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKX2 sh x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
xflatten target (TKX2 sh x)
t)
  txdot0 :: (KnownShX sh, GoodScalar r, ConvertTensor target)
         => target (TKX sh r) -> target (TKX sh r) -> target (TKX '[] r)
  txdot0 target (TKX sh r)
t target (TKX sh r)
u = Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKX ('[] @(Maybe Natural)) r))
-> target (TKX ('[] @(Maybe Natural)) r)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (IShX sh -> Int
forall (sh :: [Maybe Natural]). IShX sh -> Int
shxSize (IShX sh -> Int) -> IShX sh -> Int
forall a b. (a -> b) -> a -> b
$ target (TKX sh r) -> IShX sh
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX sh r)
t) ((forall (n :: Natural).
  KnownNat n =>
  SNat n -> target (TKX ('[] @(Maybe Natural)) r))
 -> target (TKX ('[] @(Maybe Natural)) r))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n -> target (TKX ('[] @(Maybe Natural)) r))
-> target (TKX ('[] @(Maybe Natural)) r)
forall a b. (a -> b) -> a -> b
$ \SNat n
snat ->
    target
  (TKX2
     ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
     (TKScalar r))
-> target (TKX ('[] @(Maybe Natural)) r)
forall (n :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
txsum (StaticShX
  ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        (TKScalar r))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        (TKScalar r))
forall (x :: TK) (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (SNat n -> SMayNat @Natural () SNat ('Just @Natural n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat n
snat SMayNat @Natural () SNat ('Just @Natural n)
-> StaticShX ('[] @(Maybe Natural))
-> StaticShX
     ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural () SNat n -> StaticShX sh -> StaticShX sh1
:!% StaticShX ('[] @(Maybe Natural))
forall (sh :: [Maybe Natural]).
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
StaticShX sh
ZKX) (target
   (TKX2
      ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
      (TKScalar r))
 -> target
      (TKX2
         ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
         (TKScalar r)))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        (TKScalar r))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKX sh r)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        (TKScalar r))
forall (sh :: [Maybe Natural]) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKX2 sh x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Nothing @Natural) ('[] @(Maybe Natural)))
        x)
xflatten (target (TKX sh r)
t target (TKX sh r) -> target (TKX sh r) -> target (TKX sh r)
forall a. Num a => a -> a -> a
* target (TKX sh r)
u))
  txdot1In :: (KnownShX sh, GoodScalar r)
           => SNat n -> target (TKX (sh ++ '[Just n]) r)
           -> target (TKX (sh ++ '[Just n]) r)
           -> target (TKX sh r)
  txdot1In @sh (SNat @n) target
  (TKX
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
t target
  (TKX
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
u =
    let cpermR :: PermR
cpermR = Int -> PermR
permCycle (Int -> PermR) -> Int -> PermR
forall a b. (a -> b) -> a -> b
$ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ SNat (Rank @(Maybe Natural) sh) -> Int
forall (n :: Natural). SNat n -> Int
sNatValue (StaticShX sh -> SNat (Rank @(Maybe Natural) sh)
forall (sh :: [Maybe Natural]).
StaticShX sh -> SNat (Rank @(Maybe Natural) sh)
ssxRank (forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX @sh))
    in PermR
-> (forall (list :: [Natural]).
    Perm list -> target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall r.
PermR -> (forall (list :: [Natural]). Perm list -> r) -> r
Permutation.permFromList PermR
cpermR ((forall (list :: [Natural]).
  Perm list -> target (TKX2 sh (TKScalar r)))
 -> target (TKX2 sh (TKScalar r)))
-> (forall (list :: [Natural]).
    Perm list -> target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(Perm list
cperm :: Permutation.Perm cperm) ->
         (:~:)
  @Natural
  (Rank @Natural list)
  (Rank
     @(Maybe Natural)
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))))
-> (((Rank @Natural list :: Natural)
     ~ (Rank
          @(Maybe Natural)
          ((++)
             @(Maybe Natural)
             sh
             ((':)
                @(Maybe Natural)
                ('Just @Natural n)
                ('[] @(Maybe Natural)))) :: Natural)) =>
    target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @Natural
  (Rank @Natural list)
  (Rank
     @(Maybe Natural)
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Rank cperm :~: Rank (sh ++ '[Just n])) ((((Rank @Natural list :: Natural)
   ~ (Rank
        @(Maybe Natural)
        ((++)
           @(Maybe Natural)
           sh
           ((':)
              @(Maybe Natural)
              ('Just @Natural n)
              ('[] @(Maybe Natural)))) :: Natural)) =>
  target (TKX2 sh (TKScalar r)))
 -> target (TKX2 sh (TKScalar r)))
-> (((Rank @Natural list :: Natural)
     ~ (Rank
          @(Maybe Natural)
          ((++)
             @(Maybe Natural)
             sh
             ((':)
                @(Maybe Natural)
                ('Just @Natural n)
                ('[] @(Maybe Natural)))) :: Natural)) =>
    target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$
         (:~:)
  @[Maybe Natural]
  ((++)
     @(Maybe Natural)
     (Permute
        @(Maybe Natural)
        list
        (TakeLen
           @(Maybe Natural)
           @Natural
           list
           ((++)
              @(Maybe Natural)
              sh
              ((':)
                 @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
     (DropLen
        @(Maybe Natural)
        @Natural
        list
        ((++)
           @(Maybe Natural)
           sh
           ((':)
              @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
  ((':) @(Maybe Natural) ('Just @Natural n) sh)
-> ((((++)
        @(Maybe Natural)
        (Permute
           @(Maybe Natural)
           list
           (TakeLen
              @(Maybe Natural)
              @Natural
              list
              ((++)
                 @(Maybe Natural)
                 sh
                 ((':)
                    @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
        (DropLen
           @(Maybe Natural)
           @Natural
           list
           ((++)
              @(Maybe Natural)
              sh
              ((':)
                 @(Maybe Natural)
                 ('Just @Natural n)
                 ('[] @(Maybe Natural))))) :: [Maybe Natural])
     ~ ((':) @(Maybe Natural) ('Just @Natural n) sh :: [Maybe
                                                          Natural])) =>
    target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  ((++)
     @(Maybe Natural)
     (Permute
        @(Maybe Natural)
        list
        (TakeLen
           @(Maybe Natural)
           @Natural
           list
           ((++)
              @(Maybe Natural)
              sh
              ((':)
                 @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
     (DropLen
        @(Maybe Natural)
        @Natural
        list
        ((++)
           @(Maybe Natural)
           sh
           ((':)
              @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
  ((':) @(Maybe Natural) ('Just @Natural n) sh)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                    :: Permutation.PermutePrefix cperm (sh ++ '[Just n])
                       :~: Just n : sh) (((((++)
      @(Maybe Natural)
      (Permute
         @(Maybe Natural)
         list
         (TakeLen
            @(Maybe Natural)
            @Natural
            list
            ((++)
               @(Maybe Natural)
               sh
               ((':)
                  @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
      (DropLen
         @(Maybe Natural)
         @Natural
         list
         ((++)
            @(Maybe Natural)
            sh
            ((':)
               @(Maybe Natural)
               ('Just @Natural n)
               ('[] @(Maybe Natural))))) :: [Maybe Natural])
   ~ ((':) @(Maybe Natural) ('Just @Natural n) sh :: [Maybe
                                                        Natural])) =>
  target (TKX2 sh (TKScalar r)))
 -> target (TKX2 sh (TKScalar r)))
-> ((((++)
        @(Maybe Natural)
        (Permute
           @(Maybe Natural)
           list
           (TakeLen
              @(Maybe Natural)
              @Natural
              list
              ((++)
                 @(Maybe Natural)
                 sh
                 ((':)
                    @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
        (DropLen
           @(Maybe Natural)
           @Natural
           list
           ((++)
              @(Maybe Natural)
              sh
              ((':)
                 @(Maybe Natural)
                 ('Just @Natural n)
                 ('[] @(Maybe Natural))))) :: [Maybe Natural])
     ~ ((':) @(Maybe Natural) ('Just @Natural n) sh :: [Maybe
                                                          Natural])) =>
    target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$
         target (TKX2 sh (TKScalar r))
-> Maybe (target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> target (TKX2 sh (TKScalar r))
forall a. HasCallStack => [Char] -> a
error [Char]
"txdot1In: impossible non-permutation")
         (Maybe (target (TKX2 sh (TKScalar r)))
 -> target (TKX2 sh (TKScalar r)))
-> Maybe (target (TKX2 sh (TKScalar r)))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Perm list
-> (IsPermutation list => target (TKX2 sh (TKScalar r)))
-> Maybe (target (TKX2 sh (TKScalar r)))
forall r (list :: [Natural]).
Perm list -> (IsPermutation list => r) -> Maybe r
Permutation.permCheckPermutation Perm list
cperm
         ((IsPermutation list => target (TKX2 sh (TKScalar r)))
 -> Maybe (target (TKX2 sh (TKScalar r))))
-> (IsPermutation list => target (TKX2 sh (TKScalar r)))
-> Maybe (target (TKX2 sh (TKScalar r)))
forall a b. (a -> b) -> a -> b
$ target
  (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) (TKScalar r))
-> target (TKX2 sh (TKScalar r))
forall (n :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
txsum (target
   (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) (TKScalar r))
 -> target (TKX2 sh (TKScalar r)))
-> target
     (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) (TKScalar r))
-> target (TKX2 sh (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Perm list
-> target
     (TKX
        ((++)
           @(Maybe Natural)
           sh
           ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
        r)
-> target
     (TKX2
        ((++)
           @(Maybe Natural)
           (Permute
              @(Maybe Natural)
              list
              (TakeLen
                 @(Maybe Natural)
                 @Natural
                 list
                 ((++)
                    @(Maybe Natural)
                    sh
                    ((':)
                       @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
           (DropLen
              @(Maybe Natural)
              @Natural
              list
              ((++)
                 @(Maybe Natural)
                 sh
                 ((':)
                    @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))))
        (TKScalar r))
forall (perm :: [Natural]) (sh :: [Maybe Natural]) (x :: TK).
(IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @(Maybe Natural) sh),
 KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Natural) perm sh) x)
forall (target :: Target) (perm :: [Natural])
       (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @(Maybe Natural) sh),
 KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Natural) perm sh) x)
txtranspose Perm list
cperm (target
  (TKX
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
t target
  (TKX
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
-> target
     (TKX
        ((++)
           @(Maybe Natural)
           sh
           ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
        r)
-> target
     (TKX
        ((++)
           @(Maybe Natural)
           sh
           ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
        r)
forall a. Num a => a -> a -> a
* target
  (TKX
     ((++)
        @(Maybe Natural)
        sh
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
u)
  txmatvecmul :: forall mm mn r. (GoodScalar r, ConvertTensor target)
              => Nested.SMayNat Int SNat mm -> Nested.SMayNat Int SNat mn
              -> target (TKX '[mm, mn] r) -> target (TKX '[mn] r)
              -> target (TKX '[mm] r)
  txmatvecmul SMayNat @Natural Int SNat mm
mm SMayNat @Natural Int SNat mn
mn target
  (TKX
     ((':)
        @(Maybe Natural)
        mm
        ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
     r)
m target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
v =
    StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
-> (KnownShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) =>
    target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
-> StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
forall (sh :: [Maybe Natural]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
 -> StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))))
-> ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
-> StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
forall a b. (a -> b) -> a -> b
$ SMayNat @Natural Int SNat mm
mm SMayNat @Natural Int SNat mm
-> ShX ('[] @(Maybe Natural)) Int
-> ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Natural)) Int
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
ShX sh i
ZSX) ((KnownShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) =>
  target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
 -> target
      (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> (KnownShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) =>
    target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall a b. (a -> b) -> a -> b
$
    StaticShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural)))
-> (KnownShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) =>
    target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (ShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) Int
-> StaticShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural)))
forall (sh :: [Maybe Natural]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) Int
 -> StaticShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
-> ShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) Int
-> StaticShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural)))
forall a b. (a -> b) -> a -> b
$ SMayNat @Natural Int SNat mn
mn SMayNat @Natural Int SNat mn
-> ShX ('[] @(Maybe Natural)) Int
-> ShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) Int
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Natural)) Int
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
ShX sh i
ZSX) ((KnownShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) =>
  target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
 -> target
      (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> (KnownShX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) =>
    target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall a b. (a -> b) -> a -> b
$
    Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n
    -> target
         (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (SMayNat @Natural Int SNat mm -> Int
forall (n :: Maybe Natural). SMayNat @Natural Int SNat n -> Int
fromSMayNat' SMayNat @Natural Int SNat mm
mm) ((forall (n :: Natural).
  KnownNat n =>
  SNat n
  -> target
       (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
 -> target
      (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n
    -> target
         (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall a b. (a -> b) -> a -> b
$ \(SNat @k) ->
      StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        (TKScalar r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall (x :: TK) (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Natural) sh :: Natural)
 ~ (Rank @(Maybe Natural) sh2 :: Natural),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
-> StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
forall (sh :: [Maybe Natural]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
 -> StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))))
-> ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
-> StaticShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
forall a b. (a -> b) -> a -> b
$ SMayNat @Natural Int SNat mm
mm SMayNat @Natural Int SNat mm
-> ShX ('[] @(Maybe Natural)) Int
-> ShX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) Int
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Natural)) Int
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
ShX sh i
ZSX)
      (target
   (TKX2
      ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
      (TKScalar r))
 -> target
      (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r))
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        (TKScalar r))
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall a b. (a -> b) -> a -> b
$ forall (target :: Target) (k :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShX sh, KnownSTK x) =>
(IntOf target -> target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
txbuild1 @_ @k (\IntOf target
i -> target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
-> target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
-> target (TKX2 ('[] @(Maybe Natural)) (TKScalar r))
forall (sh :: [Maybe Natural]) r.
(KnownShX sh, GoodScalar r, ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Natural)) r)
forall (target :: Target) (sh :: [Maybe Natural]) r.
(BaseTensor target, KnownShX sh, GoodScalar r,
 ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Natural)) r)
txdot0 target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
v (target
  (TKX
     ((':)
        @(Maybe Natural)
        mm
        ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
     r)
target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
        ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
     (TKScalar r))
m target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
        ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
     (TKScalar r))
-> IxXOf target ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
-> target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
forall (sh1 :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
`txindex` (IntOf target
i IntOf target
-> IxX ('[] @(Maybe Natural)) (IntOf target)
-> IxXOf target ((':) @(Maybe Natural) mm ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
i -> IxX sh i -> IxX sh1 i
:.% IxX ('[] @(Maybe Natural)) (IntOf target)
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
IxX sh i
ZIX)))
  txmatmul2 :: ( KnownNat m, KnownNat n, KnownNat p, GoodScalar r
               , ConvertTensor target )
            => target (TKX '[Just m, Just n] r)
            -> target (TKX '[Just n, Just p] r)
            -> target (TKX '[Just m, Just p] r)
  txmatmul2 @m @n @p target
  (TKX
     ((':)
        @(Maybe Natural)
        ('Just @Natural m)
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
m1 target
  (TKX
     ((':)
        @(Maybe Natural)
        ('Just @Natural n)
        ((':) @(Maybe Natural) ('Just @Natural p) ('[] @(Maybe Natural))))
     r)
m2 =
    forall (target :: Target) (k :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShX sh, KnownSTK x) =>
(IntOf target -> target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
txbuild1 @_ @m (\IntOf target
i ->
      SMayNat @Natural Int SNat ('Just @Natural p)
-> SMayNat @Natural Int SNat ('Just @Natural n)
-> target
     (TKX
        ((':)
           @(Maybe Natural)
           ('Just @Natural p)
           ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
        r)
-> target
     (TKX
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        r)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural p) ('[] @(Maybe Natural)))
        (TKScalar r))
forall (mm :: Maybe Natural) (mn :: Maybe Natural) r.
(GoodScalar r, ConvertTensor target) =>
SMayNat @Natural Int SNat mm
-> SMayNat @Natural Int SNat mn
-> target
     (TKX
        ((':)
           @(Maybe Natural)
           mm
           ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
        r)
-> target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
forall (target :: Target) (mm :: Maybe Natural)
       (mn :: Maybe Natural) r.
(BaseTensor target, GoodScalar r, ConvertTensor target) =>
SMayNat @Natural Int SNat mm
-> SMayNat @Natural Int SNat mn
-> target
     (TKX
        ((':)
           @(Maybe Natural)
           mm
           ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))))
        r)
-> target (TKX ((':) @(Maybe Natural) mn ('[] @(Maybe Natural))) r)
-> target (TKX ((':) @(Maybe Natural) mm ('[] @(Maybe Natural))) r)
txmatvecmul (SNat p -> SMayNat @Natural Int SNat ('Just @Natural p)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (forall (n :: Natural). KnownNat n => SNat n
SNat @p)) (SNat n -> SMayNat @Natural Int SNat ('Just @Natural n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (forall (n :: Natural). KnownNat n => SNat n
SNat @n))
                  (target
  (TKX
     ((':)
        @(Maybe Natural)
        ('Just @Natural n)
        ((':) @(Maybe Natural) ('Just @Natural p) ('[] @(Maybe Natural))))
     r)
-> target
     (TKX
        ((':)
           @(Maybe Natural)
           ('Just @Natural p)
           ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
        r)
forall (n :: Natural) (m :: Natural) (sh :: [Maybe Natural])
       (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target
  (TKX2
     ((':)
        @(Maybe Natural)
        ('Just @Natural n)
        ((':) @(Maybe Natural) ('Just @Natural m) sh))
     x)
-> target
     (TKX2
        ((':)
           @(Maybe Natural)
           ('Just @Natural m)
           ((':) @(Maybe Natural) ('Just @Natural n) sh))
        x)
xtr target
  (TKX
     ((':)
        @(Maybe Natural)
        ('Just @Natural n)
        ((':) @(Maybe Natural) ('Just @Natural p) ('[] @(Maybe Natural))))
     r)
m2) (target
  (TKX
     ((':)
        @(Maybe Natural)
        ('Just @Natural m)
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     r)
target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural m) ('[] @(Maybe Natural)))
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     (TKScalar r))
m1 target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural m) ('[] @(Maybe Natural)))
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural))))
     (TKScalar r))
-> IxXOf
     target
     ((':) @(Maybe Natural) ('Just @Natural m) ('[] @(Maybe Natural)))
-> target
     (TKX
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        r)
forall (sh1 :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
`txindex` (IntOf target
i IntOf target
-> IxX ('[] @(Maybe Natural)) (IntOf target)
-> IxXOf
     target
     ((':) @(Maybe Natural) ('Just @Natural m) ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
i -> IxX sh i -> IxX sh1 i
:.% IxX ('[] @(Maybe Natural)) (IntOf target)
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
IxX sh i
ZIX)))
  txreplicate :: forall sh k x. KnownSTK x
              => SNat k -> StaticShX sh -> target (TKX2 sh x)
              -> target (TKX2 (Just k ': sh) x)
  txreplicate0N :: (KnownShX sh, KnownSTK x)
                => IShX sh -> target (TKX2 '[] x) -> target (TKX2 sh x)
  txreplicate0N IShX sh
sh = Int
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n
    -> target (TKX2 ('[] @(Maybe Natural)) x) -> target (TKX2 sh x))
-> target (TKX2 ('[] @(Maybe Natural)) x)
-> target (TKX2 sh x)
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat (IShX sh -> Int
forall (sh :: [Maybe Natural]). IShX sh -> Int
shxSize IShX sh
sh) ((forall (n :: Natural).
  KnownNat n =>
  SNat n
  -> target (TKX2 ('[] @(Maybe Natural)) x) -> target (TKX2 sh x))
 -> target (TKX2 ('[] @(Maybe Natural)) x) -> target (TKX2 sh x))
-> (forall (n :: Natural).
    KnownNat n =>
    SNat n
    -> target (TKX2 ('[] @(Maybe Natural)) x) -> target (TKX2 sh x))
-> target (TKX2 ('[] @(Maybe Natural)) x)
-> target (TKX2 sh x)
forall a b. (a -> b) -> a -> b
$ \SNat n
snat ->
    IShX sh
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        x)
-> target (TKX2 sh x)
forall (sh :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (sh :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape IShX sh
sh (target
   (TKX2
      ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
      x)
 -> target (TKX2 sh x))
-> (target (TKX2 ('[] @(Maybe Natural)) x)
    -> target
         (TKX2
            ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
            x))
-> target (TKX2 ('[] @(Maybe Natural)) x)
-> target (TKX2 sh x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n
-> StaticShX ('[] @(Maybe Natural))
-> target (TKX2 ('[] @(Maybe Natural)) x)
-> target
     (TKX2
        ((':) @(Maybe Natural) ('Just @Natural n) ('[] @(Maybe Natural)))
        x)
forall (sh :: [Maybe Natural]) (k :: Natural) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
forall (target :: Target) (sh :: [Maybe Natural]) (k :: Natural)
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
txreplicate SNat n
snat StaticShX ('[] @(Maybe Natural))
forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX

  trindex :: (KnownNat m, KnownNat n, KnownSTK x)
          => target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
  trindex0 :: (KnownNat m, KnownSTK x)
           => target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
  trindex0 = target (TKR2 m x) -> IxR m (IntOf target) -> target (TKR2 0 x)
target (TKR2 (m + 0) x)
-> IxR m (IntOf target) -> target (TKR2 0 x)
forall (m :: Natural) (n :: Natural) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Natural) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
trindex
  troneHot :: ( KnownNat m, KnownNat n, KnownSTK x
              , BoolOf (PrimalOf target) ~ BoolOf target
              , EqH (PrimalOf target) (TKScalar Int64))
           => IShR m -> target (TKR2 n x) -> IxROf target m
           -> target (TKR2 (m + n) x)
  {-# INLINE troneHot #-}
  troneHot @_ @_ @x IShR m
sh target (TKR2 n x)
v IxROf target m
ix = case forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x of
    SingletonTK x
STKScalar ->
      forall (target :: Target) (m :: Natural) (n :: Natural)
       (p :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
 KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
trscatter @_ @0 (IShR m -> ShR n Int -> IShR (m + n)
forall (n :: Natural) (m :: Natural) i.
ShR n i -> ShR m i -> ShR (n + m) i
shrAppend IShR m
sh (target (TKR2 n x) -> ShR n Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 n x)
v)) target (TKR2 n x)
target (TKR2 (0 + n) x)
v (IxROf target m -> IxR 0 (IntOf target) -> IxROf target m
forall a b. a -> b -> a
const IxROf target m
ix)
    SingletonTK x
_ -> case SingletonTK (TKR2 n x)
-> target (TKR2 n x) -> FullShapeTK (TKR2 n x)
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK (TKR2 n x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK target (TKR2 n x)
v of
      FTKR IShR n
_ FullShapeTK x
ftk2 ->
        -- TODO: def at out of bounds
        let f :: IxR (m + n) (IntOf target) -> target (TKR2 0 x)
f IxR (m + n) (IntOf target)
ix2 = SingletonTK (TKR2 0 x)
-> BoolOf target
-> target (TKR2 0 x)
-> target (TKR2 0 x)
-> target (TKR2 0 x)
forall (y :: TK).
Boolean (BoolOf target) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
(BaseTensor target, Boolean (BoolOf target)) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
tcond SingletonTK (TKR2 0 x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
                          ((BoolOf target -> (IntOf target, IntOf target) -> BoolOf target)
-> BoolOf target -> [(IntOf target, IntOf target)] -> BoolOf target
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ !BoolOf target
acc (!IntOf target
i, !IntOf target
i2) -> BoolOf target
acc BoolOf target -> BoolOf target -> BoolOf target
forall b. Boolean b => b -> b -> b
&&* IntOf target
i IntOf target -> IntOf target -> BoolOf (PrimalOf target)
forall (f :: Target) (y :: TK). EqH f y => f y -> f y -> BoolOf f
==. IntOf target
i2) BoolOf target
forall b. Boolean b => b
true
                           ([(IntOf target, IntOf target)] -> BoolOf target)
-> [(IntOf target, IntOf target)] -> BoolOf target
forall a b. (a -> b) -> a -> b
$ [IntOf target] -> [IntOf target] -> [(IntOf target, IntOf target)]
forall a b. [a] -> [b] -> [(a, b)]
zip (IxROf target m -> [Item (IxROf target m)]
forall l. IsList l => l -> [Item l]
toList IxROf target m
ix) (IxR (m + n) (IntOf target) -> [Item (IxR (m + n) (IntOf target))]
forall l. IsList l => l -> [Item l]
toList IxR (m + n) (IntOf target)
ix2))
                          (target (TKR2 n x) -> IxROf target n -> target (TKR2 0 x)
forall (m :: Natural) (x :: TK).
(KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
forall (target :: Target) (m :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
trindex0 target (TKR2 n x)
v (IxR (m + n) (IntOf target) -> IxROf target n
forall (m :: Natural) (n :: Natural) i.
(KnownNat m, KnownNat n) =>
IxR (m + n) i -> IxR n i
ixrDrop IxR (m + n) (IntOf target)
ix2))
                          (FullShapeTK (TKR2 0 x) -> target (TKR2 0 x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (ShR 0 Int -> FullShapeTK x -> FullShapeTK (TKR2 0 x)
forall (n :: Natural) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR ShR 0 Int
forall (n :: Natural) i.
((n :: Natural) ~ (0 :: Natural)) =>
ShR n i
ZSR FullShapeTK x
FullShapeTK x
ftk2))
        in IShR ((m + n) + 0)
-> (IxR (m + n) (IntOf target) -> target (TKR2 0 x))
-> target (TKR2 ((m + n) + 0) x)
forall (m :: Natural) (n :: Natural) (x :: TK) (target :: Target).
(KnownNat m, KnownNat n, KnownSTK x, BaseTensor target) =>
IShR (m + n)
-> (IxROf target m -> target (TKR2 n x)) -> target (TKR2 (m + n) x)
rbuild (IShR m -> ShR n Int -> IShR (m + n)
forall (n :: Natural) (m :: Natural) i.
ShR n i -> ShR m i -> ShR (n + m) i
shrAppend IShR m
sh (target (TKR2 n x) -> ShR n Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 n x)
v)) IxR (m + n) (IntOf target) -> target (TKR2 0 x)
f
           -- TODO: if this is used often, maybe express this as the gather that
           -- would come out of vectorization, making sure it simplifies well
  trscatter :: (KnownNat m, KnownNat n, KnownNat p, KnownSTK x)
            => IShR (p + n) -> target (TKR2 (m + n) x)
            -> (IxROf target m -> IxROf target p)
            -> target (TKR2 (p + n) x)
  trscatter1 :: (KnownNat n, KnownNat p, KnownSTK x)
             => IShR (p + n) -> target (TKR2 (1 + n) x)
             -> (IntOf target -> IxROf target p)
             -> target (TKR2 (p + n) x)
  trscatter1 IShR (p + n)
sh target (TKR2 (1 + n) x)
v IntOf target -> IxROf target p
f = forall (target :: Target) (m :: Natural) (n :: Natural)
       (p :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
 KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
trscatter @target @1 IShR (p + n)
sh target (TKR2 (1 + n) x)
v (\(IntOf target
i :.: IxR n (IntOf target)
ZIR) -> IntOf target -> IxROf target p
f IntOf target
i)
  trgather :: (KnownNat m, KnownNat n, KnownNat p, KnownSTK x)
           => IShR (m + n) -> target (TKR2 (p + n) x)
           -> (IxROf target m -> IxROf target p)
           -> target (TKR2 (m + n) x)
  trgather1 :: (KnownNat n, KnownNat p, KnownSTK x)
            => Int -> target (TKR2 (p + n) x)
            -> (IntOf target -> IxROf target p)
            -> target (TKR2 (1 + n) x)
  trgather1 Int
k target (TKR2 (p + n) x)
v IntOf target -> IxROf target p
f = forall (target :: Target) (m :: Natural) (n :: Natural)
       (p :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
 KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
trgather @target @1
                             (Int
k Int -> ShR n Int -> ShR (n + 1) Int
forall {n1 :: Natural} {i} (n :: Natural).
((n + 1 :: Natural) ~ (n1 :: Natural)) =>
i -> ShR n i -> ShR n1 i
:$: ShR (p + n) Int -> ShR n Int
forall (m :: Natural) (n :: Natural) i.
(KnownNat m, KnownNat n) =>
ShR (m + n) i -> ShR n i
shrDrop (target (TKR2 (p + n) x) -> ShR (p + n) Int
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 (p + n) x)
v)) target (TKR2 (p + n) x)
v
                             (\(IntOf target
i :.: IxR n (IntOf target)
ZIR) -> IntOf target -> IxROf target p
f IntOf target
i)

  tsindex :: (KnownShS shm, KnownShS shn, KnownSTK x)
          => target (TKS2 (shm ++ shn) x) -> IxSOf target shm
          -> target (TKS2 shn x)
  tsindex0 :: (KnownShS sh1, KnownSTK x)
           => target (TKS2 sh1 x) -> IxSOf target sh1
           -> target (TKS2 '[] x)
  tsindex0 @sh1 | (:~:) @[Natural] ((++) @Natural sh1 ('[] @Natural)) sh1
Refl <- forall (l :: [Natural]).
(:~:) @[Natural] ((++) @Natural l ('[] @Natural)) l
forall {a} (l :: [a]). (:~:) @[a] ((++) @a l ('[] @a)) l
lemAppNil @sh1 = target (TKS2 sh1 x)
-> IxS sh1 (IntOf target) -> target (TKS2 ('[] @Natural) x)
target (TKS2 ((++) @Natural sh1 ('[] @Natural)) x)
-> IxS sh1 (IntOf target) -> target (TKS2 ('[] @Natural) x)
forall (shm :: [Natural]) (shn :: [Natural]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex
  tsoneHot :: ( KnownShS sh1, KnownShS sh2, KnownSTK x
              , BoolOf (PrimalOf target) ~ BoolOf target
              , EqH (PrimalOf target) (TKScalar Int64) )
           => target (TKS2 sh2 x) -> IxSOf target sh1
           -> target (TKS2 (sh1 ++ sh2) x)
  {-# INLINE tsoneHot #-}  -- this doesn't want to specialize
  tsoneHot @sh1 @sh2 @x target (TKS2 sh2 x)
v IxSOf target sh1
ix
   | SNat (Rank @Natural sh1)
SNat <- ShS sh1 -> SNat (Rank @Natural sh1)
forall (sh :: [Natural]). ShS sh -> SNat (Rank @Natural sh)
shsRank (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh1) = case forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x of
    SingletonTK x
STKScalar ->
      (:~:)
  @[Natural]
  (Take @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2))
  sh1
-> (((Take
        @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
     ~ (sh1 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  (Take @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2))
  sh1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Take (Rank sh1) (sh1 ++ sh2) :~: sh1) ((((Take
      @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
   ~ (sh1 :: [Natural])) =>
  target (TKS2 ((++) @Natural sh1 sh2) x))
 -> target (TKS2 ((++) @Natural sh1 sh2) x))
-> (((Take
        @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
     ~ (sh1 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
      (:~:)
  @[Natural]
  (Drop @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2))
  sh2
-> (((Drop
        @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
     ~ (sh2 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  (Drop @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2))
  sh2
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Drop (Rank sh1) (sh1 ++ sh2) :~: sh2) ((((Drop
      @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
   ~ (sh2 :: [Natural])) =>
  target (TKS2 ((++) @Natural sh1 sh2) x))
 -> target (TKS2 ((++) @Natural sh1 sh2) x))
-> (((Drop
        @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
     ~ (sh2 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
      forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (shp :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
 KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Natural shp shn) x)
tsscatter @_ @'[] @_ @sh1 target (TKS2 sh2 x)
target (TKS2 ((++) @Natural ('[] @Natural) sh2) x)
v (IxSOf target sh1
-> IxS ('[] @Natural) (IntOf target) -> IxSOf target sh1
forall a b. a -> b -> a
const IxSOf target sh1
ix)
    SingletonTK x
_ -> case SingletonTK (TKS2 sh2 x)
-> target (TKS2 sh2 x) -> FullShapeTK (TKS2 sh2 x)
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK (TKS2 sh2 x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK target (TKS2 sh2 x)
v of
      FTKS ShS sh
_ FullShapeTK x
ftk2 ->
        -- TODO: def at out of bounds
        (:~:)
  @[Natural]
  (Drop
     @Natural
     (Rank @Natural ((++) @Natural sh1 sh2))
     ((++) @Natural sh1 sh2))
  ('[] @Natural)
-> (((Drop
        @Natural
        (Rank @Natural ((++) @Natural sh1 sh2))
        ((++) @Natural sh1 sh2) :: [Natural])
     ~ ('[] @Natural :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  (Drop
     @Natural
     (Rank @Natural ((++) @Natural sh1 sh2))
     ((++) @Natural sh1 sh2))
  ('[] @Natural)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                   :: Drop (Rank (sh1 ++ sh2)) (sh1 ++ sh2) :~: '[]) ((((Drop
      @Natural
      (Rank @Natural ((++) @Natural sh1 sh2))
      ((++) @Natural sh1 sh2) :: [Natural])
   ~ ('[] @Natural :: [Natural])) =>
  target (TKS2 ((++) @Natural sh1 sh2) x))
 -> target (TKS2 ((++) @Natural sh1 sh2) x))
-> (((Drop
        @Natural
        (Rank @Natural ((++) @Natural sh1 sh2))
        ((++) @Natural sh1 sh2) :: [Natural])
     ~ ('[] @Natural :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        (:~:)
  @[Natural]
  (Take
     @Natural
     (Rank @Natural ((++) @Natural sh1 sh2))
     ((++) @Natural sh1 sh2))
  ((++) @Natural sh1 sh2)
-> (((Take
        @Natural
        (Rank @Natural ((++) @Natural sh1 sh2))
        ((++) @Natural sh1 sh2) :: [Natural])
     ~ ((++) @Natural sh1 sh2 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  (Take
     @Natural
     (Rank @Natural ((++) @Natural sh1 sh2))
     ((++) @Natural sh1 sh2))
  ((++) @Natural sh1 sh2)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                   :: Take (Rank (sh1 ++ sh2)) (sh1 ++ sh2) :~: (sh1 ++ sh2)) ((((Take
      @Natural
      (Rank @Natural ((++) @Natural sh1 sh2))
      ((++) @Natural sh1 sh2) :: [Natural])
   ~ ((++) @Natural sh1 sh2 :: [Natural])) =>
  target (TKS2 ((++) @Natural sh1 sh2) x))
 -> target (TKS2 ((++) @Natural sh1 sh2) x))
-> (((Take
        @Natural
        (Rank @Natural ((++) @Natural sh1 sh2))
        ((++) @Natural sh1 sh2) :: [Natural])
     ~ ((++) @Natural sh1 sh2 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        (:~:)
  @[Natural]
  (Drop @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2))
  sh2
-> (((Drop
        @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
     ~ (sh2 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural]
  (Drop @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2))
  sh2
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                   :: Drop (Rank sh1) (sh1 ++ sh2) :~: sh2) ((((Drop
      @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
   ~ (sh2 :: [Natural])) =>
  target (TKS2 ((++) @Natural sh1 sh2) x))
 -> target (TKS2 ((++) @Natural sh1 sh2) x))
-> (((Drop
        @Natural (Rank @Natural sh1) ((++) @Natural sh1 sh2) :: [Natural])
     ~ (sh2 :: [Natural])) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        ShS ((++) @Natural sh1 sh2)
-> (KnownShS ((++) @Natural sh1 sh2) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh1 ShS sh1 -> ShS sh2 -> ShS ((++) @Natural sh1 sh2)
forall (sh :: [Natural]) (sh' :: [Natural]).
ShS sh -> ShS sh' -> ShS ((++) @Natural sh sh')
`shsAppend` forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS @sh2) ((KnownShS ((++) @Natural sh1 sh2) =>
  target (TKS2 ((++) @Natural sh1 sh2) x))
 -> target (TKS2 ((++) @Natural sh1 sh2) x))
-> (KnownShS ((++) @Natural sh1 sh2) =>
    target (TKS2 ((++) @Natural sh1 sh2) x))
-> target (TKS2 ((++) @Natural sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        let f :: IxS ((++) @Natural sh1 sh2) (IntOf target)
-> target (TKS2 ('[] @Natural) x)
f IxS ((++) @Natural sh1 sh2) (IntOf target)
ix2 = SingletonTK (TKS2 ('[] @Natural) x)
-> BoolOf target
-> target (TKS2 ('[] @Natural) x)
-> target (TKS2 ('[] @Natural) x)
-> target (TKS2 ('[] @Natural) x)
forall (y :: TK).
Boolean (BoolOf target) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
(BaseTensor target, Boolean (BoolOf target)) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
tcond SingletonTK (TKS2 ('[] @Natural) x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
                          ((BoolOf target -> (IntOf target, IntOf target) -> BoolOf target)
-> BoolOf target -> [(IntOf target, IntOf target)] -> BoolOf target
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ !BoolOf target
acc (!IntOf target
i, !IntOf target
i2) -> BoolOf target
acc BoolOf target -> BoolOf target -> BoolOf target
forall b. Boolean b => b -> b -> b
&&* IntOf target
i IntOf target -> IntOf target -> BoolOf (PrimalOf target)
forall (f :: Target) (y :: TK). EqH f y => f y -> f y -> BoolOf f
==. IntOf target
i2) BoolOf target
forall b. Boolean b => b
true
                           ([(IntOf target, IntOf target)] -> BoolOf target)
-> [(IntOf target, IntOf target)] -> BoolOf target
forall a b. (a -> b) -> a -> b
$ [IntOf target] -> [IntOf target] -> [(IntOf target, IntOf target)]
forall a b. [a] -> [b] -> [(a, b)]
zip (IxSOf target sh1 -> [IntOf target]
forall a. IxS sh1 a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList IxSOf target sh1
ix) (IxS ((++) @Natural sh1 sh2) (IntOf target) -> [IntOf target]
forall a. IxS ((++) @Natural sh1 sh2) a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList IxS ((++) @Natural sh1 sh2) (IntOf target)
ix2))
                          (target (TKS2 sh2 x)
-> IxSOf target sh2 -> target (TKS2 ('[] @Natural) x)
forall (sh1 :: [Natural]) (x :: TK).
(KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
forall (target :: Target) (sh1 :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
tsindex0 target (TKS2 sh2 x)
v (forall (len :: Natural) (sh :: [Natural]) i.
(KnownShS sh, KnownNat len, KnownShS (Drop @Natural len sh)) =>
IxS sh i -> IxS (Drop @Natural len sh) i
ixsDrop @(Rank sh1) IxS ((++) @Natural sh1 sh2) (IntOf target)
ix2))
                          (FullShapeTK (TKS2 ('[] @Natural) x)
-> target (TKS2 ('[] @Natural) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (ShS ('[] @Natural)
-> FullShapeTK x -> FullShapeTK (TKS2 ('[] @Natural) x)
forall (sh :: [Natural]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Natural)
forall (sh :: [Natural]).
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
ShS sh
ZSS FullShapeTK x
FullShapeTK x
ftk2))
        in forall (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownShS (Take @Natural m sh), KnownShS sh, KnownSTK x,
 BaseTensor target) =>
(IxSOf target (Take @Natural m sh)
 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 sh x)
sbuild @(Rank (sh1 ++ sh2)) IxS ((++) @Natural sh1 sh2) (IntOf target)
-> target (TKS2 ('[] @Natural) x)
IxSOf
  target
  (Take
     @Natural
     (Rank @Natural ((++) @Natural sh1 sh2))
     ((++) @Natural sh1 sh2))
-> target
     (TKS2
        (Drop
           @Natural
           (Rank @Natural ((++) @Natural sh1 sh2))
           ((++) @Natural sh1 sh2))
        x)
f
  tsscatter
     :: (KnownShS shm, KnownShS shn, KnownShS shp, KnownSTK x)
     => target (TKS2 (shm ++ shn) x)
     -> (IxSOf target shm -> IxSOf target shp)
     -> target (TKS2 (shp ++ shn) x)
  tsscatter1
     :: (KnownNat n2, KnownShS shn, KnownShS shp, KnownSTK x)
     => target (TKS2 (n2 ': shn) x)
     -> (IntOf target -> IxSOf target shp)
     -> target (TKS2 (shp ++ shn) x)
  tsscatter1 @n2 target (TKS2 ((':) @Natural n2 shn) x)
v IntOf target -> IxSOf target shp
f = forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (shp :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
 KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Natural shp shn) x)
tsscatter @_ @'[n2] target (TKS2 ((':) @Natural n2 shn) x)
target
  (TKS2 ((++) @Natural ((':) @Natural n2 ('[] @Natural)) shn) x)
v (\(IntOf target
i :.$ IxS sh (IntOf target)
_) -> IntOf target -> IxSOf target shp
f IntOf target
i)
  tsgather
     :: (KnownShS shm, KnownShS shn, KnownShS shp, KnownSTK x)
     => target (TKS2 (shp ++ shn) x)
     -> (IxSOf target shm -> IxSOf target shp)
     -> target (TKS2 (shm ++ shn) x)
  tsgather1
     :: (KnownNat n2, KnownShS shn, KnownShS shp, KnownSTK x)
     => target (TKS2 (shp ++ shn) x)
     -> (IntOf target -> IxSOf target shp)
     -> target (TKS2 (n2 ': shn) x)
  tsgather1 @n2 target (TKS2 ((++) @Natural shp shn) x)
v IntOf target -> IxSOf target shp
f = forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (shp :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
 KnownSTK x) =>
target (TKS2 ((++) @Natural shp shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Natural shm shn) x)
tsgather @target @'[n2] target (TKS2 ((++) @Natural shp shn) x)
v (\(IntOf target
i :.$ IxS sh (IntOf target)
_) -> IntOf target -> IxSOf target shp
f IntOf target
i)

  txindex :: (KnownShX sh1, KnownShX sh2, KnownSTK x)
          => target (TKX2 (sh1 ++ sh2) x) -> IxXOf target sh1
          -> target (TKX2 sh2 x)
  txindex0 :: (KnownShX sh1, KnownSTK x)
           => target (TKX2 sh1 x) -> IxXOf target sh1
           -> target (TKX2 '[] x)
  txindex0 @sh1 | (:~:)
  @[Maybe Natural]
  ((++) @(Maybe Natural) sh1 ('[] @(Maybe Natural)))
  sh1
Refl <- forall (l :: [Maybe Natural]).
(:~:)
  @[Maybe Natural] ((++) @(Maybe Natural) l ('[] @(Maybe Natural))) l
forall {a} (l :: [a]). (:~:) @[a] ((++) @a l ('[] @a)) l
lemAppNil @sh1 = target (TKX2 sh1 x)
-> IxX sh1 (IntOf target) -> target (TKX2 ('[] @(Maybe Natural)) x)
target (TKX2 ((++) @(Maybe Natural) sh1 ('[] @(Maybe Natural))) x)
-> IxX sh1 (IntOf target) -> target (TKX2 ('[] @(Maybe Natural)) x)
forall (sh1 :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
txindex
  txoneHot :: ( KnownShX sh1, KnownShX sh2, KnownSTK x
              , BoolOf (PrimalOf target) ~ BoolOf target
              , EqH (PrimalOf target) (TKScalar Int64), ConvertTensor target )
           => IShX sh1 -> target (TKX2 sh2 x) -> IxXOf target sh1
           -> target (TKX2 (sh1 ++ sh2) x)
  {-# INLINE txoneHot #-}
  txoneHot @sh1 @sh2 @x IShX sh1
sh1 target (TKX2 sh2 x)
v IxXOf target sh1
ix
   | SNat (Rank @(Maybe Natural) sh1)
SNat <- StaticShX sh1 -> SNat (Rank @(Maybe Natural) sh1)
forall (sh :: [Maybe Natural]).
StaticShX sh -> SNat (Rank @(Maybe Natural) sh)
ssxRank (forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX @sh1) = case forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x of
    SingletonTK x
STKScalar ->
      (:~:)
  @[Maybe Natural]
  (Take
     @(Maybe Natural)
     (Rank @(Maybe Natural) sh1)
     ((++) @(Maybe Natural) sh1 sh2))
  sh1
-> (((Take
        @(Maybe Natural)
        (Rank @(Maybe Natural) sh1)
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ (sh1 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  (Take
     @(Maybe Natural)
     (Rank @(Maybe Natural) sh1)
     ((++) @(Maybe Natural) sh1 sh2))
  sh1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Take (Rank sh1) (sh1 ++ sh2) :~: sh1) ((((Take
      @(Maybe Natural)
      (Rank @(Maybe Natural) sh1)
      ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
   ~ (sh1 :: [Maybe Natural])) =>
  target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
 -> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> (((Take
        @(Maybe Natural)
        (Rank @(Maybe Natural) sh1)
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ (sh1 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
      (:~:)
  @[Maybe Natural]
  (Drop
     @(Maybe Natural)
     (Rank @(Maybe Natural) sh1)
     ((++) @(Maybe Natural) sh1 sh2))
  sh2
-> (((Drop
        @(Maybe Natural)
        (Rank @(Maybe Natural) sh1)
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ (sh2 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  (Drop
     @(Maybe Natural)
     (Rank @(Maybe Natural) sh1)
     ((++) @(Maybe Natural) sh1 sh2))
  sh2
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Drop (Rank sh1) (sh1 ++ sh2) :~: sh2) ((((Drop
      @(Maybe Natural)
      (Rank @(Maybe Natural) sh1)
      ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
   ~ (sh2 :: [Maybe Natural])) =>
  target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
 -> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> (((Drop
        @(Maybe Natural)
        (Rank @(Maybe Natural) sh1)
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ (sh2 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
      forall (target :: Target) (shm :: [Maybe Natural])
       (shn :: [Maybe Natural]) (shp :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
 KnownSTK x) =>
IShX ((++) @(Maybe Natural) shp shn)
-> target (TKX2 ((++) @(Maybe Natural) shm shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Natural) shp shn) x)
txscatter @_ @'[] @_ @sh1 (IShX sh1 -> ShX sh2 Int -> IShX ((++) @(Maybe Natural) sh1 sh2)
forall (sh :: [Maybe Natural]) (sh' :: [Maybe Natural]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Natural) sh sh') i
shxAppend IShX sh1
sh1 (target (TKX2 sh2 x) -> ShX sh2 Int
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 sh2 x)
v)) target (TKX2 sh2 x)
target (TKX2 ((++) @(Maybe Natural) ('[] @(Maybe Natural)) sh2) x)
v (IxXOf target sh1
-> IxX ('[] @(Maybe Natural)) (IntOf target) -> IxXOf target sh1
forall a b. a -> b -> a
const IxXOf target sh1
ix)
    SingletonTK x
_ -> case SingletonTK (TKX2 sh2 x)
-> target (TKX2 sh2 x) -> FullShapeTK (TKX2 sh2 x)
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK (TKX2 sh2 x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK target (TKX2 sh2 x)
v of
      FTKX IShX sh
_ FullShapeTK x
ftk2 ->
        -- TODO: def at out of bounds
        (:~:)
  @[Maybe Natural]
  (Drop
     @(Maybe Natural)
     (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
     ((++) @(Maybe Natural) sh1 sh2))
  ('[] @(Maybe Natural))
-> (((Drop
        @(Maybe Natural)
        (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  (Drop
     @(Maybe Natural)
     (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
     ((++) @(Maybe Natural) sh1 sh2))
  ('[] @(Maybe Natural))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                   :: Drop (Rank (sh1 ++ sh2)) (sh1 ++ sh2) :~: '[]) ((((Drop
      @(Maybe Natural)
      (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
      ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
   ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
  target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
 -> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> (((Drop
        @(Maybe Natural)
        (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        (:~:)
  @[Maybe Natural]
  (Take
     @(Maybe Natural)
     (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
     ((++) @(Maybe Natural) sh1 sh2))
  ((++) @(Maybe Natural) sh1 sh2)
-> (((Take
        @(Maybe Natural)
        (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ ((++) @(Maybe Natural) sh1 sh2 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  (Take
     @(Maybe Natural)
     (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
     ((++) @(Maybe Natural) sh1 sh2))
  ((++) @(Maybe Natural) sh1 sh2)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                   :: Take (Rank (sh1 ++ sh2)) (sh1 ++ sh2) :~: (sh1 ++ sh2)) ((((Take
      @(Maybe Natural)
      (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
      ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
   ~ ((++) @(Maybe Natural) sh1 sh2 :: [Maybe Natural])) =>
  target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
 -> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> (((Take
        @(Maybe Natural)
        (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ ((++) @(Maybe Natural) sh1 sh2 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        (:~:)
  @[Maybe Natural]
  (Drop
     @(Maybe Natural)
     (Rank @(Maybe Natural) sh1)
     ((++) @(Maybe Natural) sh1 sh2))
  sh2
-> (((Drop
        @(Maybe Natural)
        (Rank @(Maybe Natural) sh1)
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ (sh2 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Maybe Natural]
  (Drop
     @(Maybe Natural)
     (Rank @(Maybe Natural) sh1)
     ((++) @(Maybe Natural) sh1 sh2))
  sh2
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                   :: Drop (Rank sh1) (sh1 ++ sh2) :~: sh2) ((((Drop
      @(Maybe Natural)
      (Rank @(Maybe Natural) sh1)
      ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
   ~ (sh2 :: [Maybe Natural])) =>
  target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
 -> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> (((Drop
        @(Maybe Natural)
        (Rank @(Maybe Natural) sh1)
        ((++) @(Maybe Natural) sh1 sh2) :: [Maybe Natural])
     ~ (sh2 :: [Maybe Natural])) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        StaticShX ((++) @(Maybe Natural) sh1 sh2)
-> (KnownShX ((++) @(Maybe Natural) sh1 sh2) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX @sh1 StaticShX sh1
-> StaticShX sh2 -> StaticShX ((++) @(Maybe Natural) sh1 sh2)
forall (sh :: [Maybe Natural]) (sh' :: [Maybe Natural]).
StaticShX sh
-> StaticShX sh' -> StaticShX ((++) @(Maybe Natural) sh sh')
`ssxAppend` forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX @sh2) ((KnownShX ((++) @(Maybe Natural) sh1 sh2) =>
  target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
 -> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> (KnownShX ((++) @(Maybe Natural) sh1 sh2) =>
    target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x))
-> target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$
        let f :: IxX ((++) @(Maybe Natural) sh1 sh2) (IntOf target)
-> target (TKX2 ('[] @(Maybe Natural)) x)
f IxX ((++) @(Maybe Natural) sh1 sh2) (IntOf target)
ix2 = SingletonTK (TKX2 ('[] @(Maybe Natural)) x)
-> BoolOf target
-> target (TKX2 ('[] @(Maybe Natural)) x)
-> target (TKX2 ('[] @(Maybe Natural)) x)
-> target (TKX2 ('[] @(Maybe Natural)) x)
forall (y :: TK).
Boolean (BoolOf target) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
(BaseTensor target, Boolean (BoolOf target)) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
tcond SingletonTK (TKX2 ('[] @(Maybe Natural)) x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
                          ((BoolOf target -> (IntOf target, IntOf target) -> BoolOf target)
-> BoolOf target -> [(IntOf target, IntOf target)] -> BoolOf target
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ !BoolOf target
acc (!IntOf target
i, !IntOf target
i2) -> BoolOf target
acc BoolOf target -> BoolOf target -> BoolOf target
forall b. Boolean b => b -> b -> b
&&* IntOf target
i IntOf target -> IntOf target -> BoolOf (PrimalOf target)
forall (f :: Target) (y :: TK). EqH f y => f y -> f y -> BoolOf f
==. IntOf target
i2) BoolOf target
forall b. Boolean b => b
true
                           ([(IntOf target, IntOf target)] -> BoolOf target)
-> [(IntOf target, IntOf target)] -> BoolOf target
forall a b. (a -> b) -> a -> b
$ [IntOf target] -> [IntOf target] -> [(IntOf target, IntOf target)]
forall a b. [a] -> [b] -> [(a, b)]
zip (IxXOf target sh1 -> [IntOf target]
forall a. IxX sh1 a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList IxXOf target sh1
ix) (IxX ((++) @(Maybe Natural) sh1 sh2) (IntOf target)
-> [IntOf target]
forall a. IxX ((++) @(Maybe Natural) sh1 sh2) a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Foldable.toList IxX ((++) @(Maybe Natural) sh1 sh2) (IntOf target)
ix2))
                          (target (TKX2 sh2 x)
-> IxXOf target sh2 -> target (TKX2 ('[] @(Maybe Natural)) x)
forall (sh1 :: [Maybe Natural]) (x :: TK).
(KnownShX sh1, KnownSTK x) =>
target (TKX2 sh1 x)
-> IxXOf target sh1 -> target (TKX2 ('[] @(Maybe Natural)) x)
forall (target :: Target) (sh1 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX sh1, KnownSTK x) =>
target (TKX2 sh1 x)
-> IxXOf target sh1 -> target (TKX2 ('[] @(Maybe Natural)) x)
txindex0 target (TKX2 sh2 x)
v (forall (len :: Natural) (sh :: [Maybe Natural]) i.
(KnownNat len, KnownShX sh,
 KnownShX (Drop @(Maybe Natural) len sh)) =>
IxX sh i -> IxX (Drop @(Maybe Natural) len sh) i
ixxDrop' @(Rank sh1) IxX ((++) @(Maybe Natural) sh1 sh2) (IntOf target)
ix2))
                          (FullShapeTK (TKX2 ('[] @(Maybe Natural)) x)
-> target (TKX2 ('[] @(Maybe Natural)) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (ShX ('[] @(Maybe Natural)) Int
-> FullShapeTK x -> FullShapeTK (TKX2 ('[] @(Maybe Natural)) x)
forall (sh :: [Maybe Natural]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX ShX ('[] @(Maybe Natural)) Int
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
ShX sh i
ZSX FullShapeTK x
FullShapeTK x
ftk2))
        in forall (m :: Natural) (sh :: [Maybe Natural]) (x :: TK)
       (target :: Target).
(KnownShX (Take @(Maybe Natural) m sh), KnownSTK x,
 BaseTensor target, ConvertTensor target) =>
IShX sh
-> (IxXOf target (Take @(Maybe Natural) m sh)
    -> target (TKX2 (Drop @(Maybe Natural) m sh) x))
-> target (TKX2 sh x)
xbuild @(Rank (sh1 ++ sh2)) (IShX sh1 -> ShX sh2 Int -> IShX ((++) @(Maybe Natural) sh1 sh2)
forall (sh :: [Maybe Natural]) (sh' :: [Maybe Natural]) i.
ShX sh i -> ShX sh' i -> ShX ((++) @(Maybe Natural) sh sh') i
shxAppend IShX sh1
sh1 (target (TKX2 sh2 x) -> ShX sh2 Int
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 sh2 x)
v)) IxX ((++) @(Maybe Natural) sh1 sh2) (IntOf target)
-> target (TKX2 ('[] @(Maybe Natural)) x)
IxXOf
  target
  (Take
     @(Maybe Natural)
     (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
     ((++) @(Maybe Natural) sh1 sh2))
-> target
     (TKX2
        (Drop
           @(Maybe Natural)
           (Rank @(Maybe Natural) ((++) @(Maybe Natural) sh1 sh2))
           ((++) @(Maybe Natural) sh1 sh2))
        x)
f
  txscatter :: (KnownShX shm, KnownShX shn, KnownShX shp, KnownSTK x)
            => IShX (shp ++ shn) -> target (TKX2 (shm ++ shn) x)
            -> (IxXOf target shm -> IxXOf target shp)
            -> target (TKX2 (shp ++ shn) x)
  -- TODO: generalize this to non-Just types.
  txscatter1 :: (KnownNat n2, KnownShX shn, KnownShX shp, KnownSTK x)
             => IShX (shp ++ shn) -> target (TKX2 (Just n2 ': shn) x)
             -> (IntOf target -> IxXOf target shp)
             -> target (TKX2 (shp ++ shn) x)
  txscatter1 @n2 @_ @shp @x IShX ((++) @(Maybe Natural) shp shn)
sh target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n2) shn) x)
v IntOf target -> IxXOf target shp
f = forall (target :: Target) (shm :: [Maybe Natural])
       (shn :: [Maybe Natural]) (shp :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
 KnownSTK x) =>
IShX ((++) @(Maybe Natural) shp shn)
-> target (TKX2 ((++) @(Maybe Natural) shm shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Natural) shp shn) x)
txscatter @_ @'[Just n2] @_ @shp @x IShX ((++) @(Maybe Natural) shp shn)
sh target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n2) shn) x)
target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural n2) ('[] @(Maybe Natural)))
        shn)
     x)
v
                                               (\(IntOf target
i :.% IxX sh (IntOf target)
_) -> IntOf target -> IxXOf target shp
f IntOf target
i)
  txgather :: (KnownShX shm, KnownShX shn, KnownShX shp, KnownSTK x)
           => IShX (shm ++ shn)
           -> target (TKX2 (shp ++ shn) x)
           -> (IxXOf target shm -> IxXOf target shp)
           -> target (TKX2 (shm ++ shn) x)
  txgather1 :: (KnownNat n2, KnownShX shn, KnownShX shp, KnownSTK x)
            => SNat n2 -> target (TKX2 (shp ++ shn) x)
            -> (IntOf target -> IxXOf target shp)
            -> target (TKX2 (Just n2 ': shn) x)
  txgather1 @n2 @_ @shp SNat n2
k target (TKX2 ((++) @(Maybe Natural) shp shn) x)
v IntOf target -> IxXOf target shp
f =
    forall (target :: Target) (shm :: [Maybe Natural])
       (shn :: [Maybe Natural]) (shp :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
 KnownSTK x) =>
IShX ((++) @(Maybe Natural) shm shn)
-> target (TKX2 ((++) @(Maybe Natural) shp shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Natural) shm shn) x)
txgather @target @'[Just n2]
             (SNat n2 -> SMayNat @Natural Int SNat ('Just @Natural n2)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat n2
k SMayNat @Natural Int SNat ('Just @Natural n2)
-> ShX shn Int
-> ShX ((':) @(Maybe Natural) ('Just @Natural n2) shn) Int
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
SMayNat @Natural i SNat n -> ShX sh i -> ShX sh1 i
:$% StaticShX shp
-> ShX ((++) @(Maybe Natural) shp shn) Int -> ShX shn Int
forall (sh :: [Maybe Natural]) (sh' :: [Maybe Natural]) i.
StaticShX sh -> ShX ((++) @(Maybe Natural) sh sh') i -> ShX sh' i
shxDropSSX (forall (sh :: [Maybe Natural]). KnownShX sh => StaticShX sh
knownShX @shp) (target (TKX2 ((++) @(Maybe Natural) shp shn) x)
-> ShX ((++) @(Maybe Natural) shp shn) Int
forall (sh :: [Maybe Natural]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 ((++) @(Maybe Natural) shp shn) x)
v)) target (TKX2 ((++) @(Maybe Natural) shp shn) x)
v
             (\(IntOf target
i :.% IxX sh (IntOf target)
ZIX) -> IntOf target -> IxXOf target shp
f IntOf target
i)

  trfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2)
          => target (TKR n r) -> target (TKR n r2)
  trfromIntegral :: (GoodScalar r1, Integral r1, GoodScalar r2)
                 => target (TKR n r1) -> target (TKR n r2)
  trcast :: (RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2)
         => target (TKR n r1) -> target (TKR n r2)
  trminIndex, trmaxIndex  -- partial
    :: forall n r r2. (GoodScalar r, GoodScalar r2)
    => target (TKR (1 + n) r) -> target (TKR n r2)
  triota :: GoodScalar r => Int -> target (TKR 1 r)  -- from 0 to n - 1

  tsfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2)
          => target (TKS sh r) -> target (TKS sh r2)
  tsfromIntegral :: (GoodScalar r1, Integral r1, GoodScalar r2)
                 => target (TKS sh r1) -> target (TKS sh r2)
  tscast :: (RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2)
         => target (TKS sh r1) -> target (TKS sh r2)
  tsminIndex, tsmaxIndex  -- partial
    :: forall n sh r r2. (GoodScalar r, GoodScalar r2)
    => target (TKS (n ': sh) r) -> target (TKS (Init (n ': sh)) r2)
  tsiota :: (KnownNat n, GoodScalar r)
         => target (TKS '[n] r)  -- from 0 to n - 1

  txfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2)
          => target (TKX sh r) -> target (TKX sh r2)
  txfromIntegral :: (GoodScalar r1, Integral r1, GoodScalar r2)
                 => target (TKX sh r1) -> target (TKX sh r2)
  txcast :: (RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2)
         => target (TKX sh r1) -> target (TKX sh r2)
  txminIndex, txmaxIndex  -- partial
    :: forall mn sh r r2. (GoodScalar r, GoodScalar r2)
    => target (TKX (mn ': sh) r) -> target (TKX (Init (mn ': sh)) r2)
  txiota :: (KnownNat n, GoodScalar r)
         => target (TKX '[Just n] r)  -- from 0 to n - 1

  tkfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2)
          => target (TKScalar r) -> target (TKScalar r2)
  tkfromIntegral :: (GoodScalar r1, Integral r1, GoodScalar r2)
                 => target (TKScalar r1) -> target (TKScalar r2)
  tkcast :: (RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2)
         => target (TKScalar r1) -> target (TKScalar r2)

  trappend :: forall n x. KnownSTK x
           => target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
           -> target (TKR2 (1 + n) x)
  trslice :: forall n x. KnownSTK x
          => Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
  trreverse :: forall n x. KnownSTK x
            => target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
  trtranspose :: forall n x. KnownSTK x
              => Permutation.PermR -> target (TKR2 n x) -> target (TKR2 n x)
  trreshape :: forall n m x. KnownSTK x
            => IShR m -> target (TKR2 n x) -> target (TKR2 m x)

  tsappend :: forall m n sh x. KnownSTK x
           => target (TKS2 (m ': sh) x) -> target (TKS2 (n ': sh) x)
           -> target (TKS2 ((m + n) ': sh) x)
  tsslice :: forall i n k sh x. KnownSTK x
          => SNat i -> SNat n -> SNat k
          -> target (TKS2 (i + n + k ': sh) x) -> target (TKS2 (n ': sh) x)
  tsreverse :: forall n sh x. KnownSTK x
            => target (TKS2 (n ': sh) x) -> target (TKS2 (n ': sh) x)
  tstranspose :: ( Permutation.IsPermutation perm
                 , Rank perm <= Rank sh, KnownSTK x )
              => Permutation.Perm perm -> target (TKS2 sh x)
              -> target (TKS2 (Permutation.PermutePrefix perm sh) x)
  tsreshape :: (Product sh ~ Product sh2, KnownSTK x)
            => ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)

  txappend :: forall m n sh x. KnownSTK x
           => target (TKX2 (Just m ': sh) x) -> target (TKX2 (Just n ': sh) x)
           -> target (TKX2 (Just (m + n) ': sh) x)
  txslice :: forall i n k sh x. KnownSTK x
          => SNat i -> SNat n -> SNat k
          -> target (TKX2 (Just (i + n + k) ': sh) x)
          -> target (TKX2 (Just n ': sh) x)
  txreverse :: forall mn sh x. KnownSTK x
            => target (TKX2 (mn ': sh) x) -> target (TKX2 (mn ': sh) x)
  txtranspose :: ( Permutation.IsPermutation perm
                 , Rank perm <= Rank sh, KnownSTK x )
              => Permutation.Perm perm -> target (TKX2 sh x)
              -> target (TKX2 (Permutation.PermutePrefix perm sh) x)
  txreshape :: forall sh sh2 x. KnownSTK x
            => IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)

  trbuild1 :: (KnownNat n, KnownSTK x)
           => Int -> (IntOf target -> target (TKR2 n x))
           -> target (TKR2 (1 + n) x)
  trmap0N :: (KnownNat n, KnownSTK x, KnownSTK x1)
          => (target (TKR2 0 x1) -> target (TKR2 0 x)) -> target (TKR2 n x1)
          -> target (TKR2 n x)
  trmap0N target (TKR2 0 x1) -> target (TKR2 0 x)
f target (TKR2 n x1)
v = IShR (n + 0)
-> (IxROf target n -> target (TKR2 0 x)) -> target (TKR2 (n + 0) x)
forall (m :: Natural) (n :: Natural) (x :: TK) (target :: Target).
(KnownNat m, KnownNat n, KnownSTK x, BaseTensor target) =>
IShR (m + n)
-> (IxROf target m -> target (TKR2 n x)) -> target (TKR2 (m + n) x)
rbuild (target (TKR2 n x1) -> IShR n
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 n x1)
v) (target (TKR2 0 x1) -> target (TKR2 0 x)
f (target (TKR2 0 x1) -> target (TKR2 0 x))
-> (IxROf target n -> target (TKR2 0 x1))
-> IxROf target n
-> target (TKR2 0 x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKR2 n x1) -> IxROf target n -> target (TKR2 0 x1)
forall (m :: Natural) (x :: TK).
(KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
forall (target :: Target) (m :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
trindex0 target (TKR2 n x1)
v)
  trzipWith0N :: (KnownNat n, KnownSTK x, KnownSTK x1, KnownSTK x2)
              => (target (TKR2 0 x1) -> target (TKR2 0 x2) -> target (TKR2 0 x))
              -> target (TKR2 n x1) -> target (TKR2 n x2) -> target (TKR2 n x)
  trzipWith0N target (TKR2 0 x1) -> target (TKR2 0 x2) -> target (TKR2 0 x)
f target (TKR2 n x1)
u target (TKR2 n x2)
v =
    IShR (n + 0)
-> (IxROf target n -> target (TKR2 0 x)) -> target (TKR2 (n + 0) x)
forall (m :: Natural) (n :: Natural) (x :: TK) (target :: Target).
(KnownNat m, KnownNat n, KnownSTK x, BaseTensor target) =>
IShR (m + n)
-> (IxROf target m -> target (TKR2 n x)) -> target (TKR2 (m + n) x)
rbuild (target (TKR2 n x2) -> IShR n
forall (n :: Natural) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 n x2)
v) (\IxROf target n
ix -> target (TKR2 0 x1) -> target (TKR2 0 x2) -> target (TKR2 0 x)
f (target (TKR2 n x1) -> IxROf target n -> target (TKR2 0 x1)
forall (m :: Natural) (x :: TK).
(KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
forall (target :: Target) (m :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
trindex0 target (TKR2 n x1)
u IxROf target n
ix) (target (TKR2 n x2) -> IxROf target n -> target (TKR2 0 x2)
forall (m :: Natural) (x :: TK).
(KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
forall (target :: Target) (m :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownSTK x) =>
target (TKR2 m x) -> IxROf target m -> target (TKR2 0 x)
trindex0 target (TKR2 n x2)
v IxROf target n
ix))

  tsbuild1 :: (KnownNat k, KnownShS sh, KnownSTK x)
           => (IntOf target -> target (TKS2 sh x))
           -> target (TKS2 (k ': sh) x)
  tsmap0N :: (KnownShS sh, KnownSTK x, KnownSTK x1)
          => (target (TKS2 '[] x1) -> target (TKS2 '[] x))
          -> target (TKS2 sh x1)
          -> target (TKS2 sh x)
  tsmap0N @sh target (TKS2 ('[] @Natural) x1) -> target (TKS2 ('[] @Natural) x)
f target (TKS2 sh x1)
v =
    (:~:)
  @[Natural] (Drop @Natural (Rank @Natural sh) sh) ('[] @Natural)
-> (((Drop @Natural (Rank @Natural sh) sh :: [Natural])
     ~ ('[] @Natural :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural] (Drop @Natural (Rank @Natural sh) sh) ('[] @Natural)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Drop (Rank sh) sh :~: '[])
    ((((Drop @Natural (Rank @Natural sh) sh :: [Natural])
   ~ ('[] @Natural :: [Natural])) =>
  target (TKS2 sh x))
 -> target (TKS2 sh x))
-> (((Drop @Natural (Rank @Natural sh) sh :: [Natural])
     ~ ('[] @Natural :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ (:~:) @[Natural] (Take @Natural (Rank @Natural sh) sh) sh
-> (((Take @Natural (Rank @Natural sh) sh :: [Natural])
     ~ (sh :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @[Natural] (Take @Natural (Rank @Natural sh) sh) sh
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Take (Rank sh) sh :~: sh)
    ((((Take @Natural (Rank @Natural sh) sh :: [Natural])
   ~ (sh :: [Natural])) =>
  target (TKS2 sh x))
 -> target (TKS2 sh x))
-> (((Take @Natural (Rank @Natural sh) sh :: [Natural])
     ~ (sh :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ forall (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownShS (Take @Natural m sh), KnownShS sh, KnownSTK x,
 BaseTensor target) =>
(IxSOf target (Take @Natural m sh)
 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 sh x)
sbuild @(Rank sh) (target (TKS2 ('[] @Natural) x1) -> target (TKS2 ('[] @Natural) x)
f (target (TKS2 ('[] @Natural) x1) -> target (TKS2 ('[] @Natural) x))
-> (IxS sh (IntOf target) -> target (TKS2 ('[] @Natural) x1))
-> IxS sh (IntOf target)
-> target (TKS2 ('[] @Natural) x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKS2 sh x1)
-> IxS sh (IntOf target) -> target (TKS2 ('[] @Natural) x1)
forall (sh1 :: [Natural]) (x :: TK).
(KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
forall (target :: Target) (sh1 :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
tsindex0 target (TKS2 sh x1)
v)
  tszipWith0N :: (KnownShS sh, KnownSTK x, KnownSTK x1, KnownSTK x2)
              => (target (TKS2 '[] x1) -> target (TKS2 '[] x2)
                  -> target (TKS2 '[] x))
              -> target (TKS2 sh x1) -> target (TKS2 sh x2)
              -> target (TKS2 sh x)
  tszipWith0N @sh target (TKS2 ('[] @Natural) x1)
-> target (TKS2 ('[] @Natural) x2)
-> target (TKS2 ('[] @Natural) x)
f target (TKS2 sh x1)
u target (TKS2 sh x2)
v =
    (:~:)
  @[Natural] (Drop @Natural (Rank @Natural sh) sh) ('[] @Natural)
-> (((Drop @Natural (Rank @Natural sh) sh :: [Natural])
     ~ ('[] @Natural :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
  @[Natural] (Drop @Natural (Rank @Natural sh) sh) ('[] @Natural)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Drop (Rank sh) sh :~: '[])
    ((((Drop @Natural (Rank @Natural sh) sh :: [Natural])
   ~ ('[] @Natural :: [Natural])) =>
  target (TKS2 sh x))
 -> target (TKS2 sh x))
-> (((Drop @Natural (Rank @Natural sh) sh :: [Natural])
     ~ ('[] @Natural :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ (:~:) @[Natural] (Take @Natural (Rank @Natural sh) sh) sh
-> (((Take @Natural (Rank @Natural sh) sh :: [Natural])
     ~ (sh :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @[Natural] (Take @Natural (Rank @Natural sh) sh) sh
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Take (Rank sh) sh :~: sh)
    ((((Take @Natural (Rank @Natural sh) sh :: [Natural])
   ~ (sh :: [Natural])) =>
  target (TKS2 sh x))
 -> target (TKS2 sh x))
-> (((Take @Natural (Rank @Natural sh) sh :: [Natural])
     ~ (sh :: [Natural])) =>
    target (TKS2 sh x))
-> target (TKS2 sh x)
forall a b. (a -> b) -> a -> b
$ forall (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownShS (Take @Natural m sh), KnownShS sh, KnownSTK x,
 BaseTensor target) =>
(IxSOf target (Take @Natural m sh)
 -> target (TKS2 (Drop @Natural m sh) x))
-> target (TKS2 sh x)
sbuild @(Rank sh) (\IxSOf target (Take @Natural (Rank @Natural sh) sh)
ix -> target (TKS2 ('[] @Natural) x1)
-> target (TKS2 ('[] @Natural) x2)
-> target (TKS2 ('[] @Natural) x)
f (target (TKS2 sh x1)
-> IxSOf target sh -> target (TKS2 ('[] @Natural) x1)
forall (sh1 :: [Natural]) (x :: TK).
(KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
forall (target :: Target) (sh1 :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
tsindex0 target (TKS2 sh x1)
u IxSOf target sh
IxSOf target (Take @Natural (Rank @Natural sh) sh)
ix) (target (TKS2 sh x2)
-> IxSOf target sh -> target (TKS2 ('[] @Natural) x2)
forall (sh1 :: [Natural]) (x :: TK).
(KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
forall (target :: Target) (sh1 :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS sh1, KnownSTK x) =>
target (TKS2 sh1 x)
-> IxSOf target sh1 -> target (TKS2 ('[] @Natural) x)
tsindex0 target (TKS2 sh x2)
v IxSOf target sh
IxSOf target (Take @Natural (Rank @Natural sh) sh)
ix))

  txbuild1 :: (KnownNat k, KnownShX sh, KnownSTK x)
           => (IntOf target -> target (TKX2 sh x))
           -> target (TKX2 (Just k ': sh) x)

  tbuild1 :: forall y k. ConvertTensor target
               -- y comes first, because k easy to set via SNat
          => SNat k -> SingletonTK y -> (IntOf target -> target y)
          -> target (BuildTensorKind k y)
  tbuild1 snat :: SNat k
snat@SNat k
SNat SingletonTK y
stk0 IntOf target -> target y
f =
    let replSTK :: SingletonTK z -> (IntOf target -> target z)
                -> target (BuildTensorKind k z)
        replSTK :: forall (z :: TK).
SingletonTK z
-> (IntOf target -> target z) -> target (BuildTensorKind k z)
replSTK SingletonTK z
stk IntOf target -> target z
g = case SingletonTK z
stk of
          SingletonTK z
STKScalar -> (IntOf target -> target (TKS2 ('[] @Natural) (TKScalar r)))
-> target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
forall (target :: Target) (k :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
tsbuild1 (target (TKScalar r) -> target (TKS2 ('[] @Natural) (TKScalar r))
forall r.
GoodScalar r =>
target (TKScalar r) -> target (TKS ('[] @Natural) r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKScalar r) -> target (TKS ('[] @Natural) r)
sfromK (target (TKScalar r) -> target (TKS2 ('[] @Natural) (TKScalar r)))
-> (IntOf target -> target (TKScalar r))
-> IntOf target
-> target (TKS2 ('[] @Natural) (TKScalar r))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntOf target -> target z
IntOf target -> target (TKScalar r)
g)
          STKR SNat n
SNat SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int
-> (IntOf target -> target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trbuild1 (SNat k -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat k
snat) IntOf target -> target z
IntOf target -> target (TKR2 n x)
g
          STKS ShS sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> ShS sh
-> (KnownShS sh => target (BuildTensorKind k z))
-> target (BuildTensorKind k z)
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target (BuildTensorKind k z))
 -> target (BuildTensorKind k z))
-> (KnownShS sh => target (BuildTensorKind k z))
-> target (BuildTensorKind k z)
forall a b. (a -> b) -> a -> b
$ (IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
forall (k :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
forall (target :: Target) (k :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf target -> target (TKS2 sh x))
-> target (TKS2 ((':) @Natural k sh) x)
tsbuild1 IntOf target -> target z
IntOf target -> target (TKS2 sh x)
g
          STKX StaticShX sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> StaticShX sh
-> (KnownShX sh => target (BuildTensorKind k z))
-> target (BuildTensorKind k z)
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh ((KnownShX sh => target (BuildTensorKind k z))
 -> target (BuildTensorKind k z))
-> (KnownShX sh => target (BuildTensorKind k z))
-> target (BuildTensorKind k z)
forall a b. (a -> b) -> a -> b
$ (IntOf target -> target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
forall (k :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat k, KnownShX sh, KnownSTK x) =>
(IntOf target -> target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
forall (target :: Target) (k :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat k, KnownShX sh, KnownSTK x) =>
(IntOf target -> target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
txbuild1 IntOf target -> target z
IntOf target -> target (TKX2 sh x)
g
          STKProduct @z1 @z2 SingletonTK y1
stk1 SingletonTK z
stk2 ->
              let f1 :: IntOf target -> target y1
f1 IntOf target
i = forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 @_ @z1 @z2 (target (TKProduct y1 z) -> target y1)
-> target (TKProduct y1 z) -> target y1
forall a b. (a -> b) -> a -> b
$ IntOf target -> target z
g IntOf target
i
                  f2 :: IntOf target -> target z
f2 IntOf target
i = forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 @_ @z1 @z2 (target (TKProduct y1 z) -> target z)
-> target (TKProduct y1 z) -> target z
forall a b. (a -> b) -> a -> b
$ IntOf target -> target z
g IntOf target
i
                    -- TODO: looks expensive, but hard to do better,
                    -- so let's hope g is full of variables
              in target (BuildTensorKind k y1)
-> target (BuildTensorKind k z)
-> target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (SingletonTK y1
-> (IntOf target -> target y1) -> target (BuildTensorKind k y1)
forall (z :: TK).
SingletonTK z
-> (IntOf target -> target z) -> target (BuildTensorKind k z)
replSTK SingletonTK y1
stk1 IntOf target -> target y1
f1) (SingletonTK z
-> (IntOf target -> target z) -> target (BuildTensorKind k z)
forall (z :: TK).
SingletonTK z
-> (IntOf target -> target z) -> target (BuildTensorKind k z)
replSTK SingletonTK z
stk2 IntOf target -> target z
f2)
    in SingletonTK y
-> (IntOf target -> target y) -> target (BuildTensorKind k y)
forall (z :: TK).
SingletonTK z
-> (IntOf target -> target z) -> target (BuildTensorKind k z)
replSTK SingletonTK y
stk0 IntOf target -> target y
f

  -- | A strict right mapAccum.
  --
  -- The applications of 'tjvp' and 'tvjp' performed already at this point
  -- ensure that the computation of a derivative is not repeated
  -- and that its result is shared. However, most of the time
  -- the computation is unnneeded, so the AST instance uses a non-strict
  -- constructor 'HordeAd.Core.Ast.AstLambda' for it's instance of 'HFunOf'.
  --
  -- If the same argument functions are passed to many mapAccum calls, as in
  -- > let f = ... in ... (tmapAccumR ... f ...) ... (tmapAccumL ... f ...)
  -- extra care is needed to prevent double derivative computation.
  -- One needs to use tmapAccumRDer manually as in (simplified)
  -- > let f = ...; df = tjvp f; rf = tgrad f
  -- > in ... (tmapAccumRDer f df rf ...) ... (tmapAccumLDer f df rf ...)
  tmapAccumRDer
    :: forall accy by ey k.
       Proxy target
    -> SNat k  -- ^ length of the input
    -> FullShapeTK accy  -- ^ shape of the accumulator
    -> FullShapeTK by  -- ^ shape of the output
    -> FullShapeTK ey  -- ^ shape of an individual input
    -> HFunOf target (TKProduct accy ey) (TKProduct accy by)
         -- ^ the function to mapAccum with
    -> HFunOf target (TKProduct (ADTensorKind (TKProduct accy ey))
                                (TKProduct accy ey))
                     (ADTensorKind (TKProduct accy by))
         -- ^ the derivative of the function to mapAccum with
    -> HFunOf target (TKProduct (ADTensorKind (TKProduct accy by))
                                (TKProduct accy ey))
                     (ADTensorKind (TKProduct accy ey))
         -- ^ the reverse derivative of the function to mapAccum with
    -> target accy  -- ^ the initial accumulator
    -> target (BuildTensorKind k ey)  -- ^ the inputs
    -> target (TKProduct accy (BuildTensorKind k by))
  -- | A strict left mapAccum.
  tmapAccumLDer
    :: forall accy by ey k.
       Proxy target
    -> SNat k  -- ^ length of the input
    -> FullShapeTK accy  -- ^ shape of the accumulator
    -> FullShapeTK by  -- ^ shape of the output
    -> FullShapeTK ey  -- ^ shape of an individual input
    -> HFunOf target (TKProduct accy ey) (TKProduct accy by)
         -- ^ the function to mapAccum with
    -> HFunOf target (TKProduct (ADTensorKind (TKProduct accy ey))
                                (TKProduct accy ey))
                     (ADTensorKind (TKProduct accy by))
         -- ^ the derivative of the function to mapAccum with
    -> HFunOf target (TKProduct (ADTensorKind (TKProduct accy by))
                                (TKProduct accy ey))
                     (ADTensorKind (TKProduct accy ey))
         -- ^ the reverse derivative of the function to mapAccum with
    -> target accy  -- ^ the initial accumulator
    -> target (BuildTensorKind k ey)  -- ^ the inputs
    -> target (TKProduct accy (BuildTensorKind k by))
  tApply :: HFunOf target x z -> target x -> target z
  tlambda :: FullShapeTK x -> HFun x z -> HFunOf target x z

  -- | Reverse derivative.
  --
  -- The followign methods (and tlambda) are exactly what is needed as arguments
  -- of tmapAccumRDer.
  tgrad
    :: FullShapeTK x  -- ^ shape of x and dx
    -> HFun x (TKScalar r)  -- ^ x |-> TKScalar r
    -> HFunOf target x (ADTensorKind x)  -- ^ x |-> dx
  tvjp
    :: FullShapeTK x  -- ^ shape of x and dx
    -> HFun x z  -- ^ x |-> z
    -> HFunOf target (TKProduct (ADTensorKind z) x) (ADTensorKind x)
         -- ^ (dz, x) |-> dx
  tjvp
    :: FullShapeTK x  -- ^ shape of x and dx
    -> HFun x z  -- ^ x |-> z
    -> HFunOf target (TKProduct (ADTensorKind x) x) (ADTensorKind z)
         -- ^ (dx, x) |-> dz

  tprimalPart :: target y -> PrimalOf target y
  tdualPart :: SingletonTK y -> target y -> DualOf target y
  tfromPrimal :: SingletonTK y -> PrimalOf target y -> target y
  tfromDual :: DualOf target y -> target y
  tScale :: (Num (target y), Num (PrimalOf target y))
         => SingletonTK y -> PrimalOf target y -> DualOf target y
         -> DualOf target y
  tScale SingletonTK y
stk PrimalOf target y
s DualOf target y
t =
    SingletonTK y -> target y -> DualOf target y
forall (y :: TK). SingletonTK y -> target y -> DualOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> DualOf target y
tdualPart SingletonTK y
stk (target y -> DualOf target y) -> target y -> DualOf target y
forall a b. (a -> b) -> a -> b
$ forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal @target SingletonTK y
stk PrimalOf target y
s target y -> target y -> target y
forall a. Num a => a -> a -> a
* DualOf target y -> target y
forall (y :: TK). DualOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
DualOf target y -> target y
tfromDual DualOf target y
t

  -- General operations that use ShareTensor if available, LetTensor otherwise
  tsum
    :: forall z k. ConvertTensor target
    => SNat k -> SingletonTK z -> target (BuildTensorKind k z)
    -> target z
  default tsum
    :: forall z k. (ShareTensor target, ConvertTensor target)
    => SNat k -> SingletonTK z -> target (BuildTensorKind k z)
    -> target z
  tsum snat :: SNat k
snat@SNat k
SNat SingletonTK z
stk target (BuildTensorKind k z)
u = case SingletonTK z
stk of
    SingletonTK z
STKScalar -> target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKS ('[] @Natural) r) -> target (TKScalar r)
kfromS (target (TKS ('[] @Natural) r) -> target (TKScalar r))
-> target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
-> target (TKS ('[] @Natural) r)
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
tssum target (BuildTensorKind k z)
target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
u
    STKR SNat n
SNat SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum target (BuildTensorKind k z)
target (TKR2 (1 + n) x)
u
    STKS ShS sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> ShS sh -> (KnownShS sh => target z) -> target z
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target z) -> target z)
-> (KnownShS sh => target z) -> target z
forall a b. (a -> b) -> a -> b
$ target (TKS2 ((':) @Natural k sh) x) -> target (TKS2 sh x)
forall (n :: Natural) (sh :: [Natural]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Natural n sh) x) -> target (TKS2 sh x)
tssum target (BuildTensorKind k z)
target (TKS2 ((':) @Natural k sh) x)
u
    STKX StaticShX sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> StaticShX sh -> (KnownShX sh => target z) -> target z
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh ((KnownShX sh => target z) -> target z)
-> (KnownShX sh => target z) -> target z
forall a b. (a -> b) -> a -> b
$ target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
-> target (TKX2 sh x)
forall (n :: Natural) (sh :: [Maybe Natural]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
forall (target :: Target) (n :: Natural) (sh :: [Maybe Natural])
       (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural n) sh) x)
-> target (TKX2 sh x)
txsum target (BuildTensorKind k z)
target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
u
    STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
      let (target (BuildTensorKind k y1)
u1, target (BuildTensorKind k z)
u2) = target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
-> (target (BuildTensorKind k y1), target (BuildTensorKind k z))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (BuildTensorKind k z)
target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
u
      in target y1 -> target z -> target (TKProduct y1 z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (SNat k
-> SingletonTK y1 -> target (BuildTensorKind k y1) -> target y1
forall (z :: TK) (k :: Natural).
ConvertTensor target =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
tsum SNat k
snat SingletonTK y1
stk1 target (BuildTensorKind k y1)
u1)
               (SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
forall (z :: TK) (k :: Natural).
ConvertTensor target =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
tsum SNat k
snat SingletonTK z
stk2 target (BuildTensorKind k z)
u2)
  treplicate
    :: forall z k. ConvertTensor target
    => SNat k -> SingletonTK z -> target z
    -> target (BuildTensorKind k z)
  default treplicate
    :: forall z k. (ShareTensor target, ConvertTensor target)
    => SNat k -> SingletonTK z -> target z
    -> target (BuildTensorKind k z)
  treplicate snat :: SNat k
snat@SNat k
SNat SingletonTK z
stk target z
u = case SingletonTK z
stk of
    SingletonTK z
STKScalar -> SNat k
-> ShS ('[] @Natural)
-> target (TKS2 ('[] @Natural) (TKScalar r))
-> target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
forall (sh :: [Natural]) (k :: Natural) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
forall (target :: Target) (sh :: [Natural]) (k :: Natural)
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
tsreplicate SNat k
snat ShS ('[] @Natural)
forall (sh :: [Natural]).
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
ShS sh
ZSS (target (TKS2 ('[] @Natural) (TKScalar r))
 -> target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r)))
-> target (TKS2 ('[] @Natural) (TKScalar r))
-> target (TKS2 ((':) @Natural k ('[] @Natural)) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKScalar r) -> target (TKS2 ('[] @Natural) (TKScalar r))
forall r.
GoodScalar r =>
target (TKScalar r) -> target (TKS ('[] @Natural) r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKScalar r) -> target (TKS ('[] @Natural) r)
sfromK target z
target (TKScalar r)
u
    STKR SNat n
SNat SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
forall (n :: Natural) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
trreplicate (SNat k -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat k
snat) target z
target (TKR2 n x)
u
    STKS ShS sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
forall (sh :: [Natural]) (k :: Natural) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
forall (target :: Target) (sh :: [Natural]) (k :: Natural)
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh
-> target (TKS2 sh x)
-> target (TKS2 ((':) @Natural k sh) x)
tsreplicate SNat k
snat ShS sh
sh target z
target (TKS2 sh x)
u
    STKX StaticShX sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
forall (sh :: [Maybe Natural]) (k :: Natural) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
forall (target :: Target) (sh :: [Maybe Natural]) (k :: Natural)
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Natural) ('Just @Natural k) sh) x)
txreplicate SNat k
snat StaticShX sh
sh target z
target (TKX2 sh x)
u
    STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
      let (target y1
u1, target z
u2) = target (TKProduct y1 z) -> (target y1, target z)
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target z
target (TKProduct y1 z)
u
      in target (BuildTensorKind k y1)
-> target (BuildTensorKind k z)
-> target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (SNat k
-> SingletonTK y1 -> target y1 -> target (BuildTensorKind k y1)
forall (z :: TK) (k :: Natural).
ConvertTensor target =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat k
snat SingletonTK y1
stk1 target y1
u1)
               (SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
forall (z :: TK) (k :: Natural).
ConvertTensor target =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat k
snat SingletonTK z
stk2 target z
u2)
  tindexBuild
    :: forall z k. ConvertTensor target
    => SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> IntOf target
    -> target z
  default tindexBuild
    :: forall z k. (ShareTensor target, ConvertTensor target)
    => SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> IntOf target
    -> target z
  tindexBuild snat :: SNat k
snat@SNat k
SNat SingletonTK z
stk target (BuildTensorKind k z)
u IntOf target
i = case SingletonTK z
stk of
    SingletonTK z
STKScalar -> target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall r.
GoodScalar r =>
target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKS ('[] @Natural) r) -> target (TKScalar r)
kfromS (target (TKS ('[] @Natural) r) -> target (TKScalar r))
-> target (TKS ('[] @Natural) r) -> target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ target
  (TKS2
     ((++) @Natural ((':) @Natural k ('[] @Natural)) ('[] @Natural))
     (TKScalar r))
-> IxSOf target ((':) @Natural k ('[] @Natural))
-> target (TKS ('[] @Natural) r)
forall (shm :: [Natural]) (shn :: [Natural]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex target (BuildTensorKind k z)
target
  (TKS2
     ((++) @Natural ((':) @Natural k ('[] @Natural)) ('[] @Natural))
     (TKScalar r))
u (IntOf target
i IntOf target
-> IxS ('[] @Natural) (IntOf target)
-> IxSOf target ((':) @Natural k ('[] @Natural))
forall {sh1 :: [Natural]} {i} (n :: Natural) (sh :: [Natural]).
(KnownNat n,
 ((':) @Natural n sh :: [Natural]) ~ (sh1 :: [Natural])) =>
i -> IxS sh i -> IxS sh1 i
:.$ IxS ('[] @Natural) (IntOf target)
forall (sh :: [Natural]) i.
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
IxS sh i
ZIS)
    STKR SNat n
SNat SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> target (TKR2 (1 + n) x) -> IxROf target 1 -> target (TKR2 n x)
forall (m :: Natural) (n :: Natural) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Natural) (n :: Natural) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
trindex target (BuildTensorKind k z)
target (TKR2 (1 + n) x)
u (IntOf target
i IntOf target -> IxR 0 (IntOf target) -> IxROf target 1
forall {n1 :: Natural} {i} (n :: Natural).
((n + 1 :: Natural) ~ (n1 :: Natural)) =>
i -> IxR n i -> IxR n1 i
:.: IxR 0 (IntOf target)
forall (n :: Natural) i.
((n :: Natural) ~ (0 :: Natural)) =>
IxR n i
ZIR)
    STKS ShS sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> ShS sh -> (KnownShS sh => target z) -> target z
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target z) -> target z)
-> (KnownShS sh => target z) -> target z
forall a b. (a -> b) -> a -> b
$ target (TKS2 ((++) @Natural ((':) @Natural k ('[] @Natural)) sh) x)
-> IxSOf target ((':) @Natural k ('[] @Natural))
-> target (TKS2 sh x)
forall (shm :: [Natural]) (shn :: [Natural]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Natural]) (shn :: [Natural])
       (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Natural shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex target (BuildTensorKind k z)
target (TKS2 ((++) @Natural ((':) @Natural k ('[] @Natural)) sh) x)
u (IntOf target
i IntOf target
-> IxS ('[] @Natural) (IntOf target)
-> IxSOf target ((':) @Natural k ('[] @Natural))
forall {sh1 :: [Natural]} {i} (n :: Natural) (sh :: [Natural]).
(KnownNat n,
 ((':) @Natural n sh :: [Natural]) ~ (sh1 :: [Natural])) =>
i -> IxS sh i -> IxS sh1 i
:.$ IxS ('[] @Natural) (IntOf target)
forall (sh :: [Natural]) i.
((sh :: [Natural]) ~ ('[] @Natural :: [Natural])) =>
IxS sh i
ZIS)
    STKX StaticShX sh
sh SingletonTK x
x | Dict @TK KnownSTK x
Dict <- SingletonTK x -> Dict @TK KnownSTK x
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK SingletonTK x
x -> StaticShX sh -> (KnownShX sh => target z) -> target z
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh
sh ((KnownShX sh => target z) -> target z)
-> (KnownShX sh => target z) -> target z
forall a b. (a -> b) -> a -> b
$ target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural k) ('[] @(Maybe Natural)))
        sh)
     x)
-> IxXOf
     target
     ((':) @(Maybe Natural) ('Just @Natural k) ('[] @(Maybe Natural)))
-> target (TKX2 sh x)
forall (sh1 :: [Maybe Natural]) (sh2 :: [Maybe Natural]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Natural])
       (sh2 :: [Maybe Natural]) (x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Natural) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
txindex target (BuildTensorKind k z)
target
  (TKX2
     ((++)
        @(Maybe Natural)
        ((':) @(Maybe Natural) ('Just @Natural k) ('[] @(Maybe Natural)))
        sh)
     x)
u (IntOf target
i IntOf target
-> IxX ('[] @(Maybe Natural)) (IntOf target)
-> IxXOf
     target
     ((':) @(Maybe Natural) ('Just @Natural k) ('[] @(Maybe Natural)))
forall {sh1 :: [Maybe Natural]} {i} (n :: Maybe Natural)
       (sh :: [Maybe Natural]).
(((':) @(Maybe Natural) n sh :: [Maybe Natural])
 ~ (sh1 :: [Maybe Natural])) =>
i -> IxX sh i -> IxX sh1 i
:.% IxX ('[] @(Maybe Natural)) (IntOf target)
forall (sh :: [Maybe Natural]) i.
((sh :: [Maybe Natural])
 ~ ('[] @(Maybe Natural) :: [Maybe Natural])) =>
IxX sh i
ZIX)
    STKProduct SingletonTK y1
stk1 SingletonTK z
stk2 ->
      let (target (BuildTensorKind k y1)
u1, target (BuildTensorKind k z)
u2) = target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
-> (target (BuildTensorKind k y1), target (BuildTensorKind k z))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (BuildTensorKind k z)
target (TKProduct (BuildTensorKind k y1) (BuildTensorKind k z))
u
      in target y1 -> target z -> target (TKProduct y1 z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (SNat k
-> SingletonTK y1
-> target (BuildTensorKind k y1)
-> IntOf target
-> target y1
forall (z :: TK) (k :: Natural).
ConvertTensor target =>
SNat k
-> SingletonTK z
-> target (BuildTensorKind k z)
-> IntOf target
-> target z
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK z
-> target (BuildTensorKind k z)
-> IntOf target
-> target z
tindexBuild SNat k
snat SingletonTK y1
stk1 target (BuildTensorKind k y1)
u1 IntOf target
i)
               (SNat k
-> SingletonTK z
-> target (BuildTensorKind k z)
-> IntOf target
-> target z
forall (z :: TK) (k :: Natural).
ConvertTensor target =>
SNat k
-> SingletonTK z
-> target (BuildTensorKind k z)
-> IntOf target
-> target z
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK z
-> target (BuildTensorKind k z)
-> IntOf target
-> target z
tindexBuild SNat k
snat SingletonTK z
stk2 target (BuildTensorKind k z)
u2 IntOf target
i)

  -- Unwinding methods, needed mostly to split off the Unwind module.
  -- | Construct tensors with the given constant in each cell.
  treplTarget :: (forall r. GoodScalar r => r) -> FullShapeTK y -> target y
  -- | Construct tensors with @def@ in each cell.
  tdefTarget :: FullShapeTK y -> target y
  -- | Add pointwise all corresponding tensors within nested product, if any.
  --
  -- Requires duplicable arguments or a ShareTensor instance.
  taddTarget :: SingletonTK y -> target y -> target y -> target y
  -- | Multiply pointwise all corresponding tensors within nested products,
  -- if any.
  --
  -- Requires duplicable arguments or a ShareTensor instance.
  tmultTarget :: SingletonTK y -> target y -> target y -> target y
  -- | Sum all dimensions of each component and then sum it all. Ignore all
  -- tensors with non-differentiable elements.
  --
  -- Requires duplicable arguments or a ShareTensor instance.
  tsum0Target :: FullShapeTK y -> target y
              -> target (TKScalar Double)
  -- | Dot product each component and then sum it all. Ignore all
  -- tensors with non-differentiable elements.
  --
  -- Requires duplicable arguments or a ShareTensor instance.
  tdot0Target :: FullShapeTK y -> target y -> target y
              -> target (TKScalar Double)

  -- TODO: express without ConvertTensor or move there
  xmcast :: (KnownSTK x, KnownShX sh, Rank sh ~ Rank sh2, ConvertTensor target)
         => StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
  xmcast StaticShX sh2
sh2 target (TKX2 sh x)
a = case SingletonTK (TKX2 sh x)
-> target (TKX2 sh x) -> FullShapeTK (TKX2 sh x)
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK (TKX2 sh x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK target (TKX2 sh x)
a of
    FTKX IShX sh
sh' FullShapeTK x
_ ->
      IShX sh
-> (forall (sh :: [Natural]).
    ((Rank @Natural sh :: Natural)
     ~ (Rank @(Maybe Natural) sh :: Natural)) =>
    ShS sh -> target (TKX2 sh2 x))
-> target (TKX2 sh2 x)
forall (sh' :: [Maybe Natural]) r.
IShX sh'
-> (forall (sh :: [Natural]).
    ((Rank @Natural sh :: Natural)
     ~ (Rank @(Maybe Natural) sh' :: Natural)) =>
    ShS sh -> r)
-> r
withShsFromShX IShX sh
sh' ((forall (sh :: [Natural]).
  ((Rank @Natural sh :: Natural)
   ~ (Rank @(Maybe Natural) sh :: Natural)) =>
  ShS sh -> target (TKX2 sh2 x))
 -> target (TKX2 sh2 x))
-> (forall (sh :: [Natural]).
    ((Rank @Natural sh :: Natural)
     ~ (Rank @(Maybe Natural) sh :: Natural)) =>
    ShS sh -> target (TKX2 sh2 x))
-> target (TKX2 sh2 x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
        StaticShX sh2
-> (KnownShX sh2 => target (TKX2 sh2 x)) -> target (TKX2 sh2 x)
forall (sh :: [Maybe Natural]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX sh2
sh2 ((KnownShX sh2 => target (TKX2 sh2 x)) -> target (TKX2 sh2 x))
-> (KnownShX sh2 => target (TKX2 sh2 x)) -> target (TKX2 sh2 x)
forall a b. (a -> b) -> a -> b
$
        ShS sh
-> (KnownShS sh => target (TKX2 sh2 x)) -> target (TKX2 sh2 x)
forall (sh :: [Natural]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => target (TKX2 sh2 x)) -> target (TKX2 sh2 x))
-> (KnownShS sh => target (TKX2 sh2 x)) -> target (TKX2 sh2 x)
forall a b. (a -> b) -> a -> b
$
        target (TKS2 sh x) -> target (TKX2 sh2 x)
forall (sh :: [Natural]) (sh' :: [Maybe Natural]) (x :: TK).
(KnownShS sh, KnownShX sh',
 (Rank @Natural sh :: Natural)
 ~ (Rank @(Maybe Natural) sh' :: Natural),
 KnownSTK x) =>
target (TKS2 sh x) -> target (TKX2 sh' x)
forall (target :: Target) (sh :: [Natural])
       (sh' :: [Maybe Natural]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownShX sh',
 (Rank @Natural sh :: Natural)
 ~ (Rank @(Maybe Natural) sh' :: Natural),
 KnownSTK x) =>
target (TKS2 sh x) -> target (TKX2 sh' x)
xfromS (target (TKS2 sh x) -> target (TKX2 sh2 x))
-> target (TKS2 sh x) -> target (TKX2 sh2 x)
forall a b. (a -> b) -> a -> b
$ forall (target :: Target) (sh :: [Natural])
       (sh' :: [Maybe Natural]) (x :: TK).
(ConvertTensor target, KnownShS sh,
 (Rank @Natural sh :: Natural)
 ~ (Rank @(Maybe Natural) sh' :: Natural),
 KnownSTK x) =>
target (TKX2 sh' x) -> target (TKS2 sh x)
sfromX @_ @sh target (TKX2 sh x)
a

-- These are user-accessible, so the constraint is `ADReady`, which means
-- lets, but no shares.
type role HFun nominal nominal
newtype HFun (x :: TK) (z :: TK) =
  HFun {forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun :: forall f. ADReady f
               => f x -> f z}

instance Show (HFun x y) where
  show :: HFun x y -> [Char]
show HFun x y
_ = [Char]
"<lambda>"


-- * The mega-constraint

type ADReady target =
  ( ADReadyNoLet target
  , LetTensor target
-- The following can't be added, because we have instances like ADVal (AstRaw),
-- so AstRaw would need to have a LetTensor instance:
--  , LetTensor (PrimalOf target)
  )

type ADReadyNoLet target =
  ( ADReadyEqsClasses target
  , ADReadyEqsClasses (ShareOf target)
  , ShareTensor (ShareOf target)
  , ShareTensor (PrimalOf (ShareOf target))
  , ShareOf (ShareOf target) ~ ShareOf target
  )

type ADReadyEqsClasses target =
  ( ADReadyEqs target
  , ADReadyClasses target
  , ADReadyClasses (PrimalOf target)
  )

type ADReadyEqs target =
  ( BoolOf (PrimalOf target) ~ BoolOf target
  )

type ADReadyClasses target =
  ( BaseTensor target
  , ConvertTensor target
  , Boolean (BoolOf target)
  , AllTargetShow target
  , CommonTargetEqOrd target
  )

-- This is illegal:
-- type AllTargetShow target = forall y. KnownSTK y => Show (target y)

type AllTargetShow :: Target -> Constraint
class (forall y. KnownSTK y => Show (target y))
      => AllTargetShow target where
instance
      (forall y. KnownSTK y => Show (target y))
      => AllTargetShow target where

type CommonTargetEqOrd :: Target -> Constraint
class ( forall r. GoodScalar r => EqH target (TKScalar r)
      , forall r. GoodScalar r => OrdH target (TKScalar r)
      , forall r n. GoodScalar r => EqH target (TKR n r)
      , forall r n. GoodScalar r => OrdH target (TKR n r)
      , forall r sh. GoodScalar r => EqH target (TKS sh r)
      , forall r sh. GoodScalar r => OrdH target (TKS sh r)
      , forall r sh. GoodScalar r => EqH target (TKX sh r)
      , forall r sh. GoodScalar r => OrdH target (TKX sh r) )
      => CommonTargetEqOrd target where
instance
      ( forall r. GoodScalar r => EqH target (TKScalar r)
      , forall r. GoodScalar r => OrdH target (TKScalar r)
      , forall r n. GoodScalar r => EqH target (TKR n r)
      , forall r n. GoodScalar r => OrdH target (TKR n r)
      , forall r sh. GoodScalar r => EqH target (TKS sh r)
      , forall r sh. GoodScalar r => OrdH target (TKS sh r)
      , forall r sh. GoodScalar r => EqH target (TKX sh r)
      , forall r sh. GoodScalar r => OrdH target (TKX sh r) )
      => CommonTargetEqOrd target where