module Control.Monad.StateStack
       (
         
         MonadStateStack(..)
         
       , StateStackT(..), StateStack
         
       , runStateStackT, evalStateStackT, execStateStackT
       , runStateStack,  evalStateStack,  execStateStack
       , liftState
       ) where
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid
import Control.Applicative
#endif
import Control.Arrow (second)
import Control.Monad.Identity
import qualified Control.Monad.State as St
import Control.Arrow (first, (&&&))
import Control.Monad.Trans
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Except
import Control.Monad.Trans.Identity
import Control.Monad.Trans.List
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Reader (ReaderT)
import Control.Monad.Trans.State.Lazy as Lazy
import Control.Monad.Trans.State.Strict as Strict
import Control.Monad.Trans.Writer.Lazy as Lazy
import Control.Monad.Trans.Writer.Strict as Strict
import qualified Control.Monad.Cont.Class as CC
import qualified Control.Monad.State.Class as StC
import qualified Control.Monad.IO.Class as IC
newtype StateStackT s m a = StateStackT { unStateStackT :: St.StateT (s,[s]) m a }
  deriving (Functor, Applicative, Monad, MonadTrans, IC.MonadIO)
class St.MonadState s m => MonadStateStack s m where
  save    :: m ()   
  restore :: m ()   
instance Monad m => St.MonadState s (StateStackT s m) where
  get   = StateStackT $ St.gets fst
  put s = StateStackT $ (St.modify . first) (const s)
instance Monad m => MonadStateStack s (StateStackT s m) where
  save    = StateStackT $ St.modify (fst &&& uncurry (:))
  restore = StateStackT . St.modify $ \(cur,hist) ->
              case hist of
                []        -> (cur,hist)
                (r:hist') -> (r,hist')
runStateStackT :: Monad m => StateStackT s m a -> s -> m (a, s)
runStateStackT m s = (liftM . second) fst . flip St.runStateT (s,[]) . unStateStackT $ m
evalStateStackT :: Monad m => StateStackT s m a -> s -> m a
evalStateStackT m s = liftM fst $ runStateStackT m s
execStateStackT :: Monad m => StateStackT s m a -> s -> m s
execStateStackT m s = liftM snd $ runStateStackT m s
type StateStack s a = StateStackT s Identity a
runStateStack :: StateStack s a -> s -> (a,s)
runStateStack m s = runIdentity $ runStateStackT m s
evalStateStack :: StateStack s a -> s -> a
evalStateStack m s = runIdentity $ evalStateStackT m s
execStateStack :: StateStack s a -> s -> s
execStateStack m s = runIdentity $ execStateStackT m s
liftState :: Monad m => St.StateT s m a -> StateStackT s m a
liftState st = StateStackT . St.StateT $ \(s,ss) -> (liftM . second) (flip (,) ss) (St.runStateT st s)
instance MonadStateStack s m => MonadStateStack s (ContT r m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (ExceptT e m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (IdentityT m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (ListT m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (MaybeT m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (ReaderT r m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (Lazy.StateT s m) where
  save    = lift save
  restore = lift restore
instance MonadStateStack s m => MonadStateStack s (Strict.StateT s m) where
  save    = lift save
  restore = lift restore
instance (Monoid w, MonadStateStack s m) => MonadStateStack s (Lazy.WriterT w m) where
  save    = lift save
  restore = lift restore
instance (Monoid w, MonadStateStack s m) => MonadStateStack s (Strict.WriterT w m) where
  save    = lift save
  restore = lift restore
instance CC.MonadCont m => CC.MonadCont (StateStackT s m) where
  callCC c = StateStackT $ CC.callCC (unStateStackT . (\k -> c (StateStackT . k)))