module Idris.ErrReverse where
import Idris.AbsSyntax
import Idris.Core.TT
import Util.Pretty
import Data.List
import Debug.Trace
errReverse :: IState -> Term -> Term
errReverse ist t = rewrite 5 t 
  where
    rewrite 0 t = t
    rewrite n t = let t' = foldl applyRule t (reverse (idris_errRev ist)) in
                      if t == t' then t
                                 else rewrite (n  1) t'
    applyRule :: Term -> (Term, Term) -> Term
    applyRule t (l, r) = applyNames [] t l r
    
    
    applyNames ns t (Bind n (PVar ty) scl) (Bind n' (PVar ty') scr)
       | n == n' = applyNames (n : ns) t (instantiate (P Ref n ty) scl) 
                                         (instantiate (P Ref n' ty') scr)
       | otherwise = t
    applyNames ns t l r = matchTerm ns l r t
    matchTerm ns l r t
       | Just nmap <- match ns l t = substNames nmap r
    matchTerm ns l r (App s f a) = let f' = matchTerm ns l r f
                                       a' = matchTerm ns l r a in
                                       App s f' a'
    matchTerm ns l r (Bind n b sc) = let b' = fmap (matchTerm ns l r) b 
                                         sc' = matchTerm ns l r sc in
                                         Bind n b' sc'
    matchTerm ns l r t = t
    match ns l t = do ms <- match' ns l t
                      combine (nub ms)
    
    combine [] = Just []
    combine ((x, t) : xs) 
       | Just _ <- lookup x xs = Nothing
       | otherwise = do xs' <- combine xs
                        Just ((x, t) : xs')
    match' ns (P Ref n _) t | n `elem` ns = Just [(n, t)]
    match' ns (App _ f a) (App _ f' a') = do fs <- match' ns f f'
                                             as <- match' ns a a'
                                             Just (fs ++ as)
    
    match' ns x y = if x == y then Just [] else Nothing
    
    
    elideLambdas (App s f a) = App s (elideLambdas f) (elideLambdas a)
    elideLambdas (Bind n (Lam t) sc) 
       | size sc > 200 = P Ref (sUN "...") Erased
    elideLambdas (Bind n b sc)
       = Bind n (fmap elideLambdas b) (elideLambdas sc)
    elideLambdas t = t