{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.Autograd where

import Foreign.ForeignPtr
import GHC.Generics
import System.IO.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import qualified Torch.Internal.Managed.Autograd
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor
import Data.Default.Class


-- | Note: to create an `IndependentTensor` use `makeIndependent`;
-- | otherwise, Torch will complain the parameter does not require a gradient.
newtype IndependentTensor = IndependentTensor
  { IndependentTensor -> Tensor
toDependent :: Tensor
  }
  deriving (Int -> IndependentTensor -> ShowS
[IndependentTensor] -> ShowS
IndependentTensor -> String
(Int -> IndependentTensor -> ShowS)
-> (IndependentTensor -> String)
-> ([IndependentTensor] -> ShowS)
-> Show IndependentTensor
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IndependentTensor -> ShowS
showsPrec :: Int -> IndependentTensor -> ShowS
$cshow :: IndependentTensor -> String
show :: IndependentTensor -> String
$cshowList :: [IndependentTensor] -> ShowS
showList :: [IndependentTensor] -> ShowS
Show, (forall x. IndependentTensor -> Rep IndependentTensor x)
-> (forall x. Rep IndependentTensor x -> IndependentTensor)
-> Generic IndependentTensor
forall x. Rep IndependentTensor x -> IndependentTensor
forall x. IndependentTensor -> Rep IndependentTensor x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. IndependentTensor -> Rep IndependentTensor x
from :: forall x. IndependentTensor -> Rep IndependentTensor x
$cto :: forall x. Rep IndependentTensor x -> IndependentTensor
to :: forall x. Rep IndependentTensor x -> IndependentTensor
Generic)

data GradOptions = GradOptions
  { GradOptions -> Bool
keepGraph :: Bool
  , GradOptions -> Bool
createGraph :: Bool
  , GradOptions -> Bool
accumulateGrad :: Bool
  }
  deriving (Int -> GradOptions -> ShowS
[GradOptions] -> ShowS
GradOptions -> String
(Int -> GradOptions -> ShowS)
-> (GradOptions -> String)
-> ([GradOptions] -> ShowS)
-> Show GradOptions
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GradOptions -> ShowS
showsPrec :: Int -> GradOptions -> ShowS
$cshow :: GradOptions -> String
show :: GradOptions -> String
$cshowList :: [GradOptions] -> ShowS
showList :: [GradOptions] -> ShowS
Show)

instance Default GradOptions where
  def :: GradOptions
def = Bool -> Bool -> Bool -> GradOptions
GradOptions Bool
True Bool
False Bool
False

grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad Tensor
y [IndependentTensor]
inputs = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
 -> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor -> [Tensor] -> IO [Tensor]
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
Torch.Internal.Managed.Autograd.grad Tensor
y ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
inputs)

gradWithOptions :: GradOptions -> Tensor -> [IndependentTensor] -> [Tensor]
gradWithOptions :: GradOptions -> Tensor -> [IndependentTensor] -> [Tensor]
gradWithOptions GradOptions{Bool
keepGraph :: GradOptions -> Bool
createGraph :: GradOptions -> Bool
accumulateGrad :: GradOptions -> Bool
keepGraph :: Bool
createGraph :: Bool
accumulateGrad :: Bool
..} Tensor
y [IndependentTensor]
inputs = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (Bool
 -> Bool
 -> Bool
 -> ForeignPtr Tensor
 -> ForeignPtr TensorList
 -> IO (ForeignPtr TensorList))
-> Bool -> Bool -> Bool -> Tensor -> [Tensor] -> IO [Tensor]
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 Bool
-> Bool
-> Bool
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> IO (ForeignPtr TensorList)
Torch.Internal.Managed.Autograd.gradWithOptions Bool
keepGraph Bool
createGraph Bool
accumulateGrad Tensor
y ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
inputs)

requiresGrad :: Tensor -> Bool
requiresGrad :: Tensor -> Bool
requiresGrad Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor
t

setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad Bool
flag Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Bool -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor
t Bool
flag

makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent Tensor
tensor = Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
True

makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
requires_grad = Tensor -> IndependentTensor
IndependentTensor (Tensor -> IndependentTensor) -> IO Tensor -> IO IndependentTensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Bool -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
Torch.Internal.Managed.Autograd.makeIndependent Tensor
tensor Bool
requires_grad