module Control.Monad.Exception (
    E.Exception(..),
    E.SomeException,
    MonadException(..),
    onException,
    MonadAsyncException(..),
    bracket,
    bracket_,
    ExceptionT(..),
    mapExceptionT,
    liftException
  ) where
#if !MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif /*!MIN_VERSION_base(4,6,0) */
import Control.Applicative
import qualified Control.Exception as E (Exception(..),
                                         SomeException,
                                         catch,
                                         throw,
                                         finally)
import qualified Control.Exception as E (mask)
import Control.Monad (MonadPlus(..))
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Error (Error(..),
                                  ErrorT(..),
                                  mapErrorT,
                                  runErrorT)
import Control.Monad.Trans.Except (ExceptT(..),
                                   mapExceptT,
                                   runExceptT)
import Control.Monad.Trans.Identity (IdentityT(..),
                                     mapIdentityT,
                                     runIdentityT)
import Control.Monad.Trans.List (ListT(..),
                                 mapListT,
                                 runListT)
import Control.Monad.Trans.Maybe (MaybeT(..),
                                  mapMaybeT,
                                  runMaybeT)
import Control.Monad.Trans.RWS.Lazy as Lazy (RWST(..),
                                             mapRWST,
                                             runRWST)
import Control.Monad.Trans.RWS.Strict as Strict (RWST(..),
                                                 mapRWST,
                                                 runRWST)
import Control.Monad.Trans.Reader (ReaderT(..),
                                   mapReaderT)
import Control.Monad.Trans.State.Lazy as Lazy (StateT(..),
                                               mapStateT,
                                               runStateT)
import Control.Monad.Trans.State.Strict as Strict (StateT(..),
                                                   mapStateT,
                                                   runStateT)
import Control.Monad.Trans.Writer.Lazy as Lazy (WriterT(..),
                                                mapWriterT,
                                                runWriterT)
import Control.Monad.Trans.Writer.Strict as Strict (WriterT(..),
                                                    mapWriterT,
                                                    runWriterT)
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid (Monoid)
#endif /* !MIN_VERSION_base(4,8,0) */
import GHC.Conc.Sync (STM(..),
                      catchSTM,
                      throwSTM)
class (Monad m) => MonadException m where
    
    throw :: E.Exception e => e -> m a
    
    catch :: E.Exception e
          => m a        
          -> (e -> m a) 
          -> m a
    
    
    
    
    
    
    finally :: m a  
            -> m b  
                    
            -> m a
    act `finally` sequel = do
        a <- act `onException` sequel
        _ <- sequel
        return a
onException :: MonadException m
            => m a 
            -> m b 
            -> m a
onException act what =
    act `catch` \(e :: E.SomeException) -> what >> throw e
class (MonadIO m, MonadException m) => MonadAsyncException m where
    
    
    
    
    mask :: ((forall a. m a -> m a) -> m b) -> m b
bracket :: MonadAsyncException m
        => m a         
        -> (a -> m b)  
        -> (a -> m c)  
        -> m c         
bracket before after thing =
    mask $ \restore -> do
        a <- before
        restore (thing a) `finally` after a
bracket_ :: MonadAsyncException m
         => m a
         -> m b
         -> m c
         -> m c
bracket_ before after thing =
    bracket before (const after) (const thing)
newtype ExceptionT m a =
    ExceptionT { runExceptionT :: m (Either E.SomeException a) }
mapExceptionT :: (m (Either E.SomeException a) -> n (Either E.SomeException b))
              -> ExceptionT m a
              -> ExceptionT n b
mapExceptionT f = ExceptionT . f . runExceptionT
liftException :: MonadException m => Either E.SomeException a -> m a
liftException (Left e)  = throw e
liftException (Right a) = return a
instance MonadTrans ExceptionT where
    lift m = ExceptionT $ do
        a <- m
        return (Right a)
instance (Functor m) => Functor (ExceptionT m) where
    fmap f = ExceptionT . fmap (fmap f) . runExceptionT
instance (Monad m) => Monad (ExceptionT m) where
    return a = ExceptionT $ return (Right a)
    m >>= k  = ExceptionT $ do
        a <- runExceptionT m
        case a of
          Left l  -> return (Left l)
          Right r -> runExceptionT (k r)
    fail msg = ExceptionT $ return (Left (E.toException (userError msg)))
