module Control.Monad.Ology.Data.Ref where

import Control.Monad.Ology.Data.Param
import Control.Monad.Ology.Data.Prod
import Control.Monad.Ology.General
import Control.Monad.Ology.Specific.StateT
import Control.Monad.ST.Lazy qualified as Lazy
import Control.Monad.ST.Strict qualified as Strict
import Data.IORef
import Data.STRef.Lazy qualified as Lazy
import Data.STRef.Strict qualified as Strict
import Import

-- | A reference of a monad (as in 'StateT').
data Ref m a = MkRef
    { forall (m :: Type -> Type) a. Ref m a -> m a
refGet :: m a
    , forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut :: a -> m ()
    }

instance Functor m => Invariant (Ref m) where
    invmap :: forall a b. (a -> b) -> (b -> a) -> Ref m a -> Ref m b
invmap a -> b
f b -> a
g (MkRef m a
gt a -> m ()
pt) = m b -> (b -> m ()) -> Ref m b
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef ((a -> b) -> m a -> m b
forall a b. (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f m a
gt) (a -> m ()
pt (a -> m ()) -> (b -> a) -> b -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> a
g)

instance Applicative m => Productable (Ref m) where
    rUnit :: Ref m ()
rUnit = m () -> (() -> m ()) -> Ref m ()
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (() -> m ()
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()) (\()
_ -> () -> m ()
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ())
    Ref m a
ra <***> :: forall a b. Ref m a -> Ref m b -> Ref m (a, b)
<***> Ref m b
rb = m (a, b) -> ((a, b) -> m ()) -> Ref m (a, b)
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef ((a -> b -> (a, b)) -> m a -> m b -> m (a, b)
forall a b c. (a -> b -> c) -> m a -> m b -> m c
forall (f :: Type -> Type) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ra) (Ref m b -> m b
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m b
rb)) (((a, b) -> m ()) -> Ref m (a, b))
-> ((a, b) -> m ()) -> Ref m (a, b)
forall a b. (a -> b) -> a -> b
$ \(a
a, b
b) -> Ref m a -> a -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ra a
a m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: Type -> Type) a b. Applicative f => f a -> f b -> f b
*> Ref m b -> b -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m b
rb b
b

refModify :: Monad m => Ref m a -> (a -> a) -> m ()
refModify :: forall (m :: Type -> Type) a.
Monad m =>
Ref m a -> (a -> a) -> m ()
refModify Ref m a
ref a -> a
f = do
    a
a <- Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    Ref m a -> a -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ a -> a
f a
a

refModifyM :: Monad m => Ref m a -> (a -> m a) -> m ()
refModifyM :: forall (m :: Type -> Type) a.
Monad m =>
Ref m a -> (a -> m a) -> m ()
refModifyM Ref m a
ref a -> m a
f = do
    a
a <- Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    a
a' <- a -> m a
f a
a
    Ref m a -> a -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref a
a'

-- | Restore the original value of this reference after the operation.
refRestore :: MonadException m => Ref m a -> m --> m
refRestore :: forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> m --> m
refRestore Ref m a
ref m a
mr = m a -> (a -> m ()) -> (a -> m a) -> m a
forall (m :: Type -> Type) a b.
MonadException m =>
m a -> (a -> m ()) -> (a -> m b) -> m b
bracketNoMask (Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref) (Ref m a -> a -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref) ((a -> m a) -> m a) -> (a -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \a
_ -> m a
mr

-- | Put and restore the original value of this reference after the operation.
refPutRestore :: MonadException m => Ref m a -> a -> m --> m
refPutRestore :: forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> a -> m --> m
refPutRestore Ref m a
ref a
a m a
mr =
    Ref m a -> m --> m
forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> m --> m
refRestore Ref m a
ref (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
        Ref m a -> a -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref a
a
        m a
mr

lensMapRef ::
       forall m a b. Monad m
    => Lens' a b
    -> Ref m a
    -> Ref m b
lensMapRef :: forall (m :: Type -> Type) a b.
Monad m =>
Lens' a b -> Ref m a -> Ref m b
lensMapRef Lens' a b
l Ref m a
ref = let
    refGet' :: m b
refGet' = (a -> b) -> m a -> m b
forall a b. (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
a -> Const b a -> b
forall {k} a (b :: k). Const a b -> a
getConst (Const b a -> b) -> Const b a -> b
forall a b. (a -> b) -> a -> b
$ (b -> Const b b) -> a -> Const b a
Lens' a b
l b -> Const b b
forall {k} a (b :: k). a -> Const a b
Const a
a) (m a -> m b) -> m a -> m b
forall a b. (a -> b) -> a -> b
$ Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    refPut' :: b -> m ()
refPut' b
b = do
        a
a <- Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
        Ref m a -> a -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> Identity a -> a
forall a b. (a -> b) -> a -> b
$ (b -> Identity b) -> a -> Identity a
Lens' a b
l (\b
_ -> b -> Identity b
forall a. a -> Identity a
Identity b
b) a
a
    in m b -> (b -> m ()) -> Ref m b
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef m b
refGet' b -> m ()
refPut'

liftRef :: (MonadTrans t, Monad m) => Ref m --> Ref (t m)
liftRef :: forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type).
(MonadTrans t, Monad m) =>
Ref m --> Ref (t m)
liftRef (MkRef m a
g a -> m ()
m) = t m a -> (a -> t m ()) -> Ref (t m) a
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (m a -> t m a
forall (m :: Type -> Type) a. Monad m => m a -> t m a
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m a
g) ((a -> t m ()) -> Ref (t m) a) -> (a -> t m ()) -> Ref (t m) a
forall a b. (a -> b) -> a -> b
$ \a
a -> m () -> t m ()
forall (m :: Type -> Type) a. Monad m => m a -> t m a
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> m () -> t m ()
forall a b. (a -> b) -> a -> b
$ a -> m ()
m a
a

stateRef :: Monad m => Ref (StateT s m) s
stateRef :: forall (m :: Type -> Type) s. Monad m => Ref (StateT s m) s
stateRef = StateT s m s -> (s -> StateT s m ()) -> Ref (StateT s m) s
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef StateT s m s
forall (m :: Type -> Type) s. Monad m => StateT s m s
get s -> StateT s m ()
forall (m :: Type -> Type) s. Monad m => s -> StateT s m ()
put

-- | Run a state monad over this reference.
refRunState :: Monad m => Ref m s -> StateT s m --> m
refRunState :: forall (m :: Type -> Type) s.
Monad m =>
Ref m s -> StateT s m --> m
refRunState Ref m s
ref StateT s m a
sm = do
    s
olds <- Ref m s -> m s
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m s
ref
    (a
a, s
news) <- StateT s m a -> s -> m (a, s)
forall s (m :: Type -> Type) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
sm s
olds
    Ref m s -> s -> m ()
forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m s
ref s
news
    a -> m a
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return a
a

ioRef :: IORef a -> Ref IO a
ioRef :: forall a. IORef a -> Ref IO a
ioRef IORef a
r = IO a -> (a -> IO ()) -> Ref IO a
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
r) (IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
r)

