{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | Adaptors for working with types of collections of tensors,
-- e.g., tuples, sized lists and user types of statically known size,
-- as long as they have the proper instances defined.
-- The collections are used as representations of the domains
-- of objective functions that become the codomains of the reverse
-- derivative functions and also to handle multiple arguments
-- and results of fold-like operations.
module HordeAd.Core.Adaptor
  ( AdaptableTarget(..), TermValue(..), DualNumberValue(..)
  , ForgetShape(..), RandomValue(..)
  , stkOfListR
    -- * Helper classes and types
  , Tups, NoShapeTensorKind
  ) where

import Prelude

import Data.Default
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (gcastWith, (:~:))
import Data.Vector.Generic qualified as V
import Data.Vector.Strict qualified as Data.Vector
import GHC.TypeLits (KnownNat, OrderingI (..), cmpNat, type (-), type (<=?))
import System.Random

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (unsafeCoerceRefl)

import HordeAd.Core.Ast
import HordeAd.Core.CarriersADVal
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.ConvertTensor
import HordeAd.Core.Ops
import HordeAd.Core.OpsAst ()
import HordeAd.Core.TensorKind
import HordeAd.Core.Types

-- * Adaptor classes

-- Inspired by adaptors from @tomjaguarpaw's branch.
--
-- | The class that makes it possible to treat @vals@ (e.g., a tuple of tensors)
-- as a @target@-based (e.g., concrete or symbolic) value
-- of tensor kind @X vals@.
class AdaptableTarget (target :: Target) vals where
  type X vals :: TK  -- ^ what tensor kind represents the collection
  toTarget :: vals -> target (X vals)
    -- ^ represent a collection of tensors
  fromTarget :: target (X vals) -> vals
    -- ^ recovers a collection of tensors from its canonical representation;
    --   requires a duplicable argument

-- | An embedding of a concrete collection of tensors to a non-concrete
-- counterpart of the same shape and containing the same data.
class TermValue vals where
  type Value vals = result | result -> vals
    -- ^ a helper type, with the same general shape,
    -- but possibly more concrete, e.g., arrays instead of terms,
    -- where the injectivity is crucial to limit the number
    -- of type applications the library user has to supply
  fromValue :: Value vals -> vals  -- ^ an embedding

-- | An embedding of a concrete collection of tensors to a non-concrete
-- counterpart of the same shape and containing the same data.
-- This variant is possible to define more often, but the associated
-- type family is not injective.
class DualNumberValue vals where
  type DValue vals
    -- ^ a helper type, with the same general shape,
    -- but possibly more concrete, e.g., arrays instead of terms,
    -- where the injectivity is hard to obtain, but is not so important,
    -- because the type is not used in the best pipeline
  fromDValue :: DValue vals -> vals  -- ^ an embedding

-- | A helper class for for converting all tensors inside a type
-- from shaped to ranked. It's useful when a collection of parameters
-- is defined as shaped tensor for 'RandomValue' but then is going
-- to be used as ranked tensor to make type reconstruction easier.
class ForgetShape vals where
  type NoShape vals
  forgetShape :: vals -> NoShape vals

-- | A helper class for randomly generating initial parameters.
-- Only instance for collections of shaped tensors and scalars are possible,
-- because only then the shapes of the tensors to generate are known
-- from their types.
class RandomValue vals where
  randomValue :: Double -> StdGen -> (vals, StdGen)


-- * Base instances

instance AdaptableTarget target (target y) where
  type X (target y) = y
  toTarget :: target y -> target (X (target y))
toTarget = target y -> target y
target y -> target (X (target y))
forall a. a -> a
id
  fromTarget :: target (X (target y)) -> target y
fromTarget target (X (target y))
t = target y
target (X (target y))
t
  {-# SPECIALIZE instance AdaptableTarget Concrete (Concrete (TKS sh Double)) #-}
  {-# SPECIALIZE instance AdaptableTarget Concrete (Concrete (TKS sh Float)) #-}
    -- a failed attempt to specialize without -fpolymorphic-specialisation

instance KnownSTK y
         => TermValue (AstTensor AstMethodLet FullSpan y) where
  type Value (AstTensor AstMethodLet FullSpan y) = Concrete y
  fromValue :: Value (AstTensor AstMethodLet FullSpan y)
-> AstTensor AstMethodLet FullSpan y
fromValue Value (AstTensor AstMethodLet FullSpan y)
t = FullShapeTK y -> Concrete y -> AstTensor AstMethodLet FullSpan y
forall (y :: TK).
FullShapeTK y -> Concrete y -> AstTensor AstMethodLet FullSpan y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK y -> RepConcrete y -> FullShapeTK y
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @y) (RepConcrete y -> FullShapeTK y) -> RepConcrete y -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$ Concrete y -> RepConcrete y
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete y
Value (AstTensor AstMethodLet FullSpan y)
t) Concrete y
Value (AstTensor AstMethodLet FullSpan y)
t

instance (BaseTensor target, BaseTensor (PrimalOf target), KnownSTK y)
         => DualNumberValue (target y) where
  type DValue (target y) = Concrete y
  fromDValue :: DValue (target y) -> target y
fromDValue DValue (target y)
t = SingletonTK y -> PrimalOf target y -> target y
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @y)
                 (PrimalOf target y -> target y) -> PrimalOf target y -> target y
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> Concrete y -> PrimalOf target y
forall (y :: TK). FullShapeTK y -> Concrete y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK y -> RepConcrete y -> FullShapeTK y
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @y) (RepConcrete y -> FullShapeTK y) -> RepConcrete y -> FullShapeTK y
forall a b. (a -> b) -> a -> b
$ Concrete y -> RepConcrete y
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete y
DValue (target y)
t) Concrete y
DValue (target y)
t

instance ForgetShape (target (TKScalar r)) where
  type NoShape (target (TKScalar r)) = target (TKScalar r)
  forgetShape :: target (TKScalar r) -> NoShape (target (TKScalar r))
forgetShape = target (TKScalar r) -> target (TKScalar r)
target (TKScalar r) -> NoShape (target (TKScalar r))
forall a. a -> a
id

instance ForgetShape (target (TKR n r)) where
  type NoShape (target (TKR n r)) = target (TKR n r)
  forgetShape :: target (TKR n r) -> NoShape (target (TKR n r))
forgetShape = target (TKR n r) -> target (TKR n r)
target (TKR n r) -> NoShape (target (TKR n r))
forall a. a -> a
id

instance (KnownShS sh, GoodScalar r, ConvertTensor target)
         => ForgetShape (target (TKS sh r)) where
  type NoShape (target (TKS sh r)) = target (TKR (Rank sh) r)
  forgetShape :: target (TKS sh r) -> NoShape (target (TKS sh r))
