{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
module Torch.Autograd where
import Foreign.ForeignPtr
import GHC.Generics
import System.IO.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import qualified Torch.Internal.Managed.Autograd
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor
import Data.Default.Class
newtype IndependentTensor = IndependentTensor
{ IndependentTensor -> Tensor
toDependent :: Tensor
}
deriving (Int -> IndependentTensor -> ShowS
[IndependentTensor] -> ShowS
IndependentTensor -> String
(Int -> IndependentTensor -> ShowS)
-> (IndependentTensor -> String)
-> ([IndependentTensor] -> ShowS)
-> Show IndependentTensor
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IndependentTensor -> ShowS
showsPrec :: Int -> IndependentTensor -> ShowS
$cshow :: IndependentTensor -> String
show :: IndependentTensor -> String
$cshowList :: [IndependentTensor] -> ShowS
showList :: [IndependentTensor] -> ShowS
Show, (forall x. IndependentTensor -> Rep IndependentTensor x)
-> (forall x. Rep IndependentTensor x -> IndependentTensor)
-> Generic IndependentTensor
forall x. Rep IndependentTensor x -> IndependentTensor
forall x. IndependentTensor -> Rep IndependentTensor x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. IndependentTensor -> Rep IndependentTensor x
from :: forall x. IndependentTensor -> Rep IndependentTensor x
$cto :: forall x. Rep IndependentTensor x -> IndependentTensor
to :: forall x. Rep IndependentTensor x -> IndependentTensor
Generic)
data GradOptions = GradOptions
{ GradOptions -> Bool
keepGraph :: Bool
, GradOptions -> Bool
createGraph :: Bool
, GradOptions -> Bool
accumulateGrad :: Bool
}
deriving (Int -> GradOptions -> ShowS
[GradOptions] -> ShowS
GradOptions -> String
(Int -> GradOptions -> ShowS)
-> (GradOptions -> String)
-> ([GradOptions] -> ShowS)
-> Show GradOptions
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GradOptions -> ShowS
showsPrec :: Int -> GradOptions -> ShowS
$cshow :: GradOptions -> String
show :: GradOptions -> String
$cshowList :: [GradOptions] -> ShowS
showList :: [GradOptions] -> ShowS
Show)
instance Default GradOptions where
def :: GradOptions
def = Bool -> Bool -> Bool -> GradOptions
GradOptions Bool
True Bool
False Bool
False
grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad Tensor
y [IndependentTensor]
inputs = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor -> [Tensor] -> IO [Tensor]
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
Torch.Internal.Managed.Autograd.grad Tensor
y ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
inputs)
gradWithOptions :: GradOptions -> Tensor -> [IndependentTensor] -> [Tensor]
gradWithOptions :: GradOptions -> Tensor -> [IndependentTensor] -> [Tensor]
gradWithOptions GradOptions{Bool
keepGraph :: GradOptions -> Bool
createGraph :: GradOptions -> Bool
accumulateGrad :: GradOptions -> Bool
keepGraph :: Bool
createGraph :: Bool
accumulateGrad :: Bool
..} Tensor
y [IndependentTensor]
inputs = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (Bool
-> Bool
-> Bool
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> IO (ForeignPtr TensorList))
-> Bool -> Bool -> Bool -> Tensor -> [Tensor] -> IO [Tensor]
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 Bool
-> Bool
-> Bool
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> IO (ForeignPtr TensorList)
Torch.Internal.Managed.Autograd.gradWithOptions Bool
keepGraph Bool
createGraph Bool
accumulateGrad Tensor
y ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
inputs)
requiresGrad :: Tensor -> Bool
requiresGrad :: Tensor -> Bool
requiresGrad Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor
t
setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad Bool
flag Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Bool -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor
t Bool
flag
makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent Tensor
tensor = Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
True
makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
requires_grad = Tensor -> IndependentTensor
IndependentTensor (Tensor -> IndependentTensor) -> IO Tensor -> IO IndependentTensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Bool -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
Torch.Internal.Managed.Autograd.makeIndependent Tensor
tensor Bool
requires_grad