{-# LANGUAGE UndecidableInstances #-}

module Control.Monad.Trans.Codensity
  ( CodensityT (..)
  ) where

import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Reader (MonadReader (ask, local))
import Control.Monad.State (MonadState (state))
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Yield.Class (MonadYield (..))
import Data.Kind (Type)
import Prelude (Applicative (..), Functor (..), Monad (..), const, ($), (.))

-- | The "codensity transform", taken as an excerpt from the @kan-extensions@
--   package, following the model of @conduit@ to permit efficient monadic bind
newtype CodensityT (m :: Type -> Type) (a :: Type) = CodensityT {forall (m :: * -> *) a.
CodensityT m a -> forall b. (a -> m b) -> m b
runCodensity :: forall b. (a -> m b) -> m b}

instance Functor (CodensityT m) where
  fmap :: forall a b. (a -> b) -> CodensityT m a -> CodensityT m b
fmap a -> b
f (CodensityT forall b. (a -> m b) -> m b
m) = (forall b. (b -> m b) -> m b) -> CodensityT m b
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT (\b -> m b
k -> (a -> m b) -> m b
forall b. (a -> m b) -> m b
m (b -> m b
k (b -> m b) -> (a -> b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f))

instance Applicative (CodensityT m) where
  pure :: forall a. a -> CodensityT m a
pure a
x = (forall b. (a -> m b) -> m b) -> CodensityT m a
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT (\a -> m b
k -> a -> m b
k a
x)
  CodensityT forall b. ((a -> b) -> m b) -> m b
f <*> :: forall a b.
CodensityT m (a -> b) -> CodensityT m a -> CodensityT m b
<*> CodensityT forall b. (a -> m b) -> m b
g = (forall b. (b -> m b) -> m b) -> CodensityT m b
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT (\b -> m b
bfr -> ((a -> b) -> m b) -> m b
forall b. ((a -> b) -> m b) -> m b
f (\a -> b
ab -> (a -> m b) -> m b
forall b. (a -> m b) -> m b
g (b -> m b
bfr (b -> m b) -> (a -> b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
ab)))

instance Monad (CodensityT m) where
  CodensityT m a
m >>= :: forall a b.
CodensityT m a -> (a -> CodensityT m b) -> CodensityT m b
>>= a -> CodensityT m b
k = (forall b. (b -> m b) -> m b) -> CodensityT m b
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT (\b -> m b
c -> CodensityT m a -> forall b. (a -> m b) -> m b
forall (m :: * -> *) a.
CodensityT m a -> forall b. (a -> m b) -> m b
runCodensity CodensityT m a
m (\a
a -> CodensityT m b -> forall b. (b -> m b) -> m b
forall (m :: * -> *) a.
CodensityT m a -> forall b. (a -> m b) -> m b
runCodensity (a -> CodensityT m b
k a
a) b -> m b
c))

instance MonadTrans CodensityT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> CodensityT m a
lift m a
m = (forall b. (a -> m b) -> m b) -> CodensityT m a
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT (m a
m >>=)

instance MonadIO m => MonadIO (CodensityT m) where
  liftIO :: forall a. IO a -> CodensityT m a
liftIO = m a -> CodensityT m a
forall (m :: * -> *) a. Monad m => m a -> CodensityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> CodensityT m a) -> (IO a -> m a) -> IO a -> CodensityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance MonadYield a m => MonadYield a (CodensityT m) where
  yield :: a -> CodensityT m ()
yield = m () -> CodensityT m ()
forall (m :: * -> *) a. Monad m => m a -> CodensityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> CodensityT m ()) -> (a -> m ()) -> a -> CodensityT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m ()
forall a (m :: * -> *). MonadYield a m => a -> m ()
yield

instance MonadReader r m => MonadReader r (CodensityT m) where
  ask :: CodensityT m r
ask = (forall b. (r -> m b) -> m b) -> CodensityT m r
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT (m r
forall r (m :: * -> *). MonadReader r m => m r
ask >>=)
  local :: forall a. (r -> r) -> CodensityT m a -> CodensityT m a
local r -> r
f CodensityT m a
m = (forall b. (a -> m b) -> m b) -> CodensityT m a
forall (m :: * -> *) a.
(forall b. (a -> m b) -> m b) -> CodensityT m a
CodensityT ((forall b. (a -> m b) -> m b) -> CodensityT m a)
-> (forall b. (a -> m b) -> m b) -> CodensityT m a
forall a b. (a -> b) -> a -> b
$ \a -> m b
c -> m r
forall r (m :: * -> *). MonadReader r m => m r
ask m r -> (r -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \r
r -> (r -> r) -> m b -> m b
forall a. (r -> r) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f (m b -> m b) -> ((a -> m b) -> m b) -> (a -> m b) -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodensityT m a -> forall b. (a -> m b) -> m b
forall (m :: * -> *) a.
CodensityT m a -> forall b. (a -> m b) -> m b
runCodensity CodensityT m a
m ((a -> m b) -> m b) -> (a -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ (r -> r) -> m b -> m b
forall a. (r -> r) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (r -> r -> r
forall a b. a -> b -> a
const r
r) (m b -> m b) -> (a -> m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
c

instance MonadState s m => MonadState s (CodensityT m) where
  state :: forall a. (s -> (a, s)) -> CodensityT m a
state = m a -> CodensityT m a
forall (m :: * -> *) a. Monad m => m a -> CodensityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> CodensityT m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> CodensityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall a. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state