{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Grisette.Internal.Core.Data.Class.SafeLogBase
( SafeLogBase (..),
LogBaseOr (..),
logBaseOrZero,
)
where
import Control.Exception (ArithException (RatioZeroDenominator))
import Control.Monad.Error.Class (MonadError (throwError))
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.AsKey (AsKey (AsKey))
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp (symIte))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.SymEq (SymEq ((.==)))
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, mrgSingle)
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal)
class LogBaseOr a where
logBaseOr :: a -> a -> a -> a
logBaseOrZero :: (LogBaseOr a, Num a) => a -> a -> a
logBaseOrZero :: forall a. (LogBaseOr a, Num a) => a -> a -> a
logBaseOrZero a
l = a -> a -> a -> a
forall a. LogBaseOr a => a -> a -> a -> a
logBaseOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l) a
l
{-# INLINE logBaseOrZero #-}
class (MonadError e m, TryMerge m, Mergeable a) => SafeLogBase e a m where
safeLogBase :: a -> a -> m a
safeLogBase = a -> a -> m a
forall a. HasCallStack => a
undefined
{-# INLINE safeLogBase #-}
instance LogBaseOr SymAlgReal where
logBaseOr :: SymAlgReal -> SymAlgReal -> SymAlgReal -> SymAlgReal
logBaseOr SymAlgReal
d SymAlgReal
base SymAlgReal
a = SymBool -> SymAlgReal -> SymAlgReal -> SymAlgReal
forall v. ITEOp v => SymBool -> v -> v -> v
symIte (SymAlgReal
base SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== SymAlgReal
1) SymAlgReal
d (SymAlgReal -> SymAlgReal) -> SymAlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$ SymAlgReal -> SymAlgReal -> SymAlgReal
forall a. Floating a => a -> a -> a
logBase SymAlgReal
base SymAlgReal
a
{-# INLINE logBaseOr #-}
instance
(MonadError ArithException m, MonadUnion m) =>
SafeLogBase ArithException SymAlgReal m
where
safeLogBase :: SymAlgReal -> SymAlgReal -> m SymAlgReal
safeLogBase SymAlgReal
base SymAlgReal
a =
SymBool -> m SymAlgReal -> m SymAlgReal -> m SymAlgReal
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymAlgReal
base SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== SymAlgReal
1) (ArithException -> m SymAlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator) (m SymAlgReal -> m SymAlgReal) -> m SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$ SymAlgReal -> m SymAlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymAlgReal -> m SymAlgReal) -> SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$ SymAlgReal -> SymAlgReal -> SymAlgReal
forall a. Floating a => a -> a -> a
logBase SymAlgReal
base SymAlgReal
a
{-# INLINE safeLogBase #-}
instance (LogBaseOr a) => LogBaseOr (AsKey a) where
logBaseOr :: AsKey a -> AsKey a -> AsKey a -> AsKey a
logBaseOr (AsKey a
d) (AsKey a
base) (AsKey a
a) =
a -> AsKey a
forall a. a -> AsKey a
AsKey (a -> AsKey a) -> a -> AsKey a
forall a b. (a -> b) -> a -> b
$ a -> a -> a -> a
forall a. LogBaseOr a => a -> a -> a -> a
logBaseOr a
d a
base a
a
{-# INLINE logBaseOr #-}
instance (SafeLogBase e a m) => SafeLogBase e (AsKey a) m where
safeLogBase :: AsKey a -> AsKey a -> m (AsKey a)
safeLogBase (AsKey a
base) (AsKey a
a) = do
a
r <- a -> a -> m a
forall e a (m :: * -> *). SafeLogBase e a m => a -> a -> m a
safeLogBase a
base a
a
AsKey a -> m (AsKey a)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (AsKey a -> m (AsKey a)) -> AsKey a -> m (AsKey a)
forall a b. (a -> b) -> a -> b
$ a -> AsKey a
forall a. a -> AsKey a
AsKey a
r
{-# INLINE safeLogBase #-}