instance (Monad m) => MonadPlus (ExceptionT m) where
    mzero       = ExceptionT $ return (Left (E.toException (userError "")))
    m `mplus` n = ExceptionT $ do
        a <- runExceptionT m
        case a of
          Left _  -> runExceptionT n
          Right r -> return (Right r)
instance (Functor m, Monad m) => Applicative (ExceptionT m) where
    pure a  = ExceptionT $ return (Right a)
    f <*> v = ExceptionT $ do
        mf <- runExceptionT f
        case mf of
            Left  e -> return (Left e)
            Right k -> do
                mv <- runExceptionT v
                case mv of
                    Left  e -> return (Left e)
                    Right x -> return (Right (k x))
instance (Functor m, Monad m) => Alternative (ExceptionT m) where
    empty = mzero
    (<|>) = mplus
instance (MonadFix m) => MonadFix (ExceptionT m) where
    mfix f = ExceptionT $ mfix $ \a -> runExceptionT $ f $ case a of
        Right r -> r
        _       -> error "empty mfix argument"
instance (Monad m) => MonadException (ExceptionT m) where
    throw e     = ExceptionT $ return (Left (E.toException e))
    m `catch` h = ExceptionT $ do
        a <- runExceptionT m
        case a of
          Left l  ->  case E.fromException l of
                        Just e  -> runExceptionT (h e)
                        Nothing -> return (Left l)
          Right r -> return (Right r)
instance (MonadIO m) => MonadIO (ExceptionT m) where
    liftIO m = ExceptionT $ liftIO $
        fmap Right m `E.catch` \(e :: E.SomeException) -> return (Left e)
instance (MonadAsyncException m) => MonadAsyncException (ExceptionT m) where
    mask act = ExceptionT $ mask $ \restore ->
               runExceptionT $ act (mapExceptionT restore)
instance MonadException IO where
    catch   = E.catch
    throw   = E.throw
    finally = E.finally
#if __GLASGOW_HASKELL__ >= 700
instance MonadAsyncException IO where
    mask = E.mask
#else /* __GLASGOW_HASKELL__ < 700 */
instance MonadAsyncException IO where
    mask act = do
        b <- E.blocked
        if b
          then act id
          else E.block $ act E.unblock
#endif /* __GLASGOW_HASKELL__ < 700 */
instance MonadException STM where
    catch = catchSTM
    throw = throwSTM
instance (MonadException m, Error e) =>
    MonadException (ErrorT e m) where
    throw       = lift . throw
    m `catch` h = mapErrorT (\m' -> m' `catch` \e -> runErrorT (h e)) m
    act `finally` sequel =
        mapErrorT (\act' -> act' `finally` runErrorT sequel) act