strictSTRef :: Strict.STRef s a -> Ref (Strict.ST s) a
strictSTRef :: forall s a. STRef s a -> Ref (ST s) a
strictSTRef STRef s a
r = ST s a -> (a -> ST s ()) -> Ref (ST s) a
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (STRef s a -> ST s a
forall s a. STRef s a -> ST s a
Strict.readSTRef STRef s a
r) (STRef s a -> a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
Strict.writeSTRef STRef s a
r)

lazySTRef :: Lazy.STRef s a -> Ref (Lazy.ST s) a
lazySTRef :: forall s a. STRef s a -> Ref (ST s) a
lazySTRef STRef s a
r = ST s a -> (a -> ST s ()) -> Ref (ST s) a
forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (STRef s a -> ST s a
forall s a. STRef s a -> ST s a
Lazy.readSTRef STRef s a
r) (STRef s a -> a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
Lazy.writeSTRef STRef s a
r)

-- | Use a reference as a parameter.
refParam ::
       forall m a. MonadException m
    => Ref m a
    -> Param m a
refParam :: forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> Param m a
refParam Ref m a
ref = let
    paramAsk :: m a
paramAsk = Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    paramWith :: a -> m --> m
    paramWith :: a -> m --> m
paramWith = Ref m a -> a -> m --> m
forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> a -> m --> m
refPutRestore Ref m a
ref
    in MkParam {m a
a -> m --> m
paramAsk :: m a
paramWith :: a -> m --> m
paramWith :: a -> m --> m
paramAsk :: m a
..}

-- | Use a reference as a product.
refProd ::
       forall m a. (MonadException m, Monoid a)
    => Ref m a
    -> Prod m a
refProd :: forall (m :: Type -> Type) a.
(MonadException m, Monoid a) =>
Ref m a -> Prod m a
refProd Ref m a
ref = let
    prodTell :: a -> m ()
prodTell a
a = Ref m a -> (a -> a) -> m ()
forall (m :: Type -> Type) a.
Monad m =>
Ref m a -> (a -> a) -> m ()
refModify Ref m a
ref ((a -> a) -> m ()) -> (a -> a) -> m ()
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>) a
a
    prodCollect :: forall r. m r -> m (r, a)
    prodCollect :: forall r. m r -> m (r, a)
prodCollect m r
mr =
        Ref m a -> a -> m --> m
forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> a -> m --> m
refPutRestore Ref m a
ref a
forall a. Monoid a => a
mempty (m (r, a) -> m (r, a)) -> m (r, a) -> m (r, a)
forall a b. (a -> b) -> a -> b
$ do
            r
r <- m r
mr
            a
a <- Ref m a -> m a
forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
            (r, a) -> m (r, a)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (r
r, a
a)
    in MkProd {a -> m ()
m r -> m (r, a)
forall r. m r -> m (r, a)
prodTell :: a -> m ()
prodCollect :: forall r. m r -> m (r, a)
prodCollect :: forall r. m r -> m (r, a)
prodTell :: a -> m ()
..}