forgetShape = target (TKS sh r) -> target (TKR2 (Rank @Nat sh) (TKScalar r))
target (TKS sh r) -> NoShape (target (TKS sh r))
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKR2 (Rank @Nat sh) x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKR2 (Rank @Nat sh) x)
rfromS

instance ForgetShape (target (TKX sh r)) where
  type NoShape (target (TKX sh r)) = target (TKX sh r)
  forgetShape :: target (TKX sh r) -> NoShape (target (TKX sh r))
forgetShape = target (TKX sh r) -> target (TKX sh r)
target (TKX sh r) -> NoShape (target (TKX sh r))
forall a. a -> a
id

type family NoShapeTensorKind tk where
  NoShapeTensorKind (TKScalar r) = TKScalar r
  NoShapeTensorKind (TKR2 n r) = TKR2 n r
  NoShapeTensorKind (TKS2 sh r) = TKR2 (Rank sh) r
  NoShapeTensorKind (TKX2 sh r) = TKX2 sh r
  NoShapeTensorKind (TKProduct y z) =
    TKProduct (NoShapeTensorKind y) (NoShapeTensorKind z)

instance ( ForgetShape (target a)
         , ForgetShape (target b)
         , target (NoShapeTensorKind a) ~ NoShape (target a)
         , target (NoShapeTensorKind b) ~ NoShape (target b)
         , BaseTensor target, LetTensor target )
         => ForgetShape (target (TKProduct a b)) where
  type NoShape (target (TKProduct a b)) =
    target (NoShapeTensorKind (TKProduct a b))
  forgetShape :: target (TKProduct a b) -> NoShape (target (TKProduct a b))
forgetShape target (TKProduct a b)
ab =
    target (TKProduct a b)
-> (target (TKProduct a b)
    -> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b)))
-> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b))
forall (x :: TK) (z :: TK).
target x -> (target x -> target z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet target (TKProduct a b)
ab ((target (TKProduct a b)
  -> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b)))
 -> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b)))
-> (target (TKProduct a b)
    -> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b)))
-> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b))
forall a b. (a -> b) -> a -> b
$ \target (TKProduct a b)
abShared ->
      target (NoShapeTensorKind a)
-> target (NoShapeTensorKind b)
-> target (TKProduct (NoShapeTensorKind a) (NoShapeTensorKind b))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (target a -> NoShape (target a)
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (target (TKProduct a b) -> target a
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct a b)
abShared))
            (target b -> NoShape (target b)
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (target (TKProduct a b) -> target b
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct a b)
abShared))

instance forall r target. (GoodScalar r, BaseTensor target)
         => RandomValue (target (TKScalar r)) where
  randomValue :: Double -> StdGen -> (target (TKScalar r), StdGen)
