{-# LANGUAGE CPP, RankNTypes #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}
module GhcMonad (
        
        GhcMonad(..),
        Ghc(..),
        GhcT(..), liftGhcT,
        reflectGhc, reifyGhc,
        getSessionDynFlags,
        liftIO,
        Session(..), withSession, modifySession, withTempSession,
        
        logWarnings, printException,
        WarnErrLogger, defaultWarnErrLogger
  ) where
import MonadUtils
import HscTypes
import DynFlags
import Exception
import ErrUtils
import Control.Monad
import Data.IORef
class (Functor m, MonadIO m, ExceptionMonad m, HasDynFlags m) => GhcMonad m where
  getSession :: m HscEnv
  setSession :: HscEnv -> m ()
withSession :: GhcMonad m => (HscEnv -> m a) -> m a
withSession f = getSession >>= f
getSessionDynFlags :: GhcMonad m => m DynFlags
getSessionDynFlags = withSession (return . hsc_dflags)
modifySession :: GhcMonad m => (HscEnv -> HscEnv) -> m ()
modifySession f = do h <- getSession
                     setSession $! f h
withSavedSession :: GhcMonad m => m a -> m a
withSavedSession m = do
  saved_session <- getSession
  m `gfinally` setSession saved_session
withTempSession :: GhcMonad m => (HscEnv -> HscEnv) -> m a -> m a
withTempSession f m =
  withSavedSession $ modifySession f >> m
logWarnings :: GhcMonad m => WarningMessages -> m ()
logWarnings warns = do
  dflags <- getSessionDynFlags
  liftIO $ printOrThrowWarnings dflags warns
newtype Ghc a = Ghc { unGhc :: Session -> IO a }
data Session = Session !(IORef HscEnv)
instance Functor Ghc where
  fmap f m = Ghc $ \s -> f `fmap` unGhc m s
instance Applicative Ghc where
  pure a = Ghc $ \_ -> return a
  g <*> m = do f <- g; a <- m; return (f a)
instance Monad Ghc where
  m >>= g  = Ghc $ \s -> do a <- unGhc m s; unGhc (g a) s
instance MonadIO Ghc where
  liftIO ioA = Ghc $ \_ -> ioA
instance MonadFix Ghc where
  mfix f = Ghc $ \s -> mfix (\x -> unGhc (f x) s)
instance ExceptionMonad Ghc where
  gcatch act handle =
      Ghc $ \s -> unGhc act s `gcatch` \e -> unGhc (handle e) s
  gmask f =
      Ghc $ \s -> gmask $ \io_restore ->
                             let
                                g_restore (Ghc m) = Ghc $ \s -> io_restore (m s)
                             in
                                unGhc (f g_restore) s
instance HasDynFlags Ghc where
  getDynFlags = getSessionDynFlags
instance GhcMonad Ghc where
  getSession = Ghc $ \(Session r) -> readIORef r
  setSession s' = Ghc $ \(Session r) -> writeIORef r s'
reflectGhc :: Ghc a -> Session -> IO a
reflectGhc m = unGhc m
reifyGhc :: (Session -> IO a) -> Ghc a
reifyGhc act = Ghc $ act
newtype GhcT m a = GhcT { unGhcT :: Session -> m a }
liftGhcT :: m a -> GhcT m a
liftGhcT m = GhcT $ \_ -> m
instance Functor m => Functor (GhcT m) where
  fmap f m = GhcT $ \s -> f `fmap` unGhcT m s
instance Applicative m => Applicative (GhcT m) where
  pure x  = GhcT $ \_ -> pure x
  g <*> m = GhcT $ \s -> unGhcT g s <*> unGhcT m s
instance Monad m => Monad (GhcT m) where
  m >>= k  = GhcT $ \s -> do a <- unGhcT m s; unGhcT (k a) s
instance MonadIO m => MonadIO (GhcT m) where
  liftIO ioA = GhcT $ \_ -> liftIO ioA
instance ExceptionMonad m => ExceptionMonad (GhcT m) where
  gcatch act handle =
      GhcT $ \s -> unGhcT act s `gcatch` \e -> unGhcT (handle e) s
  gmask f =
      GhcT $ \s -> gmask $ \io_restore ->
                           let
                              g_restore (GhcT m) = GhcT $ \s -> io_restore (m s)
                           in
                              unGhcT (f g_restore) s
instance MonadIO m => HasDynFlags (GhcT m) where
  getDynFlags = GhcT $ \(Session r) -> liftM hsc_dflags (liftIO $ readIORef r)
instance ExceptionMonad m => GhcMonad (GhcT m) where
  getSession = GhcT $ \(Session r) -> liftIO $ readIORef r
  setSession s' = GhcT $ \(Session r) -> liftIO $ writeIORef r s'
printException :: GhcMonad m => SourceError -> m ()
printException err = do
  dflags <- getSessionDynFlags
  liftIO $ printBagOfErrors dflags (srcErrorMessages err)
type WarnErrLogger = forall m. GhcMonad m => Maybe SourceError -> m ()
defaultWarnErrLogger :: WarnErrLogger
defaultWarnErrLogger Nothing  = return ()
defaultWarnErrLogger (Just e) = printException e