{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SafeLogBase
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SafeLogBase
  ( SafeLogBase (..),
    LogBaseOr (..),
    logBaseOrZero,
  )
where

import Control.Exception (ArithException (RatioZeroDenominator))
import Control.Monad.Error.Class (MonadError (throwError))
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.AsKey (AsKey (AsKey))
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp (symIte))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.SymEq (SymEq ((.==)))
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, mrgSingle)
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Control.Monad.Except
-- >>> import Control.Exception

-- | Safe 'logBase' with default values returned on exception.
class LogBaseOr a where
  -- | Safe 'logBase' with default values returned on exception.
  --
  -- >>> logBaseOr "d" "base" "val" :: SymAlgReal
  -- (ite (= base 1.0) d (fdiv (log val) (log base)))
  logBaseOr :: a -> a -> a -> a

-- | Safe 'logBase' with 0 returned on exception.
logBaseOrZero :: (LogBaseOr a, Num a) => a -> a -> a
logBaseOrZero :: forall a. (LogBaseOr a, Num a) => a -> a -> a
logBaseOrZero a
l = a -> a -> a -> a
forall a. LogBaseOr a => a -> a -> a -> a
logBaseOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l) a
l
{-# INLINE logBaseOrZero #-}

-- | Safe 'logBase' with monadic error handling in multi-path execution.
-- These procedures throw an exception when the base is 1.
-- The result should be able to handle errors with `MonadError`.
class (MonadError e m, TryMerge m, Mergeable a) => SafeLogBase e a m where
  -- | Safe 'logBase' with monadic error handling in multi-path execution.
  --
  -- >>> safeLogBase (ssym "base") (ssym "val") :: ExceptT ArithException Union SymAlgReal
  -- ExceptT {If (= base 1.0) (Left Ratio has zero denominator) (Right (fdiv (log val) (log base)))}
  safeLogBase :: a -> a -> m a
  safeLogBase = a -> a -> m a
forall a. HasCallStack => a
undefined
  {-# INLINE safeLogBase #-}

instance LogBaseOr SymAlgReal where
  logBaseOr :: SymAlgReal -> SymAlgReal -> SymAlgReal -> SymAlgReal
logBaseOr SymAlgReal
d SymAlgReal
base SymAlgReal
a = SymBool -> SymAlgReal -> SymAlgReal -> SymAlgReal
forall v. ITEOp v => SymBool -> v -> v -> v
symIte (SymAlgReal
base SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== SymAlgReal
1) SymAlgReal
d (SymAlgReal -> SymAlgReal) -> SymAlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$ SymAlgReal -> SymAlgReal -> SymAlgReal
forall a. Floating a => a -> a -> a
logBase SymAlgReal
base SymAlgReal
a
  {-# INLINE logBaseOr #-}

instance
  (MonadError ArithException m, MonadUnion m) =>
  SafeLogBase ArithException SymAlgReal m
  where
  safeLogBase :: SymAlgReal -> SymAlgReal -> m SymAlgReal
safeLogBase SymAlgReal
base SymAlgReal
a =
    SymBool -> m SymAlgReal -> m SymAlgReal -> m SymAlgReal
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymAlgReal
base SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== SymAlgReal
1) (ArithException -> m SymAlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator) (m SymAlgReal -> m SymAlgReal) -> m SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$ SymAlgReal -> m SymAlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymAlgReal -> m SymAlgReal) -> SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$ SymAlgReal -> SymAlgReal -> SymAlgReal
forall a. Floating a => a -> a -> a
logBase SymAlgReal
base SymAlgReal
a
  {-# INLINE safeLogBase #-}

instance (LogBaseOr a) => LogBaseOr (AsKey a) where
  logBaseOr :: AsKey a -> AsKey a -> AsKey a -> AsKey a
logBaseOr (AsKey a
d) (AsKey a
base) (AsKey a
a) =
    a -> AsKey a
forall a. a -> AsKey a
AsKey (a -> AsKey a) -> a -> AsKey a
forall a b. (a -> b) -> a -> b
$ a -> a -> a -> a
forall a. LogBaseOr a => a -> a -> a -> a
logBaseOr a
d a
base a
a
  {-# INLINE logBaseOr #-}

instance (SafeLogBase e a m) => SafeLogBase e (AsKey a) m where
  safeLogBase :: AsKey a -> AsKey a -> m (AsKey a)
safeLogBase (AsKey a
base) (AsKey a
a) = do
    a
r <- a -> a -> m a
forall e a (m :: * -> *). SafeLogBase e a m => a -> a -> m a
safeLogBase a
base a
a
    AsKey a -> m (AsKey a)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (AsKey a -> m (AsKey a)) -> AsKey a -> m (AsKey a)
forall a b. (a -> b) -> a -> b
$ a -> AsKey a
forall a. a -> AsKey a
AsKey a
r
  {-# INLINE safeLogBase #-}