{-# language FlexibleContexts #-}

module TPDB.DP.TCap (tcap) where

import TPDB.Data
import TPDB.Pretty

import TPDB.DP.Unify

import Control.Monad (forM)
import Control.Monad.State.Strict 
import Control.Applicative


-- |  This function keeps only those parts of the input term which cannot be reduced,
-- even if the term is instantiated. All other parts are replaced by fresh variables.
-- Def 4.4 in http://cl-informatik.uibk.ac.at/users/griff/publications/Sternagel-Thiemann-RTA10.pdf

tcap :: (Ord v, Eq c, TermC v c) => [Rule (Term v c)] -> Term v c -> Term Int c
tcap :: forall v c.
(Ord v, Eq c, TermC v c) =>
[Rule (Term v c)] -> Term v c -> Term Int c
tcap [Rule (Term v c)]
dp Term v c
t = State Int (Term Int c) -> Int -> Term Int c
forall s a. State s a -> s -> a
evalState ( [Rule (Term v c)] -> Term v c -> State Int (Term Int c)
forall {v} {c} {v}.
(Ord v, Eq c) =>
[Rule (Term v c)] -> Term v c -> StateT Int Identity (Term Int c)
walk [Rule (Term v c)]
dp Term v c
t ) Int
0

fresh_var :: TermC Int c => State Int ( Term Int c )
fresh_var :: forall c. TermC v c => State Int (Term Int c)
fresh_var = do Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get ; Int -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int -> StateT Int Identity ()) -> Int -> StateT Int Identity ()
forall a b. (a -> b) -> a -> b
$! Int -> Int
forall a. Enum a => a -> a
succ Int
i ; Term Int c -> State Int (Term Int c)
forall a. a -> StateT Int Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Term Int c -> State Int (Term Int c))
-> Term Int c -> State Int (Term Int c)
forall a b. (a -> b) -> a -> b
$ Int -> Term Int c
forall v s. v -> Term v s
Var Int
i

{-# INLINE walk #-}
walk :: [Rule (Term v c)] -> Term v c -> StateT Int Identity (Term Int c)
walk [Rule (Term v c)]
dp =
  let go :: Term v c -> StateT Int Identity (Term Int c)
go Term v c
t = case Term v c
t of
        Node c
f [Term v c]
args -> do
          Term Int c
t' <- c -> [Term Int c] -> Term Int c
forall v s. s -> [Term v s] -> Term v s
Node c
f ([Term Int c] -> Term Int c)
-> StateT Int Identity [Term Int c]
-> StateT Int Identity (Term Int c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Term v c]
-> (Term v c -> StateT Int Identity (Term Int c))
-> StateT Int Identity [Term Int c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Term v c]
args Term v c -> StateT Int Identity (Term Int c)
go
          if (Rule (Term v c) -> Bool) -> [Rule (Term v c)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ( \ Rule (Term v c)
u -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Term (Either v Int) c -> Term (Either v Int) c -> Bool
forall {v} {c}. (Ord v, Eq c) => Term v c -> Term v c -> Bool
unifies ( (v -> Either v Int) -> Term v c -> Term (Either v Int) c
forall v s u.
(TermC v c, TermC v c) =>
(v -> u) -> Term v s -> Term u s
vmap v -> Either v Int
forall a b. a -> Either a b
Left (Term v c -> Term (Either v Int) c)
-> Term v c -> Term (Either v Int) c
forall a b. (a -> b) -> a -> b
$ Rule (Term v c) -> Term v c
forall a. Rule a -> a
lhs Rule (Term v c)
u ) ( (Int -> Either v Int) -> Term Int c -> Term (Either v Int) c
forall v s u.
(TermC v c, TermC v c) =>
(v -> u) -> Term v s -> Term u s
vmap Int -> Either v Int
forall a b. b -> Either a b
Right Term Int c
t' ) )  ([Rule (Term v c)] -> Bool) -> [Rule (Term v c)] -> Bool
forall a b. (a -> b) -> a -> b
$ (Rule (Term v c) -> Bool) -> [Rule (Term v c)] -> [Rule (Term v c)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> (Rule (Term v c) -> Bool) -> Rule (Term v c) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rule (Term v c) -> Bool
forall a. Rule a -> Bool
strict) [Rule (Term v c)]
dp 
            then Term Int c -> StateT Int Identity (Term Int c)
forall a. a -> StateT Int Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Term Int c
t' else StateT Int Identity (Term Int c)
forall c. TermC v c => State Int (Term Int c)
fresh_var
        Term v c
_ -> StateT Int Identity (Term Int c)
forall c. TermC v c => State Int (Term Int c)
fresh_var 
  in  Term v c -> StateT Int Identity (Term Int c)
forall {v}. Term v c -> StateT Int Identity (Term Int c)
go