randomValue Double
range StdGen
g =
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r
      (let (r
r, StdGen
g2) = StdGen -> (r, StdGen)
forall g. RandomGen g => g -> (r, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random StdGen
g
           m :: r
m = r
2 r -> r -> r
forall a. Num a => a -> a -> a
* Double -> r
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
range r -> r -> r
forall a. Num a => a -> a -> a
* (r
r r -> r -> r
forall a. Num a => a -> a -> a
- r
0.5)
       in (r -> target (TKScalar r)
forall r. GoodScalar r => r -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
r -> target (TKScalar r)
tkconcrete r
m, StdGen
g2))
      (r -> target (TKScalar r)
forall r. GoodScalar r => r -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
r -> target (TKScalar r)
tkconcrete r
forall a. Default a => a
def, StdGen
g)

instance forall sh r target. (KnownShS sh, GoodScalar r, BaseTensor target)
         => RandomValue (target (TKS sh r)) where
  randomValue :: Double -> StdGen -> (target (TKS sh r), StdGen)
randomValue Double
range StdGen
g =
    forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r
      (let createRandomVector :: Int -> StdGen -> target (TKS sh r)
           createRandomVector :: Int -> StdGen -> target (TKS sh r)
createRandomVector Int
n StdGen
seed =
             r -> target (TKS sh r)
srepl (r
2 r -> r -> r
forall a. Num a => a -> a -> a
* Double -> r
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
range)
             target (TKS sh r) -> target (TKS sh r) -> target (TKS sh r)
forall a. Num a => a -> a -> a
* (Shaped sh r -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Shaped sh r -> target (TKS sh r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, GoodScalar r) =>
Shaped sh r -> target (TKS sh r)
tsconcrete
                  (ShS sh -> Vector r -> Shaped sh r
forall a (sh :: [Nat]).
PrimElt a =>
ShS sh -> Vector a -> Shaped sh a
Nested.sfromVector ShS sh
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS (Int -> [r] -> Vector r
forall (v :: Type -> Type) a. Vector v a => Int -> [a] -> v a
V.fromListN Int
n (StdGen -> [r]
forall g. RandomGen g => g -> [r]
forall a g. (Random a, RandomGen g) => g -> [a]
randoms StdGen
seed)))
                target (TKS sh r) -> target (TKS sh r) -> target (TKS sh r)
forall a. Num a => a -> a -> a
- r -> target (TKS sh r)
srepl r
0.5)
           (StdGen
g1, StdGen
g2) = StdGen -> (StdGen, StdGen)
forall g. SplitGen g => g -> (g, g)
splitGen StdGen
g
           arr :: target (TKS sh r)
arr = Int -> StdGen -> target (TKS sh r)
createRandomVector (ShS sh -> Int
forall (sh :: [Nat]). ShS sh -> Int
shsSize (forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS @sh)) StdGen
g1
       in (target (TKS sh r)
arr, StdGen
g2))
      (r -> target (TKS sh r)
srepl r
forall a. Default a => a
def, StdGen
g)
   where srepl :: r -> target (TKS sh r)
srepl = Shaped sh r -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Shaped sh r -> target (TKS sh r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, GoodScalar r) =>
Shaped sh r -> target (TKS sh r)
tsconcrete (Shaped sh r -> target (TKS sh r))
-> (r -> Shaped sh r) -> r -> target (TKS sh r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShS sh -> r -> Shaped sh r
forall a (sh :: [Nat]). PrimElt a => ShS sh -> a -> Shaped sh a
Nested.sreplicateScal ShS sh
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS
  -- {-# SPECIALIZE instance (KnownShS sh, GoodScalar r, Fractional r, Random r) => RandomValue (Concrete (TKS sh r)) #-}
  {-# SPECIALIZE instance KnownShS sh => RandomValue (Concrete (TKS sh Double)) #-}
  {-# SPECIALIZE instance KnownShS sh => RandomValue (Concrete (TKS sh Float)) #-}

instance (RandomValue (target a), RandomValue (target b), BaseTensor target)
         => RandomValue (target (TKProduct a b)) where
  randomValue :: Double -> StdGen -> (target (TKProduct a b), StdGen)
randomValue Double
range StdGen
g =
    let (target a
v1, StdGen
g1) = Double -> StdGen -> (target a, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g
        (target b
v2, StdGen
g2) = Double -> StdGen -> (target b, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g1
    in (target a -> target b -> target (TKProduct a b)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target a
v1 target b
v2, StdGen
g2)

-- These instances are messy and hard to use, but we probably can't do better.
instance DualNumberValue Double where
  type DValue Double = Concrete (TKScalar Double)
  fromDValue :: DValue Double -> Double
fromDValue (Concrete RepConcrete (TKScalar Double)
d) = Double
RepConcrete (TKScalar Double)
d

instance DualNumberValue Float where
  type DValue Float = Concrete (TKScalar Float)
  fromDValue :: DValue Float -> Float
fromDValue (Concrete RepConcrete (TKScalar Float)
d) = Float
RepConcrete (TKScalar Float)
d

instance TermValue (Concrete (TKScalar Double)) where
  type Value (Concrete (TKScalar Double)) = Double
  fromValue :: Value (Concrete (TKScalar Double)) -> Concrete (TKScalar Double)
fromValue = RepConcrete (TKScalar Double) -> Concrete (TKScalar Double)
Value (Concrete (TKScalar Double)) -> Concrete (TKScalar Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete

instance TermValue (Concrete (TKScalar Float)) where
  type Value (Concrete (TKScalar Float)) = Float
  fromValue :: Value (Concrete (TKScalar Float)) -> Concrete (TKScalar Float)
fromValue = RepConcrete (TKScalar Float) -> Concrete (TKScalar Float)
Value (Concrete (TKScalar Float)) -> Concrete (TKScalar Float)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete


-- * Compound instances

instance (BaseTensor target, ConvertTensor target, GoodScalar r)
         => AdaptableTarget target [target (TKScalar r)] where
  type X [target (TKScalar r)] = TKR 1 r
  toTarget :: [target (TKScalar r)] -> target (X [target (TKScalar r)])
toTarget [target (TKScalar r)]
l = if [target (TKScalar r)] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [target (TKScalar r)]
l
               then Ranked 1 r -> target (TKR 1 r)
forall r (n :: Nat). GoodScalar r => Ranked n r -> target (TKR n r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, GoodScalar r) =>
Ranked n r -> target (TKR n r)
trconcrete Ranked 1 r
forall a. KnownElt a => Ranked 1 a
Nested.remptyArray
               else Vector (target (TKR2 0 (TKScalar r)))
-> target (TKR2 (1 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trfromVector (Vector (target (TKR2 0 (TKScalar r)))
 -> target (TKR2 (1 + 0) (TKScalar r)))
-> Vector (target (TKR2 0 (TKScalar r)))
-> target (TKR2 (1 + 0) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ [target (TKR2 0 (TKScalar r))]
-> Vector (target (TKR2 0 (TKScalar r)))
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList ([target (TKR2 0 (TKScalar r))]
 -> Vector (target (TKR2 0 (TKScalar r))))
-> [target (TKR2 0 (TKScalar r))]
-> Vector (target (TKR2 0 (TKScalar r)))
forall a b. (a -> b) -> a -> b
$ (target (TKScalar r) -> target (TKR2 0 (TKScalar r)))
-> [target (TKScalar r)] -> [target (TKR2 0 (TKScalar r))]
forall a b. (a -> b) -> [a] -> [b]
map target (TKScalar r) -> target (TKR2 0 (TKScalar r))
forall r. GoodScalar r => target (TKScalar r) -> target (TKR 0 r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKScalar r) -> target (TKR 0 r)
rfromK [target (TKScalar r)]
l
  fromTarget :: target (X [target (TKScalar r)]) -> [target (TKScalar r)]
fromTarget = (target (TKR2 0 (TKScalar r)) -> target (TKScalar r))
-> [target (TKR2 0 (TKScalar r))] -> [target (TKScalar r)]
forall a b. (a -> b) -> [a] -> [b]
map target (TKR2 0 (TKScalar r)) -> target (TKScalar r)
forall r. GoodScalar r => target (TKR 0 r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKR 0 r) -> target (TKScalar r)
kfromR ([target (TKR2 0 (TKScalar r))] -> [target (TKScalar r)])
-> (target (TKR 1 r) -> [target (TKR2 0 (TKScalar r))])
-> target (TKR 1 r)
-> [target (TKScalar r)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKR 1 r) -> [target (TKR2 0 (TKScalar r))]
target (TKR2 (1 + 0) (TKScalar r))
-> [target (TKR2 0 (TKScalar r))]
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
trunravelToList
                              -- inefficient, but we probably can't do better

instance (BaseTensor target, ConvertTensor target, GoodScalar r)
         => AdaptableTarget target
                            (Data.Vector.Vector (target (TKScalar r))) where
  type X (Data.Vector.Vector (target (TKScalar r))) = TKR 1 r
  toTarget :: Vector (target (TKScalar r))
-> target (X (Vector (target (TKScalar r))))
toTarget Vector (target (TKScalar r))
v = if Vector (target (TKScalar r)) -> Bool
forall (v :: Type -> Type) a. Vector v a => v a -> Bool
V.null Vector (target (TKScalar r))
v
               then Ranked 1 r -> target (TKR 1 r)
forall r (n :: Nat). GoodScalar r => Ranked n r -> target (TKR n r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, GoodScalar r) =>
Ranked n r -> target (TKR n r)
trconcrete Ranked 1 r
forall a. KnownElt a => Ranked 1 a
Nested.remptyArray
               else Vector (target (TKR2 0 (TKScalar r)))
-> target (TKR2 (1 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trfromVector (Vector (target (TKR2 0 (TKScalar r)))
 -> target (TKR2 (1 + 0) (TKScalar r)))
-> Vector (target (TKR2 0 (TKScalar r)))
-> target (TKR2 (1 + 0) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ (target (TKScalar r) -> target (TKR2 0 (TKScalar r)))
-> Vector (target (TKScalar r))
-> Vector (target (TKR2 0 (TKScalar r)))
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map target (TKScalar r) -> target (TKR2 0 (TKScalar r))
forall r. GoodScalar r => target (TKScalar r) -> target (TKR 0 r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKScalar r) -> target (TKR 0 r)
rfromK Vector (target (TKScalar r))
v
  fromTarget :: target (X (Vector (target (TKScalar r))))
-> Vector (target (TKScalar r))
fromTarget =
    [target (TKScalar r)] -> Vector (target (TKScalar r))
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList ([target (TKScalar r)] -> Vector (target (TKScalar r)))
-> (target (TKR 1 r) -> [target (TKScalar r)])
-> target (TKR 1 r)
-> Vector (target (TKScalar r))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (target (TKR2 0 (TKScalar r)) -> target (TKScalar r))
-> [target (TKR2 0 (TKScalar r))] -> [target (TKScalar r)]
forall a b. (a -> b) -> [a] -> [b]
map target (TKR2 0 (TKScalar r)) -> target (TKScalar r)
forall r. GoodScalar r => target (TKR 0 r) -> target (TKScalar r)
forall (target :: Target) r.
(ConvertTensor target, GoodScalar r) =>
target (TKR 0 r) -> target (TKScalar r)
kfromR ([target (TKR2 0 (TKScalar r))] -> [target (TKScalar r)])
-> (target (TKR 1 r) -> [target (TKR2 0 (TKScalar r))])
-> target (TKR 1 r)
-> [target (TKScalar r)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKR 1 r) -> [target (TKR2 0 (TKScalar r))]
target (TKR2 (1 + 0) (TKScalar r))
-> [target (TKR2 0 (TKScalar r))]
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
trunravelToList
                                -- inefficient, but we probably can't do better

type family Tups n t where
  Tups 0 t = TKUnit
  Tups n t = TKProduct t (Tups (n - 1) t)

stkOfListR :: forall t n.
              SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR :: forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR SingletonTK t
_ (SNat' @0) = SingletonTK TKUnit
SingletonTK (Tups n t)
stkUnit
stkOfListR SingletonTK t
stk SNat n
SNat =
  (:~:) @Bool (OrdCond @Bool (CmpNat 1 n) 'True 'True 'False) 'True
-> (((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
     ~ ('True :: Bool)) =>
    SingletonTK (Tups n t))
-> SingletonTK (Tups n t)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Bool (OrdCond @Bool (CmpNat 1 n) 'True 'True 'False) 'True
(:~:)
  @Bool (OrdCond @Bool (Compare @Nat 1 n) 'True 'True 'False) 'True
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: (1 <=? n) :~: True) ((((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
   ~ ('True :: Bool)) =>
  SingletonTK (Tups n t))
 -> SingletonTK (Tups n t))
-> (((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
     ~ ('True :: Bool)) =>
    SingletonTK (Tups n t))
-> SingletonTK (Tups n t)
forall a b. (a -> b) -> a -> b
$
  (:~:) @TK (Tups n t) (TKProduct t (Tups (n - 1) t))
-> (((Tups n t :: TK) ~ (TKProduct t (Tups (n - 1) t) :: TK)) =>
    SingletonTK (Tups n t))
-> SingletonTK (Tups n t)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @TK (Tups n t) (TKProduct t (Tups (n - 1) t))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Tups n t :~: TKProduct t (Tups (n - 1) t)) ((((Tups n t :: TK) ~ (TKProduct t (Tups (n - 1) t) :: TK)) =>
  SingletonTK (Tups n t))
 -> SingletonTK (Tups n t))
-> (((Tups n t :: TK) ~ (TKProduct t (Tups (n - 1) t) :: TK)) =>
    SingletonTK (Tups n t))
-> SingletonTK (Tups n t)
forall a b. (a -> b) -> a -> b
$
  SingletonTK t
-> SingletonTK (Tups (n - 1) t)
-> SingletonTK (TKProduct t (Tups (n - 1) t))
forall (y1 :: TK) (z :: TK).
SingletonTK y1 -> SingletonTK z -> SingletonTK (TKProduct y1 z)
STKProduct SingletonTK t
stk (SingletonTK t -> SNat (n - 1) -> SingletonTK (Tups (n - 1) t)
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR SingletonTK t
stk (forall (n :: Nat). KnownNat n => SNat n
SNat @(n - 1)))

instance (BaseTensor target, KnownNat n, AdaptableTarget target a)
         => AdaptableTarget target (ListR n a) where
  type X (ListR n a) = Tups n (X a)
  toTarget :: ListR n a -> target (X (ListR n a))
toTarget ListR n a
ZR = Z1 -> target TKUnit
forall r. GoodScalar r => r -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
r -> target (TKScalar r)
tkconcrete Z1
Z1
  toTarget ((:::) @n1 a
a ListR n1 a
rest) =
    (:~:) @TK (Tups n (X a)) (TKProduct (X a) (Tups n1 (X a)))
-> (((Tups n (X a) :: TK)
     ~ (TKProduct (X a) (Tups n1 (X a)) :: TK)) =>
    target (X (ListR n a)))
-> target (X (ListR n a))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @TK (Tups n (X a)) (TKProduct (X a) (Tups n1 (X a)))
(:~:) @TK (X (ListR n a)) (TKProduct (X a) (X (ListR n1 a)))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
               :: X (ListR n a) :~: TKProduct (X a) (X (ListR n1 a))) ((((Tups n (X a) :: TK)
   ~ (TKProduct (X a) (Tups n1 (X a)) :: TK)) =>
  target (X (ListR n a)))
 -> target (X (ListR n a)))
-> (((Tups n (X a) :: TK)
     ~ (TKProduct (X a) (Tups n1 (X a)) :: TK)) =>
    target (X (ListR n a)))
-> target (X (ListR n a))
forall a b. (a -> b) -> a -> b
$
    let a1 :: target (X a)
a1 = a -> target (X a)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget a
a
        rest1 :: target (X (ListR n1 a))
rest1 = ListR n1 a -> target (X (ListR n1 a))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget ListR n1 a
rest
    in target (X a)
-> target (Tups n1 (X a))
-> target (TKProduct (X a) (Tups n1 (X a)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X a)
a1 target (Tups n1 (X a))
target (X (ListR n1 a))
rest1
  fromTarget :: target (X (ListR n a)) -> ListR n a
fromTarget target (X (ListR n a))
tups = case forall (n :: Nat). KnownNat n => SNat n
SNat @n of
    SNat' @0 -> ListR n a
ListR 0 a
forall i. ListR 0 i
ZR
    SNat n
_ ->
      (:~:) @Bool (OrdCond @Bool (CmpNat 1 n) 'True 'True 'False) 'True
-> (((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
     ~ ('True :: Bool)) =>
    ListR n a)
-> ListR n a
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Bool (OrdCond @Bool (CmpNat 1 n) 'True 'True 'False) 'True
(:~:)
  @Bool (OrdCond @Bool (Compare @Nat 1 n) 'True 'True 'False) 'True
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: (1 <=? n) :~: True) ((((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
   ~ ('True :: Bool)) =>
  ListR n a)
 -> ListR n a)
-> (((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
     ~ ('True :: Bool)) =>
    ListR n a)
-> ListR n a
forall a b. (a -> b) -> a -> b
$
      (:~:) @TK (Tups n (X a)) (TKProduct (X a) (Tups (n - 1) (X a)))
-> (((Tups n (X a) :: TK)
     ~ (TKProduct (X a) (Tups (n - 1) (X a)) :: TK)) =>
    ListR n a)
-> ListR n a
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @TK (Tups n (X a)) (TKProduct (X a) (Tups (n - 1) (X a)))
(:~:) @TK (X (ListR n a)) (TKProduct (X a) (X (ListR (n - 1) a)))
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
                 :: X (ListR n a) :~: TKProduct (X a) (X (ListR (n - 1) a))) ((((Tups n (X a) :: TK)
   ~ (TKProduct (X a) (Tups (n - 1) (X a)) :: TK)) =>
  ListR n a)
 -> ListR n a)
-> (((Tups n (X a) :: TK)
     ~ (TKProduct (X a) (Tups (n - 1) (X a)) :: TK)) =>
    ListR n a)
-> ListR n a
forall a b. (a -> b) -> a -> b
$
      let (target (X a)
a1, target (Tups (n - 1) (X a))
rest1) = (target (TKProduct (X a) (Tups (n - 1) (X a))) -> target (X a)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (X a) (Tups (n - 1) (X a)))
target (X (ListR n a))
tups, target (TKProduct (X a) (Tups (n - 1) (X a)))
-> target (Tups (n - 1) (X a))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (X a) (Tups (n - 1) (X a)))
target (X (ListR n a))
tups)
          a :: a
a = target (X a) -> a
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget target (X a)
a1
          rest :: ListR (n - 1) a
rest = target (X (ListR (n - 1) a)) -> ListR (n - 1) a
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget target (Tups (n - 1) (X a))
target (X (ListR (n - 1) a))
rest1
      in (a
a a -> ListR (n - 1) a -> ListR ((n - 1) + 1) a
forall (n1 :: Nat) {i}. i -> ListR n1 i -> ListR (n1 + 1) i
::: ListR (n - 1) a
rest)
  {-# SPECIALIZE instance (KnownNat n, AdaptableTarget (AstTensor AstMethodLet FullSpan) a) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (ListR n a) #-}
  {-# SPECIALIZE instance (KnownNat n, AdaptableTarget (ADVal Concrete) a) => AdaptableTarget (ADVal Concrete) (ListR n a) #-}

instance TermValue a => TermValue [a] where
  type Value [a] = [Value a]
  fromValue :: Value [a] -> [a]
fromValue = (Value a -> a) -> [Value a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue

instance TermValue a => TermValue (Data.Vector.Vector a) where
  type Value (Data.Vector.Vector a) = Data.Vector.Vector (Value a)
  fromValue :: Value (Vector a) -> Vector a
fromValue = (Value a -> a) -> Vector (Value a) -> Vector a
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue

instance TermValue a => TermValue (ListR n a) where
  type Value (ListR n a) = ListR n (Value a)
  fromValue :: Value (ListR n a) -> ListR n a
fromValue ListR n (Value a)
Value (ListR n a)
ZR = ListR n a
ListR 0 a
forall i. ListR 0 i
ZR
  fromValue (Value a
a ::: ListR n1 (Value a)
rest) = Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue Value a
a a -> ListR n1 a -> ListR (n1 + 1) a
forall (n1 :: Nat) {i}. i -> ListR n1 i -> ListR (n1 + 1) i
::: Value (ListR n1 a) -> ListR n1 a
forall vals. TermValue vals => Value vals -> vals
fromValue ListR n1 (Value a)
Value (ListR n1 a)
rest

instance DualNumberValue a => DualNumberValue [a] where
  type DValue [a] = [DValue a]
  fromDValue :: DValue [a] -> [a]
fromDValue = (DValue a -> a) -> [DValue a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue

instance DualNumberValue a => DualNumberValue (Data.Vector.Vector a) where
  type DValue (Data.Vector.Vector a) = Data.Vector.Vector (DValue a)
  fromDValue :: DValue (Vector a) -> Vector a
fromDValue = (DValue a -> a) -> Vector (DValue a) -> Vector a
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue

instance DualNumberValue a => DualNumberValue (ListR n a) where
  type DValue (ListR n a) = ListR n (DValue a)
  fromDValue :: DValue (ListR n a) -> ListR n a
fromDValue ListR n (DValue a)
DValue (ListR n a)
ZR = ListR n a
ListR 0 a
forall i. ListR 0 i
ZR
  fromDValue (DValue a
a ::: ListR n1 (DValue a)
rest) = DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue a
a a -> ListR n1 a -> ListR (n1 + 1) a
forall (n1 :: Nat) {i}. i -> ListR n1 i -> ListR (n1 + 1) i
::: DValue (ListR n1 a) -> ListR n1 a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue ListR n1 (DValue a)
DValue (ListR n1 a)
rest

instance ForgetShape [a] where
  type NoShape [a] = [a]
  forgetShape :: [a] -> NoShape [a]
forgetShape = [a] -> [a]
[a] -> NoShape [a]
forall a. a -> a
id

instance ForgetShape (Data.Vector.Vector a) where
  type NoShape (Data.Vector.Vector a) = Data.Vector.Vector a
  forgetShape :: Vector a -> NoShape (Vector a)
forgetShape = Vector a -> Vector a
Vector a -> NoShape (Vector a)
forall a. a -> a
id

instance ForgetShape a => ForgetShape (ListR n a) where
  type NoShape (ListR n a) = ListR n (NoShape a)
  forgetShape :: ListR n a -> NoShape (ListR n a)
forgetShape ListR n a
ZR = ListR 0 (NoShape a)
NoShape (ListR n a)
forall i. ListR 0 i
ZR
  forgetShape (a
a ::: ListR n1 a
rest) = a -> NoShape a
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape a
a NoShape a -> ListR n1 (NoShape a) -> ListR (n1 + 1) (NoShape a)
forall (n1 :: Nat) {i}. i -> ListR n1 i -> ListR (n1 + 1) i
::: ListR n1 a -> NoShape (ListR n1 a)
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape ListR n1 a
rest

instance (RandomValue a, KnownNat n) => RandomValue (ListR n a) where
  randomValue :: Double -> StdGen -> (ListR n a, StdGen)
randomValue Double
range StdGen
g = case Proxy @Nat n -> Proxy @Nat 0 -> OrderingI @Nat n 0
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI @Nat a b
cmpNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0)  of
    OrderingI @Nat n 0
LTI -> [Char] -> (ListR n a, StdGen)
forall a. HasCallStack => [Char] -> a
error [Char]
"randomValue: impossible"
    OrderingI @Nat n 0
EQI -> (ListR n a
ListR 0 a
forall i. ListR 0 i
ZR, StdGen
g)
    OrderingI @Nat n 0
GTI -> (:~:) @Bool (OrdCond @Bool (CmpNat 1 n) 'True 'True 'False) 'True
-> (((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
     ~ ('True :: Bool)) =>
    (ListR n a, StdGen))
-> (ListR n a, StdGen)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Bool (OrdCond @Bool (CmpNat 1 n) 'True 'True 'False) 'True
(:~:)
  @Bool (OrdCond @Bool (Compare @Nat 1 n) 'True 'True 'False) 'True
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: (1 <=? n) :~: True) ((((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
   ~ ('True :: Bool)) =>
  (ListR n a, StdGen))
 -> (ListR n a, StdGen))
-> (((OrdCond @Bool (CmpNat 1 n) 'True 'True 'False :: Bool)
     ~ ('True :: Bool)) =>
    (ListR n a, StdGen))
-> (ListR n a, StdGen)
forall a b. (a -> b) -> a -> b
$
           let (a
v, StdGen
g1) = Double -> StdGen -> (a, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g
               (ListR (n - 1) a
rest, StdGen
g2) = forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(ListR (n - 1) a) Double
range StdGen
g1
           in (a
v a -> ListR (n - 1) a -> ListR ((n - 1) + 1) a
forall (n1 :: Nat) {i}. i -> ListR n1 i -> ListR (n1 + 1) i
::: ListR (n - 1) a
rest, StdGen
g2)


-- * Tuple instances

instance ( BaseTensor target
         , AdaptableTarget target a
         , AdaptableTarget target b )
         => AdaptableTarget target (a, b) where
  type X (a, b) = TKProduct (X a) (X b)
  toTarget :: (a, b) -> target (X (a, b))
toTarget (a
a, b
b) =
    let a1 :: target (X a)
a1 = a -> target (X a)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget a
a
        b1 :: target (X b)
b1 = b -> target (X b)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget b
b
    in target (X a) -> target (X b) -> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X a)
a1 target (X b)
b1
  fromTarget :: target (X (a, b)) -> (a, b)
fromTarget target (X (a, b))
ab =
    let a :: a
a = target (X a) -> a
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X a) -> a) -> target (X a) -> a
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X a)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (X a) (X b))
target (X (a, b))
ab
        b :: b
b = target (X b) -> b
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X b) -> b) -> target (X b) -> b
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X b)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (X a) (X b))
target (X (a, b))
ab
    in (a
a, b
b)
  {-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b) #-}

instance (TermValue a, TermValue b) => TermValue (a, b) where
  type Value (a, b) = (Value a, Value b)
  fromValue :: Value (a, b) -> (a, b)
fromValue (Value a
va, Value b
vb) = (Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue Value a
va, Value b -> b
forall vals. TermValue vals => Value vals -> vals
fromValue Value b
vb)

instance (DualNumberValue a, DualNumberValue b) => DualNumberValue (a, b) where
  type DValue (a, b) = (DValue a, DValue b)
  fromDValue :: DValue (a, b) -> (a, b)
fromDValue (DValue a
va, DValue b
vb) = (DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue a
va, DValue b -> b
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue b
vb)

instance ( ForgetShape a
         , ForgetShape b ) => ForgetShape (a, b) where
  type NoShape (a, b) = (NoShape a, NoShape b)
  forgetShape :: (a, b) -> NoShape (a, b)
forgetShape (a
a, b
b) = (a -> NoShape a
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape a
a, b -> NoShape b
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape b
b)

instance ( RandomValue a
         , RandomValue b ) => RandomValue (a, b) where
  randomValue :: Double -> StdGen -> ((a, b), StdGen)
randomValue Double
range StdGen
g =
    let (a
v1, StdGen
g1) = Double -> StdGen -> (a, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g
        (b
v2, StdGen
g2) = Double -> StdGen -> (b, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g1
    in ((a
v1, b
v2), StdGen
g2)

instance ( BaseTensor target
         , AdaptableTarget target a
         , AdaptableTarget target b
         , AdaptableTarget target c )
         => AdaptableTarget target (a, b, c) where
  type X (a, b, c) = TKProduct (TKProduct (X a) (X b)) (X c)
  toTarget :: (a, b, c) -> target (X (a, b, c))
toTarget (a
a, b
b, c
c) =
    let a1 :: target (X a)
a1 = a -> target (X a)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget a
a
        b1 :: target (X b)
b1 = b -> target (X b)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget b
b
        c1 :: target (X c)
c1 = c -> target (X c)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget c
c
    in target (TKProduct (X a) (X b))
-> target (X c) -> target (TKProduct (TKProduct (X a) (X b)) (X c))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (target (X a) -> target (X b) -> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X a)
a1 target (X b)
b1) target (X c)
c1
  fromTarget :: target (X (a, b, c)) -> (a, b, c)
fromTarget target (X (a, b, c))
abc =
    let a :: a
a = target (X a) -> a
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X a) -> a) -> target (X a) -> a
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X a)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (X a) (X b)) -> target (X a))
-> target (TKProduct (X a) (X b)) -> target (X a)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (TKProduct (X a) (X b)) (X c))
target (X (a, b, c))
abc
        b :: b
b = target (X b) -> b
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X b) -> b) -> target (X b) -> b
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X b)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (target (TKProduct (X a) (X b)) -> target (X b))
-> target (TKProduct (X a) (X b)) -> target (X b)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (TKProduct (X a) (X b)) (X c))
target (X (a, b, c))
abc
        c :: c
c = target (X c) -> c
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X c) -> c) -> target (X c) -> c
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (X c)) -> target (X c)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (TKProduct (X a) (X b)) (X c))
target (X (a, b, c))
abc
    in (a
a, b
b, c
c)
  {-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b, AdaptableTarget (AstTensor AstMethodLet FullSpan) c) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b, c) #-}

instance (TermValue a, TermValue b, TermValue c)
         => TermValue (a, b, c) where
  type Value (a, b, c) = (Value a, Value b, Value c)
  fromValue :: Value (a, b, c) -> (a, b, c)
fromValue (Value a
va, Value b
vb, Value c
vc) = (Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue Value a
va, Value b -> b
forall vals. TermValue vals => Value vals -> vals
fromValue Value b
vb, Value c -> c
forall vals. TermValue vals => Value vals -> vals
fromValue Value c
vc)

instance (DualNumberValue a, DualNumberValue b, DualNumberValue c)
         => DualNumberValue (a, b, c) where
  type DValue (a, b, c) = (DValue a, DValue b, DValue c)
  fromDValue :: DValue (a, b, c) -> (a, b, c)
fromDValue (DValue a
va, DValue b
vb, DValue c
vc) = (DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue a
va, DValue b -> b
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue b
vb, DValue c -> c
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue c
vc)

instance ( ForgetShape a
         , ForgetShape b
         , ForgetShape c ) => ForgetShape (a, b, c) where
  type NoShape (a, b, c) = (NoShape a, NoShape b, NoShape c)
  forgetShape :: (a, b, c) -> NoShape (a, b, c)
forgetShape (a
a, b
b, c
c) = (a -> NoShape a
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape a
a, b -> NoShape b
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape b
b, c -> NoShape c
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape c
c)

instance ( RandomValue a
         , RandomValue b
         , RandomValue c ) => RandomValue (a, b, c) where
  randomValue :: Double -> StdGen -> ((a, b, c), StdGen)
randomValue Double
range StdGen
g =
    let (a
v1, StdGen
g1) = Double -> StdGen -> (a, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g
        (b
v2, StdGen
g2) = Double -> StdGen -> (b, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g1
        (c
v3, StdGen
g3) = Double -> StdGen -> (c, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g2
    in ((a
v1, b
v2, c
v3), StdGen
g3)

instance ( BaseTensor target
         , AdaptableTarget target a
         , AdaptableTarget target b
         , AdaptableTarget target c
         , AdaptableTarget target d)
         => AdaptableTarget target (a, b, c, d) where
  type X (a, b, c, d) = TKProduct (TKProduct (X a) (X b))
                                  (TKProduct (X c) (X d))
  toTarget :: (a, b, c, d) -> target (X (a, b, c, d))
toTarget (a
a, b
b, c
c, d
d) =
    let a1 :: target (X a)
a1 = a -> target (X a)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget a
a
        b1 :: target (X b)
b1 = b -> target (X b)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget b
b
        c1 :: target (X c)
c1 = c -> target (X c)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget c
c
        d1 :: target (X d)
d1 = d -> target (X d)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget d
d
    in  target (TKProduct (X a) (X b))
-> target (TKProduct (X c) (X d))
-> target
     (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (target (X a) -> target (X b) -> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X a)
a1 target (X b)
b1) (target (X c) -> target (X d) -> target (TKProduct (X c) (X d))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X c)
c1 target (X d)
d1)
  fromTarget :: target (X (a, b, c, d)) -> (a, b, c, d)
fromTarget target (X (a, b, c, d))
abcd =
    let a :: a
a = target (X a) -> a
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X a) -> a) -> target (X a) -> a
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X a)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (X a) (X b)) -> target (X a))
-> target (TKProduct (X a) (X b)) -> target (X a)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
-> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
target (X (a, b, c, d))
abcd
        b :: b
b = target (X b) -> b
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X b) -> b) -> target (X b) -> b
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X b)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (target (TKProduct (X a) (X b)) -> target (X b))
-> target (TKProduct (X a) (X b)) -> target (X b)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
-> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
target (X (a, b, c, d))
abcd
        c :: c
c = target (X c) -> c
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X c) -> c) -> target (X c) -> c
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X c) (X d)) -> target (X c)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (X c) (X d)) -> target (X c))
-> target (TKProduct (X c) (X d)) -> target (X c)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
-> target (TKProduct (X c) (X d))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
target (X (a, b, c, d))
abcd
        d :: d
d = target (X d) -> d
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X d) -> d) -> target (X d) -> d
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X c) (X d)) -> target (X d)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (target (TKProduct (X c) (X d)) -> target (X d))
-> target (TKProduct (X c) (X d)) -> target (X d)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
-> target (TKProduct (X c) (X d))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct (TKProduct (X a) (X b)) (TKProduct (X c) (X d)))
target (X (a, b, c, d))
abcd
    in (a
a, b
b, c
c, d
d)
  {-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b, AdaptableTarget (AstTensor AstMethodLet FullSpan) c, AdaptableTarget (AstTensor AstMethodLet FullSpan) d) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b, c, d) #-}

instance (TermValue a, TermValue b, TermValue c, TermValue d)
         => TermValue (a, b, c, d) where
  type Value (a, b, c, d) = (Value a, Value b, Value c, Value d)
  fromValue :: Value (a, b, c, d) -> (a, b, c, d)
fromValue (Value a
va, Value b
vb, Value c
vc, Value d
vd) =
    (Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue Value a
va, Value b -> b
forall vals. TermValue vals => Value vals -> vals
fromValue Value b
vb, Value c -> c
forall vals. TermValue vals => Value vals -> vals
fromValue Value c
vc, Value d -> d
forall vals. TermValue vals => Value vals -> vals
fromValue Value d
vd)

instance ( DualNumberValue a, DualNumberValue b, DualNumberValue c
         , DualNumberValue d )
         => DualNumberValue (a, b, c, d) where
  type DValue (a, b, c, d) = (DValue a, DValue b, DValue c, DValue d)
  fromDValue :: DValue (a, b, c, d) -> (a, b, c, d)
fromDValue (DValue a
va, DValue b
vb, DValue c
vc, DValue d
vd) =
    (DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue a
va, DValue b -> b
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue b
vb, DValue c -> c
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue c
vc, DValue d -> d
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue d
vd)

instance ( ForgetShape a
         , ForgetShape b
         , ForgetShape c
         , ForgetShape d ) => ForgetShape (a, b, c, d) where
  type NoShape (a, b, c, d) =
    (NoShape a, NoShape b, NoShape c, NoShape d)
  forgetShape :: (a, b, c, d) -> NoShape (a, b, c, d)
forgetShape (a
a, b
b, c
c, d
d) =
    (a -> NoShape a
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape a
a, b -> NoShape b
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape b
b, c -> NoShape c
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape c
c, d -> NoShape d
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape d
d)

instance ( RandomValue a
         , RandomValue b
         , RandomValue c
         , RandomValue d ) => RandomValue (a, b, c, d) where
  randomValue :: Double -> StdGen -> ((a, b, c, d), StdGen)
randomValue Double
range StdGen
g =
    let (a
v1, StdGen
g1) = Double -> StdGen -> (a, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g
        (b
v2, StdGen
g2) = Double -> StdGen -> (b, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g1
        (c
v3, StdGen
g3) = Double -> StdGen -> (c, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g2
        (d
v4, StdGen
g4) = Double -> StdGen -> (d, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g3
    in ((a
v1, b
v2, c
v3, d
v4), StdGen
g4)

instance ( BaseTensor target
         , AdaptableTarget target a
         , AdaptableTarget target b
         , AdaptableTarget target c
         , AdaptableTarget target d
         , AdaptableTarget target e)
         => AdaptableTarget target (a, b, c, d, e) where
  type X (a, b, c, d, e) = TKProduct (TKProduct (TKProduct (X a) (X b)) (X c))
                                     (TKProduct (X d) (X e))
  toTarget :: (a, b, c, d, e) -> target (X (a, b, c, d, e))
toTarget (a
a, b
b, c
c, d
d, e
e) =
    let a1 :: target (X a)
a1 = a -> target (X a)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget a
a
        b1 :: target (X b)
b1 = b -> target (X b)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget b
b
        c1 :: target (X c)
c1 = c -> target (X c)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget c
c
        d1 :: target (X d)
d1 = d -> target (X d)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget d
d
        e1 :: target (X e)
e1 = e -> target (X e)
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget e
e
    in target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X d) (X e))
-> target
     (TKProduct
        (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (target (TKProduct (X a) (X b))
-> target (X c) -> target (TKProduct (TKProduct (X a) (X b)) (X c))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (target (X a) -> target (X b) -> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X a)
a1 target (X b)
b1) target (X c)
c1) (target (X d) -> target (X e) -> target (TKProduct (X d) (X e))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (X d)
d1 target (X e)
e1)
  fromTarget :: target (X (a, b, c, d, e)) -> (a, b, c, d, e)
fromTarget target (X (a, b, c, d, e))
abcde =
    let a :: a
a = target (X a) -> a
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X a) -> a) -> target (X a) -> a
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X a)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (X a) (X b)) -> target (X a))
-> target (TKProduct (X a) (X b)) -> target (X a)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (TKProduct (X a) (X b)) (X c))
 -> target (TKProduct (X a) (X b)))
-> target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X a) (X b))
forall a b. (a -> b) -> a -> b
$ target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
-> target (TKProduct (TKProduct (X a) (X b)) (X c))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
target (X (a, b, c, d, e))
abcde
        b :: b
b = target (X b) -> b
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X b) -> b) -> target (X b) -> b
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X a) (X b)) -> target (X b)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (target (TKProduct (X a) (X b)) -> target (X b))
-> target (TKProduct (X a) (X b)) -> target (X b)
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X a) (X b))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (TKProduct (X a) (X b)) (X c))
 -> target (TKProduct (X a) (X b)))
