{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : Sel.Internal.Scoped
-- Description : Continuation-passing utilities
-- Copyright   : (c) Jack Henahan, 2024
-- License     : BSD-3-Clause
-- Maintainer  : The Haskell Cryptography Group
-- Portability : GHC only
--
-- This module implements a version of @Codensity@, modeling delimited
-- continuations. Useful for avoiding extreme rightward drift in
-- chains of @withForeignPtr@ and friends.
module Sel.Internal.Scoped where

import Control.Monad (ap, void)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Trans.Class (MonadTrans (lift))
import Data.Kind (Type)
import Data.Type.Equality (type (~~))
import GHC.Exts (RuntimeRep, TYPE)

-- | @since 0.0.3.0
type Scoped :: forall {k} {rep :: RuntimeRep}. (k -> TYPE rep) -> Type -> Type
newtype Scoped m a = Scoped {forall {k} (m :: k -> *) a.
Scoped m a -> forall (b :: k). (a -> m b) -> m b
runScoped :: forall b. (a -> m b) -> m b}

-- | @since 0.0.3.0
instance Functor (Scoped f) where
  fmap :: forall a b. (a -> b) -> Scoped f a -> Scoped f b
fmap a -> b
f (Scoped forall (b :: k). (a -> f b) -> f b
m) = (forall (b :: k). (b -> f b) -> f b) -> Scoped f b
forall {k} (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Scoped m a
Scoped ((forall (b :: k). (b -> f b) -> f b) -> Scoped f b)
-> (forall (b :: k). (b -> f b) -> f b) -> Scoped f b
forall a b. (a -> b) -> a -> b
$ \b -> f b
k -> (a -> f b) -> f b
forall (b :: k). (a -> f b) -> f b
m (b -> f b
k (b -> f b) -> (a -> b) -> a -> f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)
  {-# INLINE fmap #-}

-- | @since 0.0.3.0
instance Applicative (Scoped f) where
  pure :: forall a. a -> Scoped f a
pure a
a = (forall (b :: k). (a -> f b) -> f b) -> Scoped f a
forall {k} (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Scoped m a
Scoped ((forall (b :: k). (a -> f b) -> f b) -> Scoped f a)
-> (forall (b :: k). (a -> f b) -> f b) -> Scoped f a
forall a b. (a -> b) -> a -> b
$ \a -> f b
k -> a -> f b
k a
a
  {-# INLINE pure #-}

  <*> :: forall a b. Scoped f (a -> b) -> Scoped f a -> Scoped f b
(<*>) = Scoped f (a -> b) -> Scoped f a -> Scoped f b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
  {-# INLINE (<*>) #-}

-- | @since 0.0.3.0
instance Monad (Scoped f) where
  Scoped forall (b :: k). (a -> f b) -> f b
m >>= :: forall a b. Scoped f a -> (a -> Scoped f b) -> Scoped f b
>>= a -> Scoped f b
f = (forall (b :: k). (b -> f b) -> f b) -> Scoped f b
forall {k} (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Scoped m a
Scoped ((forall (b :: k). (b -> f b) -> f b) -> Scoped f b)
-> (forall (b :: k). (b -> f b) -> f b) -> Scoped f b
forall a b. (a -> b) -> a -> b
$ \b -> f b
k ->
    (a -> f b) -> f b
forall (b :: k). (a -> f b) -> f b
m ((a -> f b) -> f b) -> (a -> f b) -> f b
forall a b. (a -> b) -> a -> b
$ \a
a -> Scoped f b -> forall (b :: k). (b -> f b) -> f b
forall {k} (m :: k -> *) a.
Scoped m a -> forall (b :: k). (a -> m b) -> m b
runScoped (a -> Scoped f b
f a
a) b -> f b
k
  {-# INLINE (>>=) #-}

-- | @since 0.0.3.0
instance (MonadIO m', m' ~~ m) => MonadIO (Scoped m) where
  liftIO :: forall a. IO a -> Scoped m a
liftIO = m a -> Scoped m a
m a -> Scoped m a
forall (m :: * -> *) a. Monad m => m a -> Scoped m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Scoped m a) -> (IO a -> m a) -> IO a -> Scoped m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
  {-# INLINE liftIO #-}

-- | @since 0.0.3.0
instance MonadTrans Scoped where
  lift :: forall (m :: * -> *) a. Monad m => m a -> Scoped m a
lift m a
m = (forall b. (a -> m b) -> m b) -> Scoped m a
forall {k} (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Scoped m a
Scoped (m a
m m a -> (a -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=)
  {-# INLINE lift #-}

-- | @since 0.0.3.0
reset :: Monad m => Scoped m a -> Scoped m a
reset :: forall (m :: * -> *) a. Monad m => Scoped m a -> Scoped m a
reset = m a -> Scoped m a
forall (m :: * -> *) a. Monad m => m a -> Scoped m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Scoped m a)
-> (Scoped m a -> m a) -> Scoped m a -> Scoped m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scoped m a -> m a
forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use

-- | @since 0.0.3.0
shift :: Applicative m => (forall b. (a -> m b) -> Scoped m b) -> Scoped m a
shift :: forall (m :: * -> *) a.
Applicative m =>
(forall b. (a -> m b) -> Scoped m b) -> Scoped m a
shift forall b. (a -> m b) -> Scoped m b
f = (forall b. (a -> m b) -> m b) -> Scoped m a
forall {k} (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Scoped m a
Scoped ((forall b. (a -> m b) -> m b) -> Scoped m a)
-> (forall b. (a -> m b) -> m b) -> Scoped m a
forall a b. (a -> b) -> a -> b
$ Scoped m b -> m b
forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use (Scoped m b -> m b)
-> ((a -> m b) -> Scoped m b) -> (a -> m b) -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m b) -> Scoped m b
forall b. (a -> m b) -> Scoped m b
f

-- | @since 0.0.3.0
use :: Applicative m => Scoped m a -> m a
use :: forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use (Scoped forall b. (a -> m b) -> m b
m) = (a -> m a) -> m a
forall b. (a -> m b) -> m b
m a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | @since 0.0.3.0
useM :: Monad m => Scoped m (m a) -> m a
useM :: forall (m :: * -> *) a. Monad m => Scoped m (m a) -> m a
useM Scoped m (m a)
f = Scoped m a -> m a
forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use (Scoped m a -> m a) -> Scoped m a -> m a
forall a b. (a -> b) -> a -> b
$ Scoped m (m a)
f Scoped m (m a) -> (m a -> Scoped m a) -> Scoped m a
forall a b. Scoped m a -> (a -> Scoped m b) -> Scoped m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m a -> Scoped m a
forall (m :: * -> *) a. Monad m => m a -> Scoped m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | @since 0.0.3.0
use_ :: Applicative m => Scoped m a -> m ()
use_ :: forall (m :: * -> *) a. Applicative m => Scoped m a -> m ()
use_ = m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m a -> m ()) -> (Scoped m a -> m a) -> Scoped m a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scoped m a -> m a
forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use

-- | @since 0.0.3.0
useM_ :: Monad m => Scoped m (m a) -> m ()
useM_ :: forall (m :: * -> *) a. Monad m => Scoped m (m a) -> m ()
useM_ = m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m a -> m ()) -> (Scoped m (m a) -> m a) -> Scoped m (m a) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scoped m (m a) -> m a
forall (m :: * -> *) a. Monad m => Scoped m (m a) -> m a
useM