{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.Typed.Autograd
  ( Torch.Typed.Autograd.HasGrad,
    Torch.Typed.Autograd.grad,
  )
where

import Data.Kind
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import qualified Torch.Tensor as D
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Torch.Autograd (GradOptions(..))


class HasGrad a b | a -> b where
  -- | calculate gradients of a zero-dimensional tensor with respect to a list of parameters
  grad :: forall dtype device. Tensor device dtype '[] -> a -> b
  gradWithOptions :: forall dtype device. GradOptions -> Tensor device dtype '[] -> a -> b

  toDependent :: a -> b

-- instance HasGrad (Tensor device dtype shape) (Tensor device dtype shape) where
--   grad loss input = head . unsafePerformIO $ ATen.cast2
--     Torch.Managed.Autograd.grad
--     loss
--     [Torch.Typed.Autograd.toDependent input]
--   toDependent = id

instance HasGrad (Parameter device dtype shape) (Tensor device dtype shape) where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> Parameter device dtype shape -> Tensor device dtype shape
grad Tensor device dtype '[]
loss Parameter device dtype shape
input =
    [Tensor device dtype shape] -> Tensor device dtype shape
forall a. HasCallStack => [a] -> a
head ([Tensor device dtype shape] -> Tensor device dtype shape)
-> (IO [Tensor device dtype shape] -> [Tensor device dtype shape])
-> IO [Tensor device dtype shape]
-> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO [Tensor device dtype shape] -> [Tensor device dtype shape]
forall a. IO a -> a
unsafePerformIO (IO [Tensor device dtype shape] -> Tensor device dtype shape)
-> IO [Tensor device dtype shape] -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$
      (ForeignPtr Tensor
 -> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor device dtype '[]
-> [Tensor device dtype shape]
-> IO [Tensor device dtype shape]
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
        ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
        Tensor device dtype '[]
loss
        [Parameter device dtype shape -> Tensor device dtype shape
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent Parameter device dtype shape
input]
  gradWithOptions :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
GradOptions
-> Tensor device dtype '[]
-> Parameter device dtype shape
-> Tensor device dtype shape
gradWithOptions GradOptions{Bool
keepGraph :: Bool
createGraph :: Bool
accumulateGrad :: Bool
accumulateGrad :: GradOptions -> Bool
createGraph :: GradOptions -> Bool
keepGraph :: GradOptions -> Bool
..} Tensor device dtype '[]
loss Parameter device dtype shape
input =
    [Tensor device dtype shape] -> Tensor device dtype shape
forall a. HasCallStack => [a] -> a
head ([Tensor device dtype shape] -> Tensor device dtype shape)
-> (IO [Tensor device dtype shape] -> [Tensor device dtype shape])
-> IO [Tensor device dtype shape]
-> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO [Tensor device dtype shape] -> [Tensor device dtype shape]
forall a. IO a -> a
unsafePerformIO (IO [Tensor device dtype shape] -> Tensor device dtype shape)
-> IO [Tensor device dtype shape] -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$
      (Bool
 -> Bool
 -> Bool
 -> ForeignPtr Tensor
 -> ForeignPtr TensorList
 -> IO (ForeignPtr TensorList))
-> Bool
-> Bool
-> Bool
-> Tensor device dtype '[]
-> [Tensor device dtype shape]
-> IO [Tensor device dtype shape]
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
        Bool
-> Bool
-> Bool
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> IO (ForeignPtr TensorList)
LibTorch.gradWithOptions
        Bool
keepGraph
        Bool
createGraph
        Bool
accumulateGrad
        Tensor device dtype '[]
loss
        [Parameter device dtype shape -> Tensor device dtype shape
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent Parameter device dtype shape
input]
  toDependent :: Parameter device dtype shape -> Tensor device dtype shape
toDependent = Parameter device dtype shape -> Tensor device dtype shape
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
Torch.Typed.Parameter.toDependent

instance HasGrad (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList '[] -> HList '[]
grad Tensor device dtype '[]
_ = HList '[] -> HList '[]
forall a. a -> a
id
  toDependent :: HList '[] -> HList '[]
toDependent = HList '[] -> HList '[]
forall a. a -> a
id

instance
  ( HasGrad a b,
    HasGrad (HList as) (HList bs),
    ATen.Castable (HList (b ': bs)) [D.ATenTensor]
  ) =>
  HasGrad (HList (a ': as)) (HList (b ': bs))
  where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList (a : as) -> HList (b : bs)
grad Tensor device dtype '[]
loss HList (a : as)
inputs =
    IO (HList (b : bs)) -> HList (b : bs)
forall a. IO a -> a
unsafePerformIO (IO (HList (b : bs)) -> HList (b : bs))
-> IO (HList (b : bs)) -> HList (b : bs)
forall a b. (a -> b) -> a -> b
$
      (ForeignPtr Tensor
 -> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor device dtype '[] -> HList (b : bs) -> IO (HList (b : bs))
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
        ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
        Tensor device dtype '[]
loss
        (HList (a : as) -> HList (b : bs)
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList (a : as)
inputs)
  gradWithOptions :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
GradOptions
-> Tensor device dtype '[] -> HList (a : as) -> HList (b : bs)
gradWithOptions GradOptions{Bool
accumulateGrad :: GradOptions -> Bool
createGraph :: GradOptions -> Bool
keepGraph :: GradOptions -> Bool
keepGraph :: Bool
createGraph :: Bool
accumulateGrad :: Bool
..} Tensor device dtype '[]
loss HList (a : as)
inputs =
    IO (HList (b : bs)) -> HList (b : bs)
forall a. IO a -> a
unsafePerformIO (IO (HList (b : bs)) -> HList (b : bs))
-> IO (HList (b : bs)) -> HList (b : bs)
forall a b. (a -> b) -> a -> b
$
      (Bool
 -> Bool
 -> Bool
 -> ForeignPtr Tensor
 -> ForeignPtr TensorList
 -> IO (ForeignPtr TensorList))
-> Bool
-> Bool
-> Bool
-> Tensor device dtype '[]
-> HList (b : bs)
-> IO (HList (b : bs))
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
ATen.cast5
        Bool
-> Bool
-> Bool
-> ForeignPtr Tensor
-> ForeignPtr TensorList
-> IO (ForeignPtr TensorList)
LibTorch.gradWithOptions
        Bool
keepGraph
        Bool
createGraph
        Bool
accumulateGrad
        Tensor device dtype '[]
loss
        (HList (a : as) -> HList (b : bs)
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList (a : as)
inputs)
  toDependent :: HList (a : as) -> HList (b : bs)
toDependent (a
a :. HList as
as) =
    a -> b
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent a
a b -> HList bs -> HList (b : bs)
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList as -> HList bs
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList as
as