{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module HordeAd.Core.AstTools
(
ftkAst, isTensorInt
, varInAst, varInAstBool, varInIxS, varNameInAst, varNameInIxS
, astIsSmall, ixIsSmall
, bounds
, liftRFromS1, liftRFromS2, liftXFromS1, liftXFromS2
, cAstConvert, cAstSFromR, cAstSFromX, cAstXFromS
, pattern AstFromS', checkAstFromS, checkFtkAstFromS, checkFtkAstSFrom
, cAstFromS, cAstSFrom, convFromS, convSFrom
, setTotalSharing
) where
import Prelude hiding (foldl')
import Control.Exception.Assert.Sugar
import Data.Int (Int64)
import Data.IORef
import Data.Maybe (isJust)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import System.IO.Unsafe (unsafePerformIO)
import Type.Reflection (typeRep)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Convert (withShsFromShR, withShsFromShX)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (snatPlus)
import HordeAd.Core.Ast
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
ftkAst :: forall s y ms. AstTensor ms s y -> FullShapeTK y
ftkAst :: forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
t = case AstTensor ms s y
t of
AstPair AstTensor ms s y
t1 AstTensor ms s z
t2 -> FullShapeTK y -> FullShapeTK z -> FullShapeTK (TKProduct y z)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
t1) (AstTensor ms s z -> FullShapeTK z
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s z
t2)
AstProject1 AstTensor ms s (TKProduct y z)
v -> case AstTensor ms s (TKProduct y z) -> FullShapeTK (TKProduct y z)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKProduct y z)
v of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
_ -> FullShapeTK y
FullShapeTK y1
ftk1
AstProject2 AstTensor ms s (TKProduct y y)
v -> case AstTensor ms s (TKProduct y y) -> FullShapeTK (TKProduct y y)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKProduct y y)
v of
FTKProduct FullShapeTK y1
_ FullShapeTK z
ftk2 -> FullShapeTK y
FullShapeTK z
ftk2
AstFromVector SNat k
snat SingletonTK y
_ Vector (AstTensor ms s y)
l -> case Vector (AstTensor ms s y)
-> Maybe (AstTensor ms s y, Vector (AstTensor ms s y))
forall (v :: Type -> Type) a. Vector v a => v a -> Maybe (a, v a)
V.uncons Vector (AstTensor ms s y)
l of
Maybe (AstTensor ms s y, Vector (AstTensor ms s y))
Nothing -> [Char] -> FullShapeTK y
forall a. HasCallStack => [Char] -> a
error [Char]
"ftkAst: empty vector"
Just (AstTensor ms s y
v, Vector (AstTensor ms s y)
_) -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
v)
AstSum SNat k
snat SingletonTK y
stk AstTensor ms s (BuildTensorKind k y)
v -> SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> FullShapeTK (BuildTensorKind k y)
-> FullShapeTK y
razeFTK SNat k
snat SingletonTK y
stk (AstTensor ms s (BuildTensorKind k y)
-> FullShapeTK (BuildTensorKind k y)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (BuildTensorKind k y)
v)
AstReplicate SNat k
snat SingletonTK y
_ AstTensor ms s y
v -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
v)
AstMapAccumRDer SNat k
k FullShapeTK by
bftk FullShapeTK ey
_eftk AstHFun s s (TKProduct accy ey) (TKProduct accy by)
_f AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf AstTensor ms s accy
acc0 AstTensor ms s (BuildTensorKind k ey)
_es ->
FullShapeTK accy
-> FullShapeTK (BuildTensorKind k by)
-> FullShapeTK (TKProduct accy (BuildTensorKind k by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (AstTensor ms s accy -> FullShapeTK accy
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s accy
acc0) (SNat k -> FullShapeTK by -> FullShapeTK (BuildTensorKind k by)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK by
bftk)
AstMapAccumLDer SNat k
k FullShapeTK by
bftk FullShapeTK ey
_eftk AstHFun s s (TKProduct accy ey) (TKProduct accy by)
_f AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf AstTensor ms s accy
acc0 AstTensor ms s (BuildTensorKind k ey)
_es ->
FullShapeTK accy
-> FullShapeTK (BuildTensorKind k by)
-> FullShapeTK (TKProduct accy (BuildTensorKind k by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (AstTensor ms s accy -> FullShapeTK accy
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s accy
acc0) (SNat k -> FullShapeTK by -> FullShapeTK (BuildTensorKind k by)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
k FullShapeTK by
bftk)
AstApply (AstLambda !AstVarName s1 x
_ !AstTensor AstMethodLet s y
l) AstTensor ms s1 x
_ -> AstTensor AstMethodLet s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s y
l
AstVar AstVarName s y
var -> AstVarName s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName s y
var
AstCond AstBool ms
_b AstTensor ms s y
v AstTensor ms s y
_w -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
v
AstBuild1 SNat k
snat SingletonTK y
_ (IntVarName
_var, AstTensor ms s y
v) -> SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
forall (k :: Nat) (y :: TK).
SNat k -> FullShapeTK y -> FullShapeTK (BuildTensorKind k y)
buildFTK SNat k
snat (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
v)
AstLet AstVarName s y
_ AstTensor AstMethodLet s y
_ AstTensor AstMethodLet s y
v -> AstTensor AstMethodLet s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s y
v
AstShare AstVarName s y
var AstTensor AstMethodShare s y
_ -> AstVarName s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName s y
var
AstToShare AstTensor AstMethodLet s y
v -> AstTensor AstMethodLet s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor AstMethodLet s y
v
AstPrimalPart AstTensor ms FullSpan y
a -> AstTensor ms FullSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms FullSpan y
a
AstDualPart AstTensor ms FullSpan y
a -> AstTensor ms FullSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms FullSpan y
a
AstFromPrimal AstTensor ms PrimalSpan y
a -> AstTensor ms PrimalSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms PrimalSpan y
a
AstFromDual AstTensor ms DualSpan y
a -> AstTensor ms DualSpan y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms DualSpan y
a
AstPlusK{} -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstTimesK{} -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstN1K{} -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstR1K{} -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstR2K{} -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstI2K{} -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstConcreteK r
_ -> FullShapeTK y
FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstFloorK{} -> FullShapeTK y
FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstFromIntegralK{} -> FullShapeTK y
FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstCastK{} -> FullShapeTK y
FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstPlusS AstTensor ms s (TKS2 sh (TKScalar r))
v AstTensor ms s (TKS2 sh (TKScalar r))
_ -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 sh (TKScalar r))
v
AstTimesS AstTensor ms s (TKS2 sh (TKScalar r))
v AstTensor ms s (TKS2 sh (TKScalar r))
_ -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 sh (TKScalar r))
v
AstN1S OpCodeNum1
_ AstTensor ms s (TKS2 sh (TKScalar r))
v -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 sh (TKScalar r))
v
AstR1S OpCode1
_ AstTensor ms s (TKS2 sh (TKScalar r))
v -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 sh (TKScalar r))
v
AstR2S OpCode2
_ AstTensor ms s (TKS2 sh (TKScalar r))
v AstTensor ms s (TKS2 sh (TKScalar r))
_ -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 sh (TKScalar r))
v
AstI2S OpCodeIntegral2
_ AstTensor ms s (TKS2 sh (TKScalar r))
v AstTensor ms s (TKS2 sh (TKScalar r))
_ -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 sh (TKScalar r))
v
AstConcreteS Shaped sh r
a -> ShS sh
-> FullShapeTK (TKScalar r) -> FullShapeTK (TKS2 sh (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (Shaped sh r -> ShS sh
forall (sh :: [Nat]) a. Elt a => Shaped sh a -> ShS sh
Nested.sshape Shaped sh r
a) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstFloorS AstTensor ms PrimalSpan (TKS sh r1)
v -> case AstTensor ms PrimalSpan (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms PrimalSpan (TKS sh r1)
v of
FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS2 sh (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstFromIntegralS AstTensor ms PrimalSpan (TKS sh r1)
v -> case AstTensor ms PrimalSpan (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms PrimalSpan (TKS sh r1)
v of
FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS2 sh (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstCastS AstTensor ms s (TKS sh r1)
v -> case AstTensor ms s (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS sh r1)
v of
FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS sh
-> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS2 sh (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstIndexS ShS shn
shn AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v AstIxS ms shm
_ix -> case AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v of
FTKS ShS sh
_ FullShapeTK x
x -> ShS shn -> FullShapeTK x -> FullShapeTK (TKS2 shn x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS shn
shn FullShapeTK x
x
AstScatterS ShS shn
shn AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v (AstVarListS shm
_ , AstIxS ms shp
ix) -> case AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v of
FTKS ShS sh
_ FullShapeTK x
x -> ShS ((++) @Nat shp shn)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat shp shn) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (AstIxS ms shp -> ShS shp
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS AstIxS ms shp
ix ShS shp -> ShS shn -> ShS ((++) @Nat shp shn)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS shn
shn) FullShapeTK x
x
AstGatherS ShS shn
shn AstTensor ms s (TKS2 ((++) @Nat shp shn) x)
v (AstVarListS shm
vars, AstIxS ms shp
_) -> case AstTensor ms s (TKS2 ((++) @Nat shp shn) x)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 ((++) @Nat shp shn) x)
v of
FTKS ShS sh
_ FullShapeTK x
x -> ShS ((++) @Nat shm shn)
-> FullShapeTK x -> FullShapeTK (TKS2 ((++) @Nat shm shn) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (AstVarListS shm -> ShS shm
forall (sh :: [Nat]) (f :: Nat -> Type). ListS sh f -> ShS sh
shsFromListS AstVarListS shm
vars ShS shm -> ShS shn -> ShS ((++) @Nat shm shn)
forall (sh :: [Nat]) (sh' :: [Nat]).
ShS sh -> ShS sh' -> ShS ((++) @Nat sh sh')
`shsAppend` ShS shn
shn) FullShapeTK x
x
AstMinIndexS AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
v -> case AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
-> FullShapeTK (TKS ((':) @Nat n sh) r)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
v of
FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS (Init @Nat ((':) @Nat n sh))
-> FullShapeTK (TKScalar r2)
-> FullShapeTK (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS ((':) @Nat n sh) -> ShS (Init @Nat ((':) @Nat n sh))
forall (n :: Nat) (sh :: [Nat]).
ShS ((':) @Nat n sh) -> ShS (Init @Nat ((':) @Nat n sh))
shsInit ShS sh
ShS ((':) @Nat n sh)
sh) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstMaxIndexS AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
v -> case AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
-> FullShapeTK (TKS ((':) @Nat n sh) r)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
v of
FTKS ShS sh
sh FullShapeTK x
FTKScalar -> ShS (Init @Nat ((':) @Nat n sh))
-> FullShapeTK (TKScalar r2)
-> FullShapeTK (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (ShS ((':) @Nat n sh) -> ShS (Init @Nat ((':) @Nat n sh))
forall (n :: Nat) (sh :: [Nat]).
ShS ((':) @Nat n sh) -> ShS (Init @Nat ((':) @Nat n sh))
shsInit ShS sh
ShS ((':) @Nat n sh)
sh) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstIotaS n :: SNat n
n@SNat n
SNat -> ShS ((':) @Nat n ('[] @Nat))
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n
n SNat n -> ShS ('[] @Nat) -> ShS ((':) @Nat n ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstAppendS AstTensor ms s (TKS2 ((':) @Nat m sh) x)
a AstTensor ms s (TKS2 ((':) @Nat n sh) x)
b -> case (AstTensor ms s (TKS2 ((':) @Nat m sh) x)
-> FullShapeTK (TKS2 ((':) @Nat m sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 ((':) @Nat m sh) x)
a, AstTensor ms s (TKS2 ((':) @Nat n sh) x)
-> FullShapeTK (TKS2 ((':) @Nat n sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 ((':) @Nat n sh) x)
b) of
(FTKS (SNat n
m :$$ ShS sh
sh) FullShapeTK x
x, FTKS (SNat n
n :$$ ShS sh
_) FullShapeTK x
_) -> ShS ((':) @Nat (m + n) sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat (m + n) sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n -> SNat n -> SNat (n + n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
snatPlus SNat n
m SNat n
n SNat (m + n) -> ShS sh -> ShS ((':) @Nat (m + n) sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
AstSliceS SNat i
_ n :: SNat n
n@SNat n
SNat SNat k
_ AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
a -> case AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
a of
FTKS (SNat n
_ :$$ ShS sh
sh) FullShapeTK x
x -> ShS ((':) @Nat n sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat n sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n
n SNat n -> ShS sh -> ShS ((':) @Nat n sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x
AstReverseS AstTensor ms s (TKS2 ((':) @Nat n sh) x)
v -> AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
AstTensor ms s (TKS2 ((':) @Nat n sh) x)
v
AstTransposeS Perm perm
perm AstTensor ms s (TKS2 sh x)
v -> case AstTensor ms s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 sh x)
v of
FTKS ShS sh
sh FullShapeTK x
x -> ShS (PermutePrefix @Nat perm sh)
-> FullShapeTK x
-> FullShapeTK (TKS2 (PermutePrefix @Nat perm sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (Perm perm -> ShS sh -> ShS (PermutePrefix @Nat perm sh)
forall (is :: [Nat]) (sh :: [Nat]).
Perm is -> ShS sh -> ShS (PermutePrefix @Nat is sh)
shsPermutePrefix Perm perm
perm ShS sh
sh) FullShapeTK x
x
AstReshapeS ShS sh2
sh2 AstTensor ms s (TKS2 sh x)
v -> case AstTensor ms s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 sh x)
v of
FTKS ShS sh
_ FullShapeTK x
x -> ShS sh2 -> FullShapeTK x -> FullShapeTK (TKS2 sh2 x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh2
sh2 FullShapeTK x
x
AstConvert TKConversion a1 y
c AstTensor ms s a1
u -> TKConversion a1 y -> FullShapeTK a1 -> FullShapeTK y
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a1 y
c (FullShapeTK a1 -> FullShapeTK y)
-> FullShapeTK a1 -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$ AstTensor ms s a1 -> FullShapeTK a1
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s a1
u
AstSum0S AstTensor ms s (TKS2 sh x)
v -> case AstTensor ms s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 sh x)
v of
FTKS ShS sh
_ FullShapeTK x
x -> ShS ('[] @Nat) -> FullShapeTK x -> FullShapeTK (TKS2 ('[] @Nat) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK x
x
AstDot0S AstTensor ms s (TKS sh r)
_u AstTensor ms s (TKS sh r)
_v -> ShS ('[] @Nat)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKS2 ('[] @Nat) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstDot1InS ShS sh
sh SNat n
_ AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
_u AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
_v -> ShS sh
-> FullShapeTK (TKScalar r) -> FullShapeTK (TKS2 sh (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
AstMatmul2S m :: SNat m
m@SNat m
SNat SNat n
_ p :: SNat p
p@SNat p
SNat AstTensor ms s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
_u AstTensor ms s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
_v -> ShS ((':) @Nat m ((':) @Nat p ('[] @Nat)))
-> FullShapeTK (TKScalar r)
-> FullShapeTK
(TKS2 ((':) @Nat m ((':) @Nat p ('[] @Nat))) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat m
m SNat m
-> ShS ((':) @Nat p ('[] @Nat))
-> ShS ((':) @Nat m ((':) @Nat p ('[] @Nat)))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat p
p SNat p -> ShS ('[] @Nat) -> ShS ((':) @Nat p ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar
isTensorInt :: forall s y ms. AstSpan s
=> Proxy s -> FullShapeTK y
-> Maybe (AstTensor ms s y :~: AstInt ms)
isTensorInt :: forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstSpan s =>
Proxy @AstSpanType s
-> FullShapeTK y
-> Maybe ((:~:) @Type (AstTensor ms s y) (AstInt ms))
isTensorInt Proxy @AstSpanType s
_ FullShapeTK y
ftk = case FullShapeTK y
ftk of
FTKScalar @r -> case ( TypeRep @Type r
-> TypeRep @Type Int64 -> Maybe ((:~:) @Type r Int64)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Int64)
, forall (s1 :: AstSpanType) (s2 :: AstSpanType).
(AstSpan s1, AstSpan s2) =>
Maybe ((:~:) @AstSpanType s1 s2)
sameAstSpan @s @PrimalSpan ) of
(Just (:~:) @Type r Int64
Refl, Just (:~:) @AstSpanType s PrimalSpan
Refl) -> (:~:) @Type (AstTensor ms s y) (AstInt ms)
-> Maybe ((:~:) @Type (AstTensor ms s y) (AstInt ms))
forall a. a -> Maybe a
Just (:~:) @Type (AstTensor ms s y) (AstTensor ms s y)
(:~:) @Type (AstTensor ms s y) (AstInt ms)
forall {k} (a :: k). (:~:) @k a a
Refl
(Maybe ((:~:) @Type r Int64),
Maybe ((:~:) @AstSpanType s PrimalSpan))
_ -> Maybe ((:~:) @Type (AstTensor ms s y) (AstInt ms))
forall a. Maybe a
Nothing
FullShapeTK y
_ -> Maybe ((:~:) @Type (AstTensor ms s y) (AstInt ms))
forall a. Maybe a
Nothing
varInAst :: AstVarId -> AstTensor ms s y -> Bool
varInAst :: forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var = \case
AstPair AstTensor ms s y
t1 AstTensor ms s z
t2 -> AstVarId -> AstTensor ms s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s y
t1 Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s z -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s z
t2
AstProject1 AstTensor ms s (TKProduct y z)
t -> AstVarId -> AstTensor ms s (TKProduct y z) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKProduct y z)
t
AstProject2 AstTensor ms s (TKProduct y y)
t -> AstVarId -> AstTensor ms s (TKProduct y y) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKProduct y y)
t
AstFromVector SNat k
_ SingletonTK y
_ Vector (AstTensor ms s y)
vl -> (AstTensor ms s y -> Bool) -> Vector (AstTensor ms s y) -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (AstVarId -> AstTensor ms s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var) Vector (AstTensor ms s y)
vl
AstSum SNat k
_ SingletonTK y
_ AstTensor ms s (BuildTensorKind k y)
v -> AstVarId -> AstTensor ms s (BuildTensorKind k y) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (BuildTensorKind k y)
v
AstReplicate SNat k
_ SingletonTK y
_ AstTensor ms s y
v -> AstVarId -> AstTensor ms s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s y
v
AstMapAccumRDer SNat k
_k FullShapeTK by
_bftk FullShapeTK ey
_eftk AstHFun s s (TKProduct accy ey) (TKProduct accy by)
_f AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf AstTensor ms s accy
acc0 AstTensor ms s (BuildTensorKind k ey)
es ->
AstVarId -> AstTensor ms s accy -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s accy
acc0 Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (BuildTensorKind k ey) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (BuildTensorKind k ey)
es
AstMapAccumLDer SNat k
_k FullShapeTK by
_bftk FullShapeTK ey
_eftk AstHFun s s (TKProduct accy ey) (TKProduct accy by)
_f AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df AstHFun
s
s
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf AstTensor ms s accy
acc0 AstTensor ms s (BuildTensorKind k ey)
es ->
AstVarId -> AstTensor ms s accy -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s accy
acc0 Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (BuildTensorKind k ey) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (BuildTensorKind k ey)
es
AstApply AstHFun s1 s x y
t AstTensor ms s1 x
ll -> AstVarId -> AstHFun s1 s x y -> Bool
forall (s :: AstSpanType) (s2 :: AstSpanType) (x :: TK) (y :: TK).
AstVarId -> AstHFun s s2 x y -> Bool
varInAstHFun AstVarId
var AstHFun s1 s x y
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s1 x -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s1 x
ll
AstVar AstVarName s y
var2 -> AstVarId
var AstVarId -> AstVarId -> Bool
forall a. Eq a => a -> a -> Bool
== AstVarName s y -> AstVarId
forall (s :: AstSpanType) (y :: TK). AstVarName s y -> AstVarId
varNameToAstVarId AstVarName s y
var2
AstCond AstBool ms
b AstTensor ms s y
v AstTensor ms s y
w -> AstVarId -> AstBool ms -> Bool
forall (ms :: AstMethodOfSharing). AstVarId -> AstBool ms -> Bool
varInAstBool AstVarId
var AstBool ms
b Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s y
v Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s y
w
AstBuild1 SNat k
_ SingletonTK y
_ (IntVarName
var2, AstTensor ms s y
v) ->
Bool -> Bool -> Bool
forall a. HasCallStack => Bool -> a -> a
assert (IntVarName -> AstVarId
forall (s :: AstSpanType) (y :: TK). AstVarName s y -> AstVarId
varNameToAstVarId IntVarName
var2 AstVarId -> AstVarId -> Bool
forall a. Eq a => a -> a -> Bool
/= AstVarId
var) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
AstVarId -> AstTensor ms s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s y
v
AstLet AstVarName s y
_ AstTensor AstMethodLet s y
u AstTensor AstMethodLet s y
v -> AstVarId -> AstTensor AstMethodLet s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor AstMethodLet s y
u Bool -> Bool -> Bool
|| AstVarId -> AstTensor AstMethodLet s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor AstMethodLet s y
v
AstShare AstVarName s y
_ AstTensor AstMethodShare s y
v -> AstVarId -> AstTensor AstMethodShare s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor AstMethodShare s y
v
AstToShare AstTensor AstMethodLet s y
v -> AstVarId -> AstTensor AstMethodLet s y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor AstMethodLet s y
v
AstPrimalPart AstTensor ms FullSpan y
a -> AstVarId -> AstTensor ms FullSpan y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms FullSpan y
a
AstDualPart AstTensor ms FullSpan y
a -> AstVarId -> AstTensor ms FullSpan y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms FullSpan y
a
AstFromPrimal AstTensor ms PrimalSpan y
v -> AstVarId -> AstTensor ms PrimalSpan y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan y
v
AstFromDual AstTensor ms DualSpan y
v -> AstVarId -> AstTensor ms DualSpan y -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms DualSpan y
v
AstPlusK AstTensor ms s (TKScalar r)
t AstTensor ms s (TKScalar r)
u -> AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
u
AstTimesK AstTensor ms s (TKScalar r)
t AstTensor ms s (TKScalar r)
u -> AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
u
AstN1K OpCodeNum1
_ AstTensor ms s (TKScalar r)
t -> AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
t
AstR1K OpCode1
_ AstTensor ms s (TKScalar r)
t -> AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
t
AstR2K OpCode2
_ AstTensor ms s (TKScalar r)
t AstTensor ms s (TKScalar r)
u -> AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
u
AstI2K OpCodeIntegral2
_ AstTensor ms s (TKScalar r)
t AstTensor ms s (TKScalar r)
u -> AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r)
u
AstConcreteK{} -> Bool
False
AstFloorK AstTensor ms PrimalSpan (TKScalar r1)
a -> AstVarId -> AstTensor ms PrimalSpan (TKScalar r1) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKScalar r1)
a
AstFromIntegralK AstTensor ms PrimalSpan (TKScalar r1)
t -> AstVarId -> AstTensor ms PrimalSpan (TKScalar r1) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKScalar r1)
t
AstCastK AstTensor ms s (TKScalar r1)
t -> AstVarId -> AstTensor ms s (TKScalar r1) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKScalar r1)
t
AstPlusS AstTensor ms s (TKS2 sh (TKScalar r))
t AstTensor ms s (TKS2 sh (TKScalar r))
u -> AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
u
AstTimesS AstTensor ms s (TKS2 sh (TKScalar r))
t AstTensor ms s (TKS2 sh (TKScalar r))
u -> AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
u
AstN1S OpCodeNum1
_ AstTensor ms s (TKS2 sh (TKScalar r))
t -> AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
t
AstR1S OpCode1
_ AstTensor ms s (TKS2 sh (TKScalar r))
t -> AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
t
AstR2S OpCode2
_ AstTensor ms s (TKS2 sh (TKScalar r))
t AstTensor ms s (TKS2 sh (TKScalar r))
u -> AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
u
AstI2S OpCodeIntegral2
_ AstTensor ms s (TKS2 sh (TKScalar r))
t AstTensor ms s (TKS2 sh (TKScalar r))
u -> AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
t Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 sh (TKScalar r)) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh (TKScalar r))
u
AstConcreteS{} -> Bool
False
AstFloorS AstTensor ms PrimalSpan (TKS sh r1)
a -> AstVarId -> AstTensor ms PrimalSpan (TKS sh r1) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKS sh r1)
a
AstFromIntegralS AstTensor ms PrimalSpan (TKS sh r1)
a -> AstVarId -> AstTensor ms PrimalSpan (TKS sh r1) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKS sh r1)
a
AstCastS AstTensor ms s (TKS sh r1)
t -> AstVarId -> AstTensor ms s (TKS sh r1) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS sh r1)
t
AstIndexS ShS shn
_ AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v AstIxS ms shm
ix -> AstVarId -> AstTensor ms s (TKS2 ((++) @Nat shm shn) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v Bool -> Bool -> Bool
|| AstVarId -> AstIxS ms shm -> Bool
forall (ms :: AstMethodOfSharing) (sh :: [Nat]).
AstVarId -> AstIxS ms sh -> Bool
varInIxS AstVarId
var AstIxS ms shm
ix
AstScatterS ShS shn
_ AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v (AstVarListS shm
_vars, AstIxS ms shp
ix) -> AstVarId -> AstIxS ms shp -> Bool
forall (ms :: AstMethodOfSharing) (sh :: [Nat]).
AstVarId -> AstIxS ms sh -> Bool
varInIxS AstVarId
var AstIxS ms shp
ix Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 ((++) @Nat shm shn) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((++) @Nat shm shn) x)
v
AstGatherS ShS shn
_ AstTensor ms s (TKS2 ((++) @Nat shp shn) x)
v (AstVarListS shm
_vars, AstIxS ms shp
ix) -> AstVarId -> AstIxS ms shp -> Bool
forall (ms :: AstMethodOfSharing) (sh :: [Nat]).
AstVarId -> AstIxS ms sh -> Bool
varInIxS AstVarId
var AstIxS ms shp
ix Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 ((++) @Nat shp shn) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((++) @Nat shp shn) x)
v
AstMinIndexS AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
a -> AstVarId
-> AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
a
AstMaxIndexS AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
a -> AstVarId
-> AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKS ((':) @Nat n sh) r)
a
AstIotaS{} -> Bool
False
AstAppendS AstTensor ms s (TKS2 ((':) @Nat m sh) x)
v AstTensor ms s (TKS2 ((':) @Nat n sh) x)
u -> AstVarId -> AstTensor ms s (TKS2 ((':) @Nat m sh) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((':) @Nat m sh) x)
v Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS2 ((':) @Nat n sh) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((':) @Nat n sh) x)
u
AstSliceS SNat i
_ SNat n
_ SNat k
_ AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v -> AstVarId
-> AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v
AstReverseS AstTensor ms s (TKS2 ((':) @Nat n sh) x)
v -> AstVarId -> AstTensor ms s (TKS2 ((':) @Nat n sh) x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 ((':) @Nat n sh) x)
v
AstTransposeS Perm perm
_perm AstTensor ms s (TKS2 sh x)
v -> AstVarId -> AstTensor ms s (TKS2 sh x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh x)
v
AstReshapeS ShS sh2
_ AstTensor ms s (TKS2 sh x)
v -> AstVarId -> AstTensor ms s (TKS2 sh x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh x)
v
AstConvert TKConversion a1 y
_ AstTensor ms s a1
v -> AstVarId -> AstTensor ms s a1 -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s a1
v
AstSum0S AstTensor ms s (TKS2 sh x)
v -> AstVarId -> AstTensor ms s (TKS2 sh x) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS2 sh x)
v
AstDot0S AstTensor ms s (TKS sh r)
u AstTensor ms s (TKS sh r)
v -> AstVarId -> AstTensor ms s (TKS sh r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS sh r)
u Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms s (TKS sh r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS sh r)
v
AstDot1InS ShS sh
_ SNat n
_ AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
u AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
v -> AstVarId
-> AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
-> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
u Bool -> Bool -> Bool
|| AstVarId
-> AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
-> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS ((++) @Nat sh ((':) @Nat n ('[] @Nat))) r)
v
AstMatmul2S SNat m
_ SNat n
_ SNat p
_ AstTensor ms s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
u AstTensor ms s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
v -> AstVarId
-> AstTensor ms s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
u Bool -> Bool -> Bool
|| AstVarId
-> AstTensor ms s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms s (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
v
varInIxS :: AstVarId -> AstIxS ms sh -> Bool
varInIxS :: forall (ms :: AstMethodOfSharing) (sh :: [Nat]).
AstVarId -> AstIxS ms sh -> Bool
varInIxS AstVarId
var = (AstInt ms -> Bool) -> IxS sh (AstInt ms) -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (AstVarId -> AstInt ms -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var)
varInAstHFun :: AstVarId -> AstHFun s s2 x y -> Bool
varInAstHFun :: forall (s :: AstSpanType) (s2 :: AstSpanType) (x :: TK) (y :: TK).
AstVarId -> AstHFun s s2 x y -> Bool
varInAstHFun AstVarId
_var AstLambda{} =
Bool
False
varInAstBool :: AstVarId -> AstBool ms -> Bool
varInAstBool :: forall (ms :: AstMethodOfSharing). AstVarId -> AstBool ms -> Bool
varInAstBool AstVarId
var = \case
AstBoolConst{} -> Bool
False
AstBoolNot AstBool ms
b -> AstVarId -> AstBool ms -> Bool
forall (ms :: AstMethodOfSharing). AstVarId -> AstBool ms -> Bool
varInAstBool AstVarId
var AstBool ms
b
AstBoolAnd AstBool ms
arg1 AstBool ms
arg2 -> AstVarId -> AstBool ms -> Bool
forall (ms :: AstMethodOfSharing). AstVarId -> AstBool ms -> Bool
varInAstBool AstVarId
var AstBool ms
arg1 Bool -> Bool -> Bool
|| AstVarId -> AstBool ms -> Bool
forall (ms :: AstMethodOfSharing). AstVarId -> AstBool ms -> Bool
varInAstBool AstVarId
var AstBool ms
arg2
AstLeqK AstTensor ms PrimalSpan (TKScalar r)
arg1 AstTensor ms PrimalSpan (TKScalar r)
arg2 -> AstVarId -> AstTensor ms PrimalSpan (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKScalar r)
arg1 Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms PrimalSpan (TKScalar r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKScalar r)
arg2
AstLeqS AstTensor ms PrimalSpan (TKS sh r)
arg1 AstTensor ms PrimalSpan (TKS sh r)
arg2 -> AstVarId -> AstTensor ms PrimalSpan (TKS sh r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKS sh r)
arg1 Bool -> Bool -> Bool
|| AstVarId -> AstTensor ms PrimalSpan (TKS sh r) -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst AstVarId
var AstTensor ms PrimalSpan (TKS sh r)
arg2
varNameInAst :: AstVarName f y -> AstTensor ms s2 y2 -> Bool
varNameInAst :: forall (f :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing)
(s2 :: AstSpanType) (y2 :: TK).
AstVarName f y -> AstTensor ms s2 y2 -> Bool
varNameInAst AstVarName f y
var = AstVarId -> AstTensor ms s2 y2 -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
AstVarId -> AstTensor ms s y -> Bool
varInAst (AstVarName f y -> AstVarId
forall (s :: AstSpanType) (y :: TK). AstVarName s y -> AstVarId
varNameToAstVarId AstVarName f y
var)
varNameInIxS :: AstVarName f y -> AstIxS ms sh -> Bool
varNameInIxS :: forall (f :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing)
(sh :: [Nat]).
AstVarName f y -> AstIxS ms sh -> Bool
varNameInIxS AstVarName f y
var = AstVarId -> AstIxS ms sh -> Bool
forall (ms :: AstMethodOfSharing) (sh :: [Nat]).
AstVarId -> AstIxS ms sh -> Bool
varInIxS (AstVarName f y -> AstVarId
forall (s :: AstSpanType) (y :: TK). AstVarName s y -> AstVarId
varNameToAstVarId AstVarName f y
var)
unsafeTotalSharingRef :: IORef Bool
{-# NOINLINE unsafeTotalSharingRef #-}
unsafeTotalSharingRef :: IORef Bool
unsafeTotalSharingRef = IO (IORef Bool) -> IORef Bool
forall a. IO a -> a
unsafePerformIO (IO (IORef Bool) -> IORef Bool) -> IO (IORef Bool) -> IORef Bool
forall a b. (a -> b) -> a -> b
$ Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
setTotalSharing :: Bool -> IO ()
setTotalSharing :: Bool -> IO ()
setTotalSharing Bool
b = IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef IORef Bool
unsafeTotalSharingRef Bool
b
astIsSmall :: Bool -> AstTensor ms s y -> Bool
astIsSmall :: forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Bool -> AstTensor ms s y -> Bool
astIsSmall Bool
_ AstVar{} = Bool
True
astIsSmall Bool
_ AstShare{} = Bool
True
astIsSmall Bool
_ AstConcreteK{} = Bool
True
astIsSmall Bool
_ (AstConcreteS Shaped sh r
a) | SNat (Rank @Nat sh) -> Int
forall (n :: Nat). SNat n -> Int
sNatValue (Shaped sh r -> SNat (Rank @Nat sh)
forall a (sh :: [Nat]). Elt a => Shaped sh a -> SNat (Rank @Nat sh)
Nested.srank Shaped sh r
a) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
True
astIsSmall Bool
lax AstTensor ms s y
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
unsafeTotalSharing <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
unsafeTotalSharingRef
return $! if | unsafeTotalSharing -> False
| lax -> astIsSmallN 50 t > 0
| otherwise -> astIsSmallN 20 t > 0
astIsSmallN :: Int -> AstTensor ms s y -> Int
astIsSmallN :: forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN Int
n AstTensor ms s y
_ | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Int
0
astIsSmallN Int
n AstTensor ms s y
t0 = case AstTensor ms s y
t0 of
AstPair AstTensor ms s y
t1 AstTensor ms s z
t2 -> Int -> AstTensor ms s z -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int -> AstTensor ms s y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s y
t1) AstTensor ms s z
t2
AstProject1 AstTensor ms s (TKProduct y z)
t -> Int -> AstTensor ms s (TKProduct y z) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s (TKProduct y z)
t
AstProject2 AstTensor ms s (TKProduct y y)
t -> Int -> AstTensor ms s (TKProduct y y) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s (TKProduct y y)
t
AstFromVector (SNat' @1) SingletonTK y
_ Vector (AstTensor ms s y)
v -> Int -> AstTensor ms s y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (AstTensor ms s y -> Int) -> AstTensor ms s y -> Int
forall a b. (a -> b) -> a -> b
$ Vector (AstTensor ms s y)
v Vector (AstTensor ms s y) -> Int -> AstTensor ms s y
forall (v :: Type -> Type) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
V.! Int
0
AstSum (SNat' @1) SingletonTK y
_ AstTensor ms s (BuildTensorKind k y)
v -> Int -> AstTensor ms s (BuildTensorKind 1 y) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s (BuildTensorKind k y)
AstTensor ms s (BuildTensorKind 1 y)
v
AstReplicate SNat k
_ SingletonTK y
_ AstTensor ms s y
v ->
Int -> AstTensor ms s y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s y
v
AstVar{} -> Int
n
AstCond AstBool ms
b AstTensor ms s y
u AstTensor ms s y
v -> Int -> AstTensor ms s y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int -> AstTensor ms s y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int -> AstBool ms -> Int
forall (ms :: AstMethodOfSharing). Int -> AstBool ms -> Int
astBoolIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstBool ms
b) AstTensor ms s y
u) AstTensor ms s y
v
AstConcreteK r
_ -> Int
n
AstConcreteS Shaped sh r
_ -> Int
n
AstShare{} -> Int
n
AstPrimalPart AstTensor ms FullSpan y
v -> Int -> AstTensor ms FullSpan y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms FullSpan y
v
AstDualPart AstTensor ms FullSpan y
v -> Int -> AstTensor ms FullSpan y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms FullSpan y
v
AstFromPrimal AstTensor ms PrimalSpan y
v -> Int -> AstTensor ms PrimalSpan y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms PrimalSpan y
v
AstFromDual AstTensor ms DualSpan y
v -> Int -> AstTensor ms DualSpan y -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms DualSpan y
v
AstIotaS{} -> Int
n
AstSliceS SNat i
_ SNat n
_ SNat k
_ AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v ->
if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
20 then Int
0 else Int -> AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s (TKS2 ((':) @Nat ((i + n) + k) sh) x)
v
AstReverseS AstTensor ms s (TKS2 ((':) @Nat n sh) x)
v ->
Int -> AstTensor ms s (TKS2 ((':) @Nat n sh) x) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s (TKS2 ((':) @Nat n sh) x)
v
AstTransposeS Perm perm
_perm AstTensor ms s (TKS2 sh x)
v ->
if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
20 then Int
0 else Int -> AstTensor ms s (TKS2 sh x) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s (TKS2 sh x)
v
AstConvert TKConversion a1 y
_ AstTensor ms s a1
v -> Int -> AstTensor ms s a1 -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms s a1
v
AstTensor ms s y
_ -> Int
0
astBoolIsSmallN :: Int -> AstBool ms -> Int
astBoolIsSmallN :: forall (ms :: AstMethodOfSharing). Int -> AstBool ms -> Int
astBoolIsSmallN Int
n AstBool ms
_ | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Int
0
astBoolIsSmallN Int
n AstBool ms
t0 = case AstBool ms
t0 of
AstBoolConst{} -> Int
n
AstBoolNot AstBool ms
v -> Int -> AstBool ms -> Int
forall (ms :: AstMethodOfSharing). Int -> AstBool ms -> Int
astBoolIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstBool ms
v
AstBoolAnd AstBool ms
u AstBool ms
v -> Int -> AstBool ms -> Int
forall (ms :: AstMethodOfSharing). Int -> AstBool ms -> Int
astBoolIsSmallN (Int -> AstBool ms -> Int
forall (ms :: AstMethodOfSharing). Int -> AstBool ms -> Int
astBoolIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstBool ms
u) AstBool ms
v
AstLeqK AstTensor ms PrimalSpan (TKScalar r)
u AstTensor ms PrimalSpan (TKScalar r)
v -> Int -> AstTensor ms PrimalSpan (TKScalar r) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int -> AstTensor ms PrimalSpan (TKScalar r) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms PrimalSpan (TKScalar r)
u) AstTensor ms PrimalSpan (TKScalar r)
v
AstLeqS AstTensor ms PrimalSpan (TKS sh r)
u AstTensor ms PrimalSpan (TKS sh r)
v -> Int -> AstTensor ms PrimalSpan (TKS sh r) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int -> AstTensor ms PrimalSpan (TKS sh r) -> Int
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Int -> AstTensor ms s y -> Int
astIsSmallN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) AstTensor ms PrimalSpan (TKS sh r)
u) AstTensor ms PrimalSpan (TKS sh r)
v
ixIsSmall :: AstIxS ms sh -> Bool
ixIsSmall :: forall (ms :: AstMethodOfSharing) (sh :: [Nat]).
AstIxS ms sh -> Bool
ixIsSmall = (AstInt ms -> Bool) -> IxS sh (AstInt ms) -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Bool -> AstInt ms -> Bool
forall (ms :: AstMethodOfSharing) (s :: AstSpanType) (y :: TK).
Bool -> AstTensor ms s y -> Bool
astIsSmall Bool
True)
bounds :: GoodScalar r => AstTensor ms s (TKScalar r) -> (r, r)
bounds :: forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds (AstConcreteK r
u) = (r
r
u, r
r
u)
bounds (AstVar AstVarName s (TKScalar r)
var) = case AstVarName s (TKScalar r) -> Maybe (Int64, Int64)
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> Maybe (Int64, Int64)
varNameToBounds AstVarName s (TKScalar r)
var of
Maybe (Int64, Int64)
Nothing -> (-r
1000000000, r
1000000000)
Just (Int64
u1, Int64
u2) -> (Int64 -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
u1, Int64 -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
u2)
bounds (AstFromPrimal AstTensor ms PrimalSpan (TKScalar r)
u) = AstTensor ms PrimalSpan (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms PrimalSpan (TKScalar r)
u
bounds (AstPrimalPart AstTensor ms FullSpan (TKScalar r)
u) = AstTensor ms FullSpan (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms FullSpan (TKScalar r)
u
bounds (AstCond AstBool ms
_b AstTensor ms s (TKScalar r)
u AstTensor ms s (TKScalar r)
v) = let (r
u1, r
u2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
u
(r
v1, r
v2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
v
in (r -> r -> r
forall a. Ord a => a -> a -> a
min r
u1 r
v1, r -> r -> r
forall a. Ord a => a -> a -> a
max r
u2 r
v2)
bounds (AstLet AstVarName s y
_ AstTensor AstMethodLet s y
_ AstTensor AstMethodLet s (TKScalar r)
u) = AstTensor AstMethodLet s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor AstMethodLet s (TKScalar r)
u
bounds (AstPlusK AstTensor ms s (TKScalar r)
u AstTensor ms s (TKScalar r)
v) = let (r
u1, r
u2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
u
(r
v1, r
v2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
v
in (r
r
u1 r -> r -> r
forall a. Num a => a -> a -> a
+ r
r
v1, r
r
u2 r -> r -> r
forall a. Num a => a -> a -> a
+ r
r
v2)
bounds (AstN1K OpCodeNum1
NegateOp AstTensor ms s (TKScalar r)
u) = let (r
u1, r
u2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
u in (- r
r
u2, - r
r
u1)
bounds (AstTimesK AstTensor ms s (TKScalar r)
u AstTensor ms s (TKScalar r)
v) =
let (r
u1, r
u2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
u
(r
v1, r
v2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
v
l :: [r]
l = [r
u1 r -> r -> r
forall a. Num a => a -> a -> a
* r
v1, r
u1 r -> r -> r
forall a. Num a => a -> a -> a
* r
v2, r
u2 r -> r -> r
forall a. Num a => a -> a -> a
* r
v1, r
u2 r -> r -> r
forall a. Num a => a -> a -> a
* r
v2]
in ([r] -> r
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
minimum [r]
[r]
l, [r] -> r
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum [r]
[r]
l)
bounds (AstI2K OpCodeIntegral2
QuotOp AstTensor ms s (TKScalar r)
u (AstConcreteK r
v)) | r
v r -> r -> Bool
forall a. Ord a => a -> a -> Bool
> r
0 =
let (r
u1, r
u2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
u
in (r
r
u1 r -> r -> r
forall a. IntegralH a => a -> a -> a
`quotH` r
r
v, r
r
u2 r -> r -> r
forall a. IntegralH a => a -> a -> a
`quotH` r
r
v)
bounds (AstI2K OpCodeIntegral2
RemOp AstTensor ms s (TKScalar r)
u (AstConcreteK r
v)) | r
v r -> r -> Bool
forall a. Ord a => a -> a -> Bool
> r
0 =
let (r
u1, r
u2) = AstTensor ms s (TKScalar r) -> (r, r)
forall r (ms :: AstMethodOfSharing) (s :: AstSpanType).
GoodScalar r =>
AstTensor ms s (TKScalar r) -> (r, r)
bounds AstTensor ms s (TKScalar r)
u
in if | r
u1 r -> r -> Bool
forall a. Ord a => a -> a -> Bool
>= r
0 -> (r
0, r -> r -> r
forall a. Ord a => a -> a -> a
min r
r
u2 (r
r
v r -> r -> r
forall a. Num a => a -> a -> a
- r
1))
| r
u2 r -> r -> Bool
forall a. Ord a => a -> a -> Bool
<= r
0 -> (r -> r -> r
forall a. Ord a => a -> a -> a
max r
r
u1 (- r
r
v r -> r -> r
forall a. Num a => a -> a -> a
+ r
1), r
0)
| Bool
otherwise -> (- r
r
v r -> r -> r
forall a. Num a => a -> a -> a
+ r
1, r
r
v r -> r -> r
forall a. Num a => a -> a -> a
- r
1)
bounds AstTensor ms s (TKScalar r)
_ = (-r
1000000000, r
1000000000)
liftRFromS1 :: forall n x ms s.
(forall sh.
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKR2 n x)
-> AstTensor ms s (TKR2 n x)
liftRFromS1 :: forall (n :: Nat) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
(forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKR2 n x) -> AstTensor ms s (TKR2 n x)
liftRFromS1 forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f AstTensor ms s (TKR2 n x)
a = case AstTensor ms s (TKR2 n x) -> FullShapeTK (TKR2 n x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKR2 n x)
a of
ftk :: FullShapeTK (TKR2 n x)
ftk@(FTKR IShR n
sh' FullShapeTK x
_) ->
IShR n
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKR2 n x)
forall (n :: Nat) r.
IShR n
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> r)
-> r
withShsFromShR IShR n
sh' ((forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKR2 n x))
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
FullShapeTK z -> AstTensor ms s y -> AstTensor ms s z
cAstFromS @(TKS2 sh x) FullShapeTK (TKR2 n x)
ftk
(AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f (forall (sh :: [Nat]) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
ShS sh
-> AstTensor ms s (TKR2 (Rank @Nat sh) x)
-> AstTensor ms s (TKS2 sh x)
cAstSFromR @sh ShS sh
sh AstTensor ms s (TKR2 n x)
AstTensor ms s (TKR2 (Rank @Nat sh) x)
a)
liftRFromS2 :: forall n x ms s.
(forall sh.
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKR2 n x) -> AstTensor ms s (TKR2 n x)
-> AstTensor ms s (TKR2 n x)
liftRFromS2 :: forall (n :: Nat) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
(forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKR2 n x)
-> AstTensor ms s (TKR2 n x)
-> AstTensor ms s (TKR2 n x)
liftRFromS2 forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f AstTensor ms s (TKR2 n x)
a AstTensor ms s (TKR2 n x)
b = case AstTensor ms s (TKR2 n x) -> FullShapeTK (TKR2 n x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKR2 n x)
a of
ftk :: FullShapeTK (TKR2 n x)
ftk@(FTKR IShR n
sh' FullShapeTK x
_) ->
IShR n
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKR2 n x)
forall (n :: Nat) r.
IShR n
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> r)
-> r
withShsFromShR IShR n
sh' ((forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKR2 n x))
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
ShS sh -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
FullShapeTK z -> AstTensor ms s y -> AstTensor ms s z
cAstFromS @(TKS2 sh x) FullShapeTK (TKR2 n x)
ftk
(AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKR2 n x))
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f (forall (sh :: [Nat]) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
ShS sh
-> AstTensor ms s (TKR2 (Rank @Nat sh) x)
-> AstTensor ms s (TKS2 sh x)
cAstSFromR @sh ShS sh
sh AstTensor ms s (TKR2 n x)
AstTensor ms s (TKR2 (Rank @Nat sh) x)
a) (forall (sh :: [Nat]) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
ShS sh
-> AstTensor ms s (TKR2 (Rank @Nat sh) x)
-> AstTensor ms s (TKS2 sh x)
cAstSFromR @sh ShS sh
sh AstTensor ms s (TKR2 n x)
AstTensor ms s (TKR2 (Rank @Nat sh) x)
b)
liftXFromS1 :: forall sh' x ms s.
(forall sh.
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKX2 sh' x)
-> AstTensor ms s (TKX2 sh' x)
liftXFromS1 :: forall (sh' :: [Maybe Nat]) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
(forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKX2 sh' x)
liftXFromS1 forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f AstTensor ms s (TKX2 sh' x)
a = case AstTensor ms s (TKX2 sh' x) -> FullShapeTK (TKX2 sh' x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKX2 sh' x)
a of
ftk :: FullShapeTK (TKX2 sh' x)
ftk@(FTKX IShX sh
sh' FullShapeTK x
_) ->
IShX sh
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKX2 sh' x)
forall (sh' :: [Maybe Nat]) r.
IShX sh'
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> r)
-> r
withShsFromShX IShX sh
sh' ((forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKX2 sh' x))
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
FullShapeTK z -> AstTensor ms s y -> AstTensor ms s z
cAstFromS @(TKS2 sh x) FullShapeTK (TKX2 sh' x)
ftk
(AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$ AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f (forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(ms :: AstMethodOfSharing) (s :: AstSpanType).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKS2 sh x)
cAstSFromX @sh @sh' ShS sh
sh AstTensor ms s (TKX2 sh' x)
a)
liftXFromS2 :: forall sh' x ms s.
(forall sh.
AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKX2 sh' x)
-> AstTensor ms s (TKX2 sh' x)
liftXFromS2 :: forall (sh' :: [Maybe Nat]) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
(forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x))
-> AstTensor ms s (TKX2 sh' x)
-> AstTensor ms s (TKX2 sh' x)
-> AstTensor ms s (TKX2 sh' x)
liftXFromS2 forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f AstTensor ms s (TKX2 sh' x)
a AstTensor ms s (TKX2 sh' x)
b = case AstTensor ms s (TKX2 sh' x) -> FullShapeTK (TKX2 sh' x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKX2 sh' x)
a of
ftk :: FullShapeTK (TKX2 sh' x)
ftk@(FTKX IShX sh
sh' FullShapeTK x
_) ->
IShX sh
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKX2 sh' x)
forall (sh' :: [Maybe Nat]) r.
IShX sh'
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> r)
-> r
withShsFromShX IShX sh
sh' ((forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKX2 sh' x))
-> (forall (sh :: [Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
FullShapeTK z -> AstTensor ms s y -> AstTensor ms s z
cAstFromS @(TKS2 sh x) FullShapeTK (TKX2 sh' x)
ftk
(AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKX2 sh' x))
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$ AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
forall (sh :: [Nat]).
AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKS2 sh x)
f (forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(ms :: AstMethodOfSharing) (s :: AstSpanType).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKS2 sh x)
cAstSFromX @sh @sh' ShS sh
sh AstTensor ms s (TKX2 sh' x)
a) (forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(ms :: AstMethodOfSharing) (s :: AstSpanType).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKS2 sh x)
cAstSFromX @sh @sh' ShS sh
sh AstTensor ms s (TKX2 sh' x)
b)
cAstConvert :: TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert :: forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert TKConversion x z
c AstTensor ms s x
t
| Just (:~:) @TK x z
Refl <- FullShapeTK x -> FullShapeTK z -> Maybe ((:~:) @TK x z)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK (AstTensor ms s x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s x
t) (TKConversion x z -> FullShapeTK x -> FullShapeTK z
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion x z
c (AstTensor ms s x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s x
t)) = AstTensor ms s x
AstTensor ms s z
t
cAstConvert TKConversion x z
c1 (AstConvert TKConversion a1 x
c2 AstTensor ms s a1
t2) = TKConversion a1 z -> AstTensor ms s a1 -> AstTensor ms s z
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert (TKConversion x z -> TKConversion a1 x -> TKConversion a1 z
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp TKConversion x z
c1 TKConversion a1 x
c2) AstTensor ms s a1
t2
cAstConvert TKConversion x z
c1 (AstFromPrimal (AstConvert TKConversion a1 x
c2 AstTensor ms PrimalSpan a1
t2)) =
AstTensor ms PrimalSpan z -> AstTensor ms FullSpan z
forall (c :: TK) (a :: AstMethodOfSharing).
AstTensor a PrimalSpan c -> AstTensor a FullSpan c
AstFromPrimal (TKConversion a1 z
-> AstTensor ms PrimalSpan a1 -> AstTensor ms PrimalSpan z
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert (TKConversion x z -> TKConversion a1 x -> TKConversion a1 z
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp TKConversion x z
c1 TKConversion a1 x
c2) AstTensor ms PrimalSpan a1
t2)
cAstConvert TKConversion x z
c AstTensor ms s x
t = TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
AstConvert TKConversion x z
c AstTensor ms s x
t
cAstSFromR :: forall sh x ms s.
ShS sh -> AstTensor ms s (TKR2 (Rank sh) x)
-> AstTensor ms s (TKS2 sh x)
cAstSFromR :: forall (sh :: [Nat]) (x :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
ShS sh
-> AstTensor ms s (TKR2 (Rank @Nat sh) x)
-> AstTensor ms s (TKS2 sh x)
cAstSFromR ShS sh
sh AstTensor ms s (TKR2 (Rank @Nat sh) x)
v = case AstTensor ms s (TKR2 (Rank @Nat sh) x)
-> FullShapeTK (TKR2 (Rank @Nat sh) x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKR2 (Rank @Nat sh) x)
v of
FTKR IShR n
_ FullShapeTK x
x | (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
Refl <- Proxy @Nat n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @(Rank sh)) ->
let c2 :: TKConversion (TKR2 n x) (TKS2 sh x)
c2 = TKConversion
(TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) (TKS2 sh x)
-> TKConversion
(TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
-> TKConversion (TKR2 n x) (TKS2 sh x)
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKS2 sh x)
-> TKConversion
(TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) (TKS2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat sh' :: Nat)) =>
FullShapeTK (TKS2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKS2 sh' a1)
ConvXS' (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
x)) TKConversion
(TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (n :: Nat) (a1 :: TK).
TKConversion
(TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX
in TKConversion (TKR2 n x) (TKS2 sh x)
-> AstTensor ms s (TKR2 n x) -> AstTensor ms s (TKS2 sh x)
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert TKConversion (TKR2 n x) (TKS2 sh x)
TKConversion (TKR2 n x) (TKS2 sh x)
c2 AstTensor ms s (TKR2 n x)
AstTensor ms s (TKR2 (Rank @Nat sh) x)
v
cAstSFromX :: forall sh sh' x ms s. Rank sh ~ Rank sh'
=> ShS sh -> AstTensor ms s (TKX2 sh' x)
-> AstTensor ms s (TKS2 sh x)
cAstSFromX :: forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(ms :: AstMethodOfSharing) (s :: AstSpanType).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKS2 sh x)
cAstSFromX ShS sh
sh AstTensor ms s (TKX2 sh' x)
v = case AstTensor ms s (TKX2 sh' x) -> FullShapeTK (TKX2 sh' x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKX2 sh' x)
v of
FTKX IShX sh
_ FullShapeTK x
x -> let c2 :: TKConversion (TKX2 sh' x) (TKS2 sh x)
c2 = FullShapeTK (TKS2 sh x) -> TKConversion (TKX2 sh' x) (TKS2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat sh' :: Nat)) =>
FullShapeTK (TKS2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKS2 sh' a1)
ConvXS' (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
x)
in TKConversion (TKX2 sh' x) (TKS2 sh x)
-> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKS2 sh x)
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert TKConversion (TKX2 sh' x) (TKS2 sh x)
TKConversion (TKX2 sh' x) (TKS2 sh x)
c2 AstTensor ms s (TKX2 sh' x)
AstTensor ms s (TKX2 sh' x)
v
cAstXFromS :: forall sh sh' x ms s. Rank sh ~ Rank sh'
=> StaticShX sh' -> AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKX2 sh' x)
cAstXFromS :: forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(ms :: AstMethodOfSharing) (s :: AstSpanType).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
StaticShX sh'
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKX2 sh' x)
cAstXFromS StaticShX sh'
ssx AstTensor ms s (TKS2 sh x)
v
| FTKS ShS sh
sh FullShapeTK x
x <- AstTensor ms s (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s (TKS2 sh x)
v
, let shx :: IShX sh'
shx = StaticShX sh' -> ShS sh -> IShX sh'
forall (sh :: [Nat]) (sh' :: [Maybe Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
StaticShX sh' -> ShS sh -> IShX sh'
shCastSX StaticShX sh'
ssx ShS sh
sh
, (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
Refl <- ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
forall (sh :: [Nat]).
ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
lemRankMapJust ShS sh
sh =
let c2 :: TKConversion (TKS2 sh x) (TKX2 sh' x)
c2 = TKConversion (TKX2 (MapJust @Nat sh) x) (TKX2 sh' x)
-> TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
-> TKConversion (TKS2 sh x) (TKX2 sh' x)
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKX2 sh' x)
-> TKConversion (TKX2 (MapJust @Nat sh) x) (TKX2 sh' x)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
FullShapeTK (TKX2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKX2 sh' a1)
ConvXX' (IShX sh' -> FullShapeTK x -> FullShapeTK (TKX2 sh' x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh'
shx FullShapeTK x
x)) TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
in TKConversion (TKS2 sh x) (TKX2 sh' x)
-> AstTensor ms s (TKS2 sh x) -> AstTensor ms s (TKX2 sh' x)
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert TKConversion (TKS2 sh x) (TKX2 sh' x)
TKConversion (TKS2 sh x) (TKX2 sh' x)
c2 AstTensor ms s (TKS2 sh x)
AstTensor ms s (TKS2 sh x)
v
pattern AstFromS' :: forall {z1} {ms1} {s1}.
forall y z ms s. (z ~ z1, ms ~ ms1, s ~ s1)
=> FullShapeTK z -> AstTensor ms s y
-> AstTensor ms1 s1 z1
pattern $mAstFromS' :: forall {r} {z1 :: TK} {ms1 :: AstMethodOfSharing}
{s1 :: AstSpanType}.
AstTensor ms1 s1 z1
-> (forall {y :: TK} {z :: TK} {ms :: AstMethodOfSharing}
{s :: AstSpanType}.
((z :: TK) ~ (z1 :: TK),
(ms :: AstMethodOfSharing) ~ (ms1 :: AstMethodOfSharing),
(s :: AstSpanType) ~ (s1 :: AstSpanType)) =>
FullShapeTK z -> AstTensor ms s y -> r)
-> ((# #) -> r)
-> r
AstFromS' zftk a <-
AstConvert c (checkPatternAstFromS c -> Just (zftk, a))
checkPatternAstFromS :: TKConversion y z -> AstTensor ms s y
-> Maybe (FullShapeTK z, AstTensor ms s y)
checkPatternAstFromS :: forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion y z
-> AstTensor ms s y -> Maybe (FullShapeTK z, AstTensor ms s y)
checkPatternAstFromS TKConversion y z
c AstTensor ms s y
t =
let zftk :: FullShapeTK z
zftk = TKConversion y z -> FullShapeTK y -> FullShapeTK z
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion y z
c (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
t)
in if FullShapeTK y -> FullShapeTK z -> Bool
forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstFromS (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
t) FullShapeTK z
zftk then (FullShapeTK z, AstTensor ms s y)
-> Maybe (FullShapeTK z, AstTensor ms s y)
forall a. a -> Maybe a
Just (FullShapeTK z
zftk, AstTensor ms s y
t) else Maybe (FullShapeTK z, AstTensor ms s y)
forall a. Maybe a
Nothing
checkAstFromS :: TKConversion a b -> AstTensor ms s a -> Bool
checkAstFromS :: forall (a :: TK) (b :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion a b -> AstTensor ms s a -> Bool
checkAstFromS TKConversion a b
c AstTensor ms s a
t = Maybe (FullShapeTK b, AstTensor ms s a) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (FullShapeTK b, AstTensor ms s a) -> Bool)
-> Maybe (FullShapeTK b, AstTensor ms s a) -> Bool
forall a b. (a -> b) -> a -> b
$ TKConversion a b
-> AstTensor ms s a -> Maybe (FullShapeTK b, AstTensor ms s a)
forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion y z
-> AstTensor ms s y -> Maybe (FullShapeTK z, AstTensor ms s y)
checkPatternAstFromS TKConversion a b
c AstTensor ms s a
t
checkFtkAstFromS :: FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstFromS :: forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstFromS FullShapeTK y
yftk FullShapeTK z
zftk | Just (:~:) @TK y z
Refl <- FullShapeTK y -> FullShapeTK z -> Maybe ((:~:) @TK y z)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y
yftk FullShapeTK z
zftk = Bool
True
checkFtkAstFromS FTKS{} FTKS{} = Bool
False
checkFtkAstFromS FTKS{} FullShapeTK z
_ = Bool
True
checkFtkAstFromS (FTKProduct FullShapeTK y1
yftk1 FullShapeTK z
yftk2) (FTKProduct FullShapeTK y1
zftk1 FullShapeTK z
zftk2) =
FullShapeTK y1 -> FullShapeTK y1 -> Bool
forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstFromS FullShapeTK y1
yftk1 FullShapeTK y1
zftk1 Bool -> Bool -> Bool
&& FullShapeTK z -> FullShapeTK z -> Bool
forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstFromS FullShapeTK z
yftk2 FullShapeTK z
zftk2
checkFtkAstFromS FullShapeTK y
_ FullShapeTK z
_ = Bool
False
checkFtkAstSFrom :: FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstSFrom :: forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstSFrom FullShapeTK y
yftk FullShapeTK z
zftk | Just (:~:) @TK y z
Refl <- FullShapeTK y -> FullShapeTK z -> Maybe ((:~:) @TK y z)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y
yftk FullShapeTK z
zftk = Bool
True
checkFtkAstSFrom FTKS{} FTKS{} = Bool
False
checkFtkAstSFrom FullShapeTK y
_ FTKS{} = Bool
True
checkFtkAstSFrom (FTKProduct FullShapeTK y1
yftk1 FullShapeTK z
yftk2) (FTKProduct FullShapeTK y1
zftk1 FullShapeTK z
zftk2) =
FullShapeTK y1 -> FullShapeTK y1 -> Bool
forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstSFrom FullShapeTK y1
yftk1 FullShapeTK y1
zftk1 Bool -> Bool -> Bool
&& FullShapeTK z -> FullShapeTK z -> Bool
forall (y :: TK) (z :: TK). FullShapeTK y -> FullShapeTK z -> Bool
checkFtkAstSFrom FullShapeTK z
yftk2 FullShapeTK z
zftk2
checkFtkAstSFrom FullShapeTK y
_ FullShapeTK z
_ = Bool
False
cAstFromS :: forall y z ms s.
FullShapeTK z -> AstTensor ms s y
-> AstTensor ms s z
cAstFromS :: forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
FullShapeTK z -> AstTensor ms s y -> AstTensor ms s z
cAstFromS FullShapeTK z
zftk AstTensor ms s y
t = TKConversion y z -> AstTensor ms s y -> AstTensor ms s z
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert (FullShapeTK y -> FullShapeTK z -> TKConversion y z
forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> FullShapeTK z0 -> TKConversion y0 z0
convFromS (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
t) FullShapeTK z
zftk) AstTensor ms s y
t
cAstSFrom :: forall y z ms s.
FullShapeTK z -> AstTensor ms s y
-> AstTensor ms s z
cAstSFrom :: forall (y :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
FullShapeTK z -> AstTensor ms s y -> AstTensor ms s z
cAstSFrom FullShapeTK z
zftk AstTensor ms s y
t = TKConversion y z -> AstTensor ms s y -> AstTensor ms s z
forall (x :: TK) (z :: TK) (ms :: AstMethodOfSharing)
(s :: AstSpanType).
TKConversion x z -> AstTensor ms s x -> AstTensor ms s z
cAstConvert (FullShapeTK y -> SingletonTK z -> TKConversion y z
forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> SingletonTK z0 -> TKConversion y0 z0
convSFrom (AstTensor ms s y -> FullShapeTK y
forall (s :: AstSpanType) (y :: TK) (ms :: AstMethodOfSharing).
AstTensor ms s y -> FullShapeTK y
ftkAst AstTensor ms s y
t) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
zftk)) AstTensor ms s y
t
convFromS :: FullShapeTK y0 -> FullShapeTK z0 -> TKConversion y0 z0
convFromS :: forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> FullShapeTK z0 -> TKConversion y0 z0
convFromS FullShapeTK y0
yftk0 FullShapeTK z0
zftk0 = case (FullShapeTK y0
yftk0, FullShapeTK z0
zftk0) of
(FullShapeTK y0, FullShapeTK z0)
_ | Just (:~:) @TK y0 z0
Refl <- FullShapeTK y0 -> FullShapeTK z0 -> Maybe ((:~:) @TK y0 z0)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y0
yftk0 FullShapeTK z0
zftk0 -> TKConversion y0 y0
TKConversion y0 z0
forall (a :: TK). TKConversion a a
ConvId
(FTKS ShS sh
ZSS (FTKScalar @ry), FTKScalar @rz)
| Just (:~:) @Type r r
Refl <- TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @ry) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @rz) ->
TKConversion (TKX2 ('[] @(Maybe Nat)) z0) z0
-> TKConversion y0 (TKX2 ('[] @(Maybe Nat)) z0)
-> TKConversion y0 z0
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp TKConversion (TKX2 ('[] @(Maybe Nat)) z0) z0
forall (b :: TK). TKConversion (TKX2 ('[] @(Maybe Nat)) b) b
ConvX0 TKConversion y0 (TKX2 ('[] @(Maybe Nat)) z0)
TKConversion
(TKS2 ('[] @Nat) z0) (TKX2 (MapJust @Nat ('[] @Nat)) z0)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
(FTKS ShS sh
sh FullShapeTK x
x, FTKR IShR n
rsh FullShapeTK x
rx)
| Just (:~:) @TK x x
Refl <- FullShapeTK x -> FullShapeTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK x
x FullShapeTK x
rx
, Just (:~:) @Nat (Rank @Nat sh) n
Refl <- SNat (Rank @Nat sh)
-> SNat n -> Maybe ((:~:) @Nat (Rank @Nat sh) n)
forall (a :: Nat) (b :: Nat).
SNat a -> SNat b -> Maybe ((:~:) @Nat a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (ShS sh -> SNat (Rank @Nat sh)
forall (sh :: [Nat]). ShS sh -> SNat (Rank @Nat sh)
shsRank ShS sh
sh) (IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
rsh)
, (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
Refl <- ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
forall (sh :: [Nat]).
ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
lemRankMapJust ShS sh
sh ->
TKConversion (TKX2 (MapJust @Nat sh) x) z0
-> TKConversion y0 (TKX2 (MapJust @Nat sh) x) -> TKConversion y0 z0
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (SingletonTK x
-> TKConversion
(TKX2 (MapJust @Nat sh) x)
(TKR2 (Rank @(Maybe Nat) (MapJust @Nat sh)) x)
forall (a1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> TKConversion (TKX2 sh a1) (TKR2 (Rank @(Maybe Nat) sh) a1)
ConvXR (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)) TKConversion y0 (TKX2 (MapJust @Nat sh) x)
TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
(FTKS ShS sh
sh FullShapeTK x
x, FTKX IShX sh
xsh FullShapeTK x
xx)
| Just (:~:) @TK x x
Refl <- FullShapeTK x -> FullShapeTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK x
x FullShapeTK x
xx
, Just (:~:) @Nat (Rank @Nat sh) (Rank @(Maybe Nat) sh)
Refl <- SNat (Rank @Nat sh)
-> SNat (Rank @(Maybe Nat) sh)
-> Maybe ((:~:) @Nat (Rank @Nat sh) (Rank @(Maybe Nat) sh))
forall (a :: Nat) (b :: Nat).
SNat a -> SNat b -> Maybe ((:~:) @Nat a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (ShS sh -> SNat (Rank @Nat sh)
forall (sh :: [Nat]). ShS sh -> SNat (Rank @Nat sh)
shsRank ShS sh
sh) (IShX sh -> SNat (Rank @(Maybe Nat) sh)
forall (sh :: [Maybe Nat]) i.
ShX sh i -> SNat (Rank @(Maybe Nat) sh)
shxRank IShX sh
xsh)
, (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
Refl <- ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
forall (sh :: [Nat]).
ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
lemRankMapJust ShS sh
sh ->
TKConversion (TKX2 (MapJust @Nat sh) x) z0
-> TKConversion y0 (TKX2 (MapJust @Nat sh) x) -> TKConversion y0 z0
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKX2 sh x)
-> TKConversion (TKX2 (MapJust @Nat sh) x) (TKX2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
FullShapeTK (TKX2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKX2 sh' a1)
ConvXX' FullShapeTK z0
FullShapeTK (TKX2 sh x)
zftk0) TKConversion y0 (TKX2 (MapJust @Nat sh) x)
TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
(FTKProduct FullShapeTK y1
yftk1 FullShapeTK z
yftk2, FTKProduct FullShapeTK y1
zftk1 FullShapeTK z
zftk2) ->
TKConversion y1 y1
-> TKConversion z z
-> TKConversion (TKProduct y1 z) (TKProduct y1 z)
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 (FullShapeTK y1 -> FullShapeTK y1 -> TKConversion y1 y1
forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> FullShapeTK z0 -> TKConversion y0 z0
convFromS FullShapeTK y1
yftk1 FullShapeTK y1
zftk1) (FullShapeTK z -> FullShapeTK z -> TKConversion z z
forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> FullShapeTK z0 -> TKConversion y0 z0
convFromS FullShapeTK z
yftk2 FullShapeTK z
zftk2)
(FullShapeTK y0, FullShapeTK z0)
_ -> [Char] -> TKConversion y0 z0
forall a. HasCallStack => [Char] -> a
error ([Char] -> TKConversion y0 z0) -> [Char] -> TKConversion y0 z0
forall a b. (a -> b) -> a -> b
$ [Char]
"convFromS': unexpected types "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"(" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ FullShapeTK y0 -> [Char]
forall a. Show a => a -> [Char]
show FullShapeTK y0
yftk0 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ FullShapeTK z0 -> [Char]
forall a. Show a => a -> [Char]
show FullShapeTK z0
zftk0 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
")"
convSFrom :: FullShapeTK y0 -> SingletonTK z0 -> TKConversion y0 z0
convSFrom :: forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> SingletonTK z0 -> TKConversion y0 z0
convSFrom FullShapeTK y0
yftk0 SingletonTK z0
zstk0 = case (SingletonTK z0
zstk0, FullShapeTK y0
yftk0) of
(SingletonTK z0, FullShapeTK y0)
_ | Just (:~:) @TK y0 z0
Refl <- SingletonTK y0 -> SingletonTK z0 -> Maybe ((:~:) @TK y0 z0)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK (FullShapeTK y0 -> SingletonTK y0
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y0
yftk0) SingletonTK z0
zstk0 -> TKConversion y0 y0
TKConversion y0 z0
forall (a :: TK). TKConversion a a
ConvId
(STKS ShS sh
ZSS (STKScalar @ry), FTKScalar @rz)
| Just (:~:) @Type r r
Refl <- TypeRep @Type r -> TypeRep @Type r -> Maybe ((:~:) @Type r r)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @ry) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @rz) ->
TKConversion (TKX2 ('[] @(Maybe Nat)) y0) z0
-> TKConversion y0 (TKX2 ('[] @(Maybe Nat)) y0)
-> TKConversion y0 z0
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp TKConversion (TKX2 ('[] @(Maybe Nat)) y0) z0
TKConversion
(TKX2 (MapJust @Nat ('[] @Nat)) y0) (TKS2 ('[] @Nat) y0)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
ConvXS (SingletonTK y0 -> TKConversion y0 (TKX2 ('[] @(Maybe Nat)) y0)
forall (a :: TK).
SingletonTK a -> TKConversion a (TKX2 ('[] @(Maybe Nat)) a)
Conv0X SingletonTK y0
SingletonTK (TKScalar r)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar)
(STKS @sh ShS sh
sh SingletonTK x
x, FTKR IShR n
rsh FullShapeTK x
rx)
| Just (:~:) @TK x x
Refl <- SingletonTK x -> SingletonTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK x
x (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
rx)
, Just (:~:) @Nat (Rank @Nat sh) n
Refl <- SNat (Rank @Nat sh)
-> SNat n -> Maybe ((:~:) @Nat (Rank @Nat sh) n)
forall (a :: Nat) (b :: Nat).
SNat a -> SNat b -> Maybe ((:~:) @Nat a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (ShS sh -> SNat (Rank @Nat sh)
forall (sh :: [Nat]). ShS sh -> SNat (Rank @Nat sh)
shsRank ShS sh
sh) (IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
rsh)
, (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
Refl <- Proxy @Nat n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @(Rank sh)) ->
TKConversion (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) z0
-> TKConversion
y0 (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
-> TKConversion y0 z0
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKS2 sh x)
-> TKConversion
(TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) (TKS2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat sh' :: Nat)) =>
FullShapeTK (TKS2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKS2 sh' a1)
ConvXS' (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
rx)) TKConversion y0 (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
TKConversion
(TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (n :: Nat) (a1 :: TK).
TKConversion
(TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX
(STKS ShS sh
sh SingletonTK x
x, FTKX IShX sh
xsh FullShapeTK x
xx)
| Just (:~:) @TK x x
Refl <- SingletonTK x -> SingletonTK x -> Maybe ((:~:) @TK x x)
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK SingletonTK x
x (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xx)
, Just (:~:) @Nat (Rank @Nat sh) (Rank @(Maybe Nat) sh)
Refl <- SNat (Rank @Nat sh)
-> SNat (Rank @(Maybe Nat) sh)
-> Maybe ((:~:) @Nat (Rank @Nat sh) (Rank @(Maybe Nat) sh))
forall (a :: Nat) (b :: Nat).
SNat a -> SNat b -> Maybe ((:~:) @Nat a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (ShS sh -> SNat (Rank @Nat sh)
forall (sh :: [Nat]). ShS sh -> SNat (Rank @Nat sh)
shsRank ShS sh
sh) (IShX sh -> SNat (Rank @(Maybe Nat) sh)
forall (sh :: [Maybe Nat]) i.
ShX sh i -> SNat (Rank @(Maybe Nat) sh)
shxRank IShX sh
xsh)
, (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
Refl <- ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
forall (sh :: [Nat]).
ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
lemRankMapJust ShS sh
sh ->
FullShapeTK (TKS2 sh x) -> TKConversion (TKX2 sh x) (TKS2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat sh' :: Nat)) =>
FullShapeTK (TKS2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKS2 sh' a1)
ConvXS' (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
xx)
(STKProduct SingletonTK y1
zstk1 SingletonTK z
zstk2, FTKProduct FullShapeTK y1
yftk1 FullShapeTK z
yftk2) ->
TKConversion y1 y1
-> TKConversion z z
-> TKConversion (TKProduct y1 z) (TKProduct y1 z)
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 (FullShapeTK y1 -> SingletonTK y1 -> TKConversion y1 y1
forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> SingletonTK z0 -> TKConversion y0 z0
convSFrom FullShapeTK y1
yftk1 SingletonTK y1
zstk1) (FullShapeTK z -> SingletonTK z -> TKConversion z z
forall (y0 :: TK) (z0 :: TK).
FullShapeTK y0 -> SingletonTK z0 -> TKConversion y0 z0
convSFrom FullShapeTK z
yftk2 SingletonTK z
zstk2)
(SingletonTK z0, FullShapeTK y0)
_ -> [Char] -> TKConversion y0 z0
forall a. HasCallStack => [Char] -> a
error ([Char] -> TKConversion y0 z0) -> [Char] -> TKConversion y0 z0
forall a b. (a -> b) -> a -> b
$ [Char]
"convSFrom': unexpected types "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"(" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ FullShapeTK y0 -> [Char]
forall a. Show a => a -> [Char]
show FullShapeTK y0
yftk0 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SingletonTK z0 -> [Char]
forall a. Show a => a -> [Char]
show SingletonTK z0
zstk0 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
")"