-> target (TKProduct (TKProduct (X a) (X b)) (X c))
-> target (TKProduct (X a) (X b))
forall a b. (a -> b) -> a -> b
$ target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
-> target (TKProduct (TKProduct (X a) (X b)) (X c))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
target (X (a, b, c, d, e))
abcde
        c :: c
c = target (X c) -> c
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X c) -> c) -> target (X c) -> c
forall a b. (a -> b) -> a -> b
$ target (TKProduct (TKProduct (X a) (X b)) (X c)) -> target (X c)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (target (TKProduct (TKProduct (X a) (X b)) (X c)) -> target (X c))
-> target (TKProduct (TKProduct (X a) (X b)) (X c)) -> target (X c)
forall a b. (a -> b) -> a -> b
$ target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
-> target (TKProduct (TKProduct (X a) (X b)) (X c))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
target (X (a, b, c, d, e))
abcde
        d :: d
d = target (X d) -> d
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X d) -> d) -> target (X d) -> d
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X d) (X e)) -> target (X d)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (X d) (X e)) -> target (X d))
-> target (TKProduct (X d) (X e)) -> target (X d)
forall a b. (a -> b) -> a -> b
$ target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
-> target (TKProduct (X d) (X e))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
target (X (a, b, c, d, e))
abcde
        e :: e