instance (MonadException m) =>
    MonadException (ExceptT e' m) where
    throw       = lift . throw
    m `catch` h = mapExceptT (\m' -> m' `catch` \e -> runExceptT (h e)) m
    act `finally` sequel =
        mapExceptT (\act' -> act' `finally` runExceptT sequel) act
instance (MonadException m) =>
    MonadException (IdentityT m) where
    throw       = lift . throw
    m `catch` h = mapIdentityT (\m' -> m' `catch` \e -> runIdentityT (h e)) m
instance MonadException m =>
    MonadException (ListT m) where
    throw       = lift . throw
    m `catch` h = mapListT (\m' -> m' `catch` \e -> runListT (h e)) m
instance (MonadException m) =>
    MonadException (MaybeT m) where
    throw       = lift . throw
    m `catch` h = mapMaybeT (\m' -> m' `catch` \e -> runMaybeT (h e)) m
    act `finally` sequel =
        mapMaybeT (\act' -> act' `finally` runMaybeT sequel) act
instance (Monoid w, MonadException m) =>
    MonadException (Lazy.RWST r w s m) where
    throw       = lift . throw
    m `catch` h = Lazy.RWST $ \r s ->
                  Lazy.runRWST m r s `catch` \e -> Lazy.runRWST (h e) r s
instance (Monoid w, MonadException m) =>
    MonadException (Strict.RWST r w s m) where
    throw       = lift . throw
    m `catch` h = Strict.RWST $ \r s ->
                  Strict.runRWST m r s `catch` \e -> Strict.runRWST (h e) r s
instance (MonadException m) =>
    MonadException (ReaderT r m) where
    throw       = lift . throw
    m `catch` h = ReaderT $ \r ->
                  runReaderT m r `catch` \e -> runReaderT (h e) r
instance (MonadException m) =>
    MonadException (Lazy.StateT s m) where
    throw       = lift . throw
    m `catch` h = Lazy.StateT $ \s ->
                  Lazy.runStateT m s `catch` \e -> Lazy.runStateT (h e) s
instance (MonadException m) =>
    MonadException (Strict.StateT s m) where
    throw       = lift . throw
    m `catch` h = Strict.StateT $ \s ->
                  Strict.runStateT m s `catch` \e -> Strict.runStateT (h e) s
instance (Monoid w, MonadException m) =>
    MonadException (Lazy.WriterT w m) where
    throw       = lift . throw
    m `catch` h = Lazy.WriterT $
                  Lazy.runWriterT m `catch` \e -> Lazy.runWriterT (h e)
instance (Monoid w, MonadException m) =>
    MonadException (Strict.WriterT w m) where
    throw       = lift . throw
    m `catch` h = Strict.WriterT $
                  Strict.runWriterT m `catch` \e -> Strict.runWriterT (h e)
instance (MonadAsyncException m, Error e) =>
    MonadAsyncException (ErrorT e m) where
    mask act = ErrorT $ mask $ \restore ->
               runErrorT $ act (mapErrorT restore)
instance (MonadAsyncException m) =>
    MonadAsyncException (ExceptT e' m) where
    mask act = ExceptT $ mask $ \restore ->
               runExceptT $ act (mapExceptT restore)
instance (MonadAsyncException m) =>
    MonadAsyncException (IdentityT m) where
    mask act = IdentityT $ mask $ \restore ->
               runIdentityT $ act (mapIdentityT restore)
instance (MonadAsyncException m) =>
    MonadAsyncException (ListT m) where
    mask act = ListT $ mask $ \restore ->
               runListT $ act (mapListT restore)
instance (MonadAsyncException m) =>
    MonadAsyncException (MaybeT m) where
    mask act = MaybeT $ mask $ \restore ->
               runMaybeT $ act (mapMaybeT restore)
instance (Monoid w, MonadAsyncException m) =>
    MonadAsyncException (Lazy.RWST r w s m) where
    mask act = Lazy.RWST $ \r s -> mask $ \restore ->
               Lazy.runRWST (act (Lazy.mapRWST restore)) r s
instance (Monoid w, MonadAsyncException m) =>
    MonadAsyncException (Strict.RWST r w s m) where
    mask act = Strict.RWST $ \r s -> mask $ \restore ->
               Strict.runRWST (act (Strict.mapRWST restore)) r s
instance (MonadAsyncException m) =>
    MonadAsyncException (ReaderT r m) where
    mask act = ReaderT $ \r -> mask $ \restore ->
               runReaderT (act (mapReaderT restore)) r
instance (MonadAsyncException m) =>
    MonadAsyncException (Lazy.StateT s m) where
    mask act = Lazy.StateT $ \s -> mask $ \restore ->
               Lazy.runStateT (act (Lazy.mapStateT restore)) s
instance (MonadAsyncException m) =>
    MonadAsyncException (Strict.StateT s m) where
    mask act = Strict.StateT $ \s -> mask $ \restore ->
               Strict.runStateT (act (Strict.mapStateT restore)) s
instance (Monoid w, MonadAsyncException m) =>
    MonadAsyncException (Lazy.WriterT w m) where
    mask act = Lazy.WriterT $ mask $ \restore ->
               Lazy.runWriterT $ act (Lazy.mapWriterT restore)
instance (Monoid w, MonadAsyncException m) =>
    MonadAsyncException (Strict.WriterT w m) where
    mask act = Strict.WriterT $ mask $ \restore ->
               Strict.runWriterT $ act (Strict.mapWriterT restore)