e = target (X e) -> e
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (target (X e) -> e) -> target (X e) -> e
forall a b. (a -> b) -> a -> b
$ target (TKProduct (X d) (X e)) -> target (X e)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 (target (TKProduct (X d) (X e)) -> target (X e))
-> target (TKProduct (X d) (X e)) -> target (X e)
forall a b. (a -> b) -> a -> b
$ target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
-> target (TKProduct (X d) (X e))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target
  (TKProduct
     (TKProduct (TKProduct (X a) (X b)) (X c)) (TKProduct (X d) (X e)))
target (X (a, b, c, d, e))
abcde
    in (a
a, b
b, c
c, d
d, e
e)
  {-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b, AdaptableTarget (AstTensor AstMethodLet FullSpan) c, AdaptableTarget (AstTensor AstMethodLet FullSpan) d, AdaptableTarget (AstTensor AstMethodLet FullSpan) e) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b, c, d, e) #-}

instance (TermValue a, TermValue b, TermValue c, TermValue d, TermValue e)
         => TermValue (a, b, c, d, e) where
  type Value (a, b, c, d, e) = (Value a, Value b, Value c, Value d, Value e)
  fromValue :: Value (a, b, c, d, e) -> (a, b, c, d, e)
fromValue (Value a
va, Value b
vb, Value c
vc, Value d
vd, Value e
ve) =
    (Value a -> a
forall vals. TermValue vals => Value vals -> vals
fromValue Value a
va, Value b -> b
forall vals. TermValue vals => Value vals -> vals
fromValue Value b
vb, Value c -> c
forall vals. TermValue vals => Value vals -> vals
fromValue Value c
vc, Value d -> d
forall vals. TermValue vals => Value vals -> vals
fromValue Value d
vd, Value e -> e
forall vals. TermValue vals => Value vals -> vals
fromValue Value e
ve)

instance ( DualNumberValue a, DualNumberValue b, DualNumberValue c
         , DualNumberValue d, DualNumberValue e )
         => DualNumberValue (a, b, c, d, e) where
  type DValue (a, b, c, d, e) =
    (DValue a, DValue b, DValue c, DValue d, DValue e)
  fromDValue :: DValue (a, b, c, d, e) -> (a, b, c, d, e)
fromDValue (DValue a
va, DValue b
vb, DValue c
vc, DValue d
vd, DValue e
ve) =
    (DValue a -> a
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue a
va, DValue b -> b
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue b
vb, DValue c -> c
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue c
vc, DValue d -> d
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue d
vd, DValue e -> e
forall vals. DualNumberValue vals => DValue vals -> vals
fromDValue DValue e
ve)

instance ( ForgetShape a
         , ForgetShape b
         , ForgetShape c
         , ForgetShape d
         , ForgetShape e ) => ForgetShape (a, b, c, d, e) where
  type NoShape (a, b, c, d, e) =
    (NoShape a, NoShape b, NoShape c, NoShape d, NoShape e)
  forgetShape :: (a, b, c, d, e) -> NoShape (a, b, c, d, e)
forgetShape (a
a, b
b, c
c, d
d, e
e) =
    (a -> NoShape a
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape a
a, b -> NoShape b
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape b
b, c -> NoShape c
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape c
c, d -> NoShape d
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape d
d, e -> NoShape e
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape e
e)

instance ( RandomValue a
         , RandomValue b
         , RandomValue c
         , RandomValue d
         , RandomValue e ) => RandomValue (a, b, c, d, e) where
  randomValue :: Double -> StdGen -> ((a, b, c, d, e), StdGen)
randomValue Double
range StdGen
g =
    let (a
v1, StdGen
g1) = Double -> StdGen -> (a, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g
        (b
v2, StdGen
g2) = Double -> StdGen -> (b, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g1
        (c
v3, StdGen
g3) = Double -> StdGen -> (c, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g2
        (d
v4, StdGen
g4) = Double -> StdGen -> (d, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g3
        (e
v5, StdGen
g5) = Double -> StdGen -> (e, StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
range StdGen
g4
    in ((a
v1, b
v2, c
v3, d
v4, e
v5), StdGen
g5)