module Idris.Erasure (performUsageAnalysis, mkFieldName) where
import Idris.AbsSyntax
import Idris.ASTUtils
import Idris.Core.CaseTree
import Idris.Core.TT
import Idris.Core.Evaluate
import Idris.Primitives
import Idris.Error
import Debug.Trace
import System.IO.Unsafe
import Control.Category
import Prelude hiding (id, (.))
import Control.Arrow
import Control.Applicative
import Control.Monad.State
import Data.Maybe
import Data.List
import qualified Data.Set as S
import qualified Data.IntSet as IS
import qualified Data.Map as M
import qualified Data.IntMap as IM
import Data.Set (Set)
import Data.IntSet (IntSet)
import Data.Map (Map)
import Data.IntMap (IntMap)
import Data.Text (pack)
import qualified Data.Text as T
type UseMap = Map Name (IntMap (Set Reason))
data Arg = Arg Int | Result deriving (Eq, Ord)
instance Show Arg where
    show (Arg i) = show i
    show Result  = "*"
type Node = (Name, Arg)
type Deps = Map Cond DepSet
type Reason = (Name, Int)  
type DepSet = Map Node (Set Reason)
type Cond = Set Node
data VarInfo = VI
    { viDeps   :: DepSet      
    , viFunArg :: Maybe Int   
    , viMethod :: Maybe Name  
    }
    deriving Show
type Vars = Map Name VarInfo
performUsageAnalysis :: [Name] -> Idris [Name]
performUsageAnalysis startNames = do
    ctx <- tt_ctxt <$> getIState
    case startNames of
      [] -> return []  
      main  -> do
        ci  <- idris_classes <$> getIState
        cg  <- idris_callgraph <$> getIState
        opt <- idris_optimisation <$> getIState
        used <- idris_erasureUsed <$> getIState
        externs <- idris_externs <$> getIState
        
        let depMap = buildDepMap ci used (S.toList externs) ctx main
        
        let (residDeps, (reachableNames, minUse)) = minimalUsage depMap
            usage = M.toList minUse
        
        logLvl 5 $ "Original deps:\n" ++ unlines (map fmtItem . M.toList $ depMap)
        logLvl 3 $ "Reachable names:\n" ++ unlines (map (indent . show) . S.toList $ reachableNames)
        logLvl 4 $ "Minimal usage:\n" ++ fmtUseMap usage
        logLvl 5 $ "Residual deps:\n" ++ unlines (map fmtItem . M.toList $ residDeps)
        
        checkEnabled <- (WarnReach `elem`) . opt_cmdline . idris_options <$> getIState
        when checkEnabled $
            mapM_ (checkAccessibility opt) usage
        
        reachablePostulates <- S.intersection reachableNames . idris_postulates <$> getIState
        when (not . S.null $ reachablePostulates)
            $ ifail ("reachable postulates:\n" ++ intercalate "\n" ["  " ++ show n | n <- S.toList reachablePostulates])
        
        mapM_ storeUsage usage
        return $ S.toList reachableNames
  where
    indent = ("  " ++)
    fmtItem :: (Cond, DepSet) -> String
    fmtItem (cond, deps) = indent $ show (S.toList cond) ++ " -> " ++ show (M.toList deps)
    fmtUseMap :: [(Name, IntMap (Set Reason))] -> String
    fmtUseMap = unlines . map (\(n,is) -> indent $ show n ++ " -> " ++ fmtIxs is)
    fmtIxs :: IntMap (Set Reason) -> String
    fmtIxs = intercalate ", " . map fmtArg . IM.toList
      where
        fmtArg (i, rs)
            | S.null rs = show i
            | otherwise = show i ++ " from " ++ intercalate ", " (map show $ S.toList rs)
    storeUsage :: (Name, IntMap (Set Reason)) -> Idris ()
    storeUsage (n, args) = fputState (cg_usedpos . ist_callgraph n) flat
      where
        flat = [(i, S.toList rs) | (i,rs) <- IM.toList args]
    checkAccessibility :: Ctxt OptInfo -> (Name, IntMap (Set Reason)) -> Idris ()
    checkAccessibility opt (n, reachable)
        | Just (Optimise inaccessible dt) <- lookupCtxtExact n opt
        , eargs@(_:_) <- [fmt n (S.toList rs) | (i,n) <- inaccessible, rs <- maybeToList $ IM.lookup i reachable]
        = warn $ show n ++ ": inaccessible arguments reachable:\n  " ++ intercalate "\n  " eargs
        | otherwise = return ()
      where
        fmt n [] = show n ++ " (no more information available)"
        fmt n rs = show n ++ " from " ++ intercalate ", " [show rn ++ " arg# " ++ show ri | (rn,ri) <- rs]
        warn = logLvl 0
minimalUsage :: Deps -> (Deps, (Set Name, UseMap))
minimalUsage = second gather . forwardChain
  where
    gather :: DepSet -> (Set Name, UseMap)
    gather = foldr ins (S.empty, M.empty) . M.toList
       where
        ins :: (Node, Set Reason) -> (Set Name, UseMap) -> (Set Name, UseMap)
        ins ((n, Result), rs) (ns, umap) = (S.insert n ns, umap)
        ins ((n, Arg i ), rs) (ns, umap) = (ns, M.insertWith (IM.unionWith S.union) n (IM.singleton i rs) umap)
forwardChain :: Deps -> (Deps, DepSet)
forwardChain deps
    | Just trivials <- M.lookup S.empty deps 
        = (M.unionWith S.union trivials)
            `second` forwardChain (remove trivials . M.delete S.empty $ deps)
    | otherwise = (deps, M.empty)
  where
    
    
    remove :: DepSet -> Deps -> Deps
    remove ds = M.mapKeysWith (M.unionWith S.union) (S.\\ M.keysSet ds)
buildDepMap :: Ctxt ClassInfo -> [(Name, Int)] -> [(Name, Int)] ->
               Context -> [Name] -> Deps
buildDepMap ci used externs ctx startNames 
    = addPostulates used $ dfs S.empty M.empty startNames
  where
    
    addPostulates :: [(Name, Int)] -> Deps -> Deps
    addPostulates used deps = foldr (\(ds, rs) -> M.insertWith (M.unionWith S.union) ds rs) deps (postulates used)
      where
        
        (==>) ds rs = (S.fromList ds, M.fromList [(r, S.empty) | r <- rs])
        it n is = [(sUN n, Arg i) | i <- is]
        mn n is = [(MN 0 $ pack n, Arg i) | i <- is]
        
        specialPrims = S.fromList [sUN "prim__believe_me"]
        usedNames = allNames deps S.\\ specialPrims
        usedPrims = [(p_name p, p_arity p) | p <- primitives, p_name p `S.member` usedNames]
        postulates used = 
            [ [] ==> concat
                
                
                [(map (\n -> (n, Result)) startNames)
                ,[(sUN "run__IO", Result), (sUN "run__IO", Arg 1)]
                ,[(sUN "call__IO", Result), (sUN "call__IO", Arg 2)]
                
                , map (\(n, i) -> (n, Arg i)) used 
                
                
                , it "MkIO"         [2]
                , it "prim__IO"     [1]
                
                
                , [(pairCon, Arg 2),
                   (pairCon, Arg 3)] 
                
                
                
                , it "prim_fork"    [0]
                , it "unsafePerformPrimIO"  [1]
                
                
                , it "prim__believe_me" [2]
    
                
                , [(n, Arg i) | (n,arity) <- usedPrims, i <- [0..arity1]]
                
                , [(n, Arg i) | (n,arity) <- externs, i <- [0..arity1]]
                
                ]
            ]
    
    
    
    dfs :: Set Name -> Deps -> [Name] -> Deps
    dfs visited deps [] = deps
    dfs visited deps (n : ns)
        | n `S.member` visited = dfs visited deps ns
        | otherwise = dfs (S.insert n visited) (M.unionWith (M.unionWith S.union) deps' deps) (next ++ ns)
      where
        next = [n | n <- S.toList depn, n `S.notMember` visited]
        depn = S.delete n $ allNames deps'
        deps' = getDeps n
    
    
    allNames :: Deps -> Set Name
    allNames = S.unions . map names . M.toList
        where
        names (cs, ns) = S.map fst cs `S.union` S.map fst (M.keysSet ns)
    
    getDeps :: Name -> Deps
    getDeps (SN (WhereN i (SN (InstanceCtorN classN)) (MN i' field)))
        = M.empty  
    getDeps n = case lookupDefExact n ctx of
        Just def -> getDepsDef n def
        Nothing  -> error $ "erasure checker: unknown reference: " ++ show n
    getDepsDef :: Name -> Def -> Deps
    getDepsDef fn (Function ty t) = error "a function encountered"  
    getDepsDef fn (TyDecl   ty t) = M.empty
    getDepsDef fn (Operator ty n' f) = M.empty  
    getDepsDef fn (CaseOp ci ty tys def tot cdefs)
        = getDepsSC fn etaVars (etaMap `M.union` varMap) sc
      where
        
        
        etaIdx = [length vars .. length tys  1]
        etaVars = [eta i | i <- etaIdx]
        etaMap = M.fromList [varPair (eta i) i | i <- etaIdx]
        eta i = MN i (pack "eta")
        
        varMap = M.fromList [varPair v i | (v,i) <- zip vars [0..]]
        varPair n argNo = (n, VI
            { viDeps   = M.singleton (fn, Arg argNo) S.empty
            , viFunArg = Just argNo
            , viMethod = Nothing
            })
        (vars, sc) = cases_runtime cdefs
            
            
    etaExpand :: [Name] -> Term -> Term
    etaExpand []       t = t
    etaExpand (n : ns) t = etaExpand ns (App Complete t (P Ref n Erased))
    getDepsSC :: Name -> [Name] -> Vars -> SC -> Deps
    getDepsSC fn es vs  ImpossibleCase     = M.empty
    getDepsSC fn es vs (UnmatchedCase msg) = M.empty
    
    getDepsSC fn es vs (ProjCase (Proj t i) alts) = getDepsSC fn es vs (ProjCase t alts)  
    getDepsSC fn es vs (ProjCase (P  _ n _) alts) = getDepsSC fn es vs (Case Shared n alts)  
    
    getDepsSC fn es vs (ProjCase t alts)   = error $ "ProjCase not supported:\n" ++ show (ProjCase t alts)
    getDepsSC fn es vs (STerm    t)        = getDepsTerm vs [] (S.singleton (fn, Result)) (etaExpand es t)
    getDepsSC fn es vs (Case sh n alts)
        
        
        
        
        = addTagDep $ unionMap (getDepsAlt fn es vs casedVar) alts  
      where
        addTagDep = case alts of
            [_] -> id  
            _   -> M.insertWith (M.unionWith S.union) (S.singleton (fn, Result)) (viDeps casedVar)
        casedVar  = fromMaybe (error $ "nonpatvar in case: " ++ show n) (M.lookup n vs)
    getDepsAlt :: Name -> [Name] -> Vars -> VarInfo -> CaseAlt -> Deps
    getDepsAlt fn es vs var (FnCase n ns sc) = M.empty 
    getDepsAlt fn es vs var (ConstCase c sc) = getDepsSC fn es vs sc
    getDepsAlt fn es vs var (DefaultCase sc) = getDepsSC fn es vs sc
    getDepsAlt fn es vs var (SucCase   n sc)
        = getDepsSC fn es (M.insert n var vs) sc 
    
    getDepsAlt fn es vs var (ConCase n cnt ns sc)
        = getDepsSC fn es (vs' `M.union` vs) sc  
      where
        
        
        vs' = M.fromList [(v, VI
            { viDeps   = M.insertWith S.union (n, Arg j) (S.singleton (fn, varIdx)) (viDeps var)
            , viFunArg = viFunArg var
            , viMethod = meth j
            })
          | (v, j) <- zip ns [0..]]
        
        
        varIdx = fromJust (viFunArg var)
        
        meth :: Int -> Maybe Name
        meth | SN (InstanceCtorN className) <- n = \j -> Just (mkFieldName n j)
             | otherwise = \j -> Nothing
    
    getDepsTerm :: Vars -> [(Name, Cond -> Deps)] -> Cond -> Term -> Deps
    
    getDepsTerm vs bs cd (P _ n _)
        
        | Just deps <- lookup n bs
        = deps cd
        
        | Just var <- M.lookup n vs
        = M.singleton cd (viDeps var)
        
        | MN _ _ <- n
        = error $ "erasure analysis: variable " ++ show n ++ " unbound in " ++ show (S.toList cd)
        
        | otherwise = M.singleton cd (M.singleton (n, Result) S.empty)
    
    getDepsTerm vs bs cd (V i) = snd (bs !! i) cd
    getDepsTerm vs bs cd (Bind n bdr body)
        
        
        | Lam ty <- bdr = getDepsTerm vs ((n, const M.empty) : bs) cd body
        | Pi _ ty _ <- bdr = getDepsTerm vs ((n, const M.empty) : bs) cd body
        
        
        |  Let ty t <- bdr = var t cd `union` getDepsTerm vs ((n, const M.empty) : bs) cd body
        | NLet ty t <- bdr = var t cd `union` getDepsTerm vs ((n, const M.empty) : bs) cd body
      where
        var t cd = getDepsTerm vs bs cd t
    
    getDepsTerm vs bs cd app@(App _ _ _)
        | (fun, args) <- unApply app = case fun of
            
            P (DCon _ _ _) ctorName@(SN (InstanceCtorN className)) _
                -> conditionalDeps ctorName args  
                    `union` unionMap (methodDeps ctorName) (zip [0..] args)  
            
            P (TCon _ _) n _ -> unconditionalDeps args  
            P (DCon _ _ _) n _ -> conditionalDeps n args  
            
            
            
            P _ (UN n) _
                | n == T.pack "mkForeignPrim"
                -> unconditionalDeps $ drop 4 args
            
            
            P _ n _
                
                | Just deps <- lookup n bs
                    -> deps cd `union` unconditionalDeps args
                
                | Just var  <- M.lookup n vs
                , Just meth <- viMethod var
                    -> viDeps var `ins` conditionalDeps meth args  
                
                | Just var <- M.lookup n vs
                    
                    -> viDeps var `ins` unconditionalDeps args
                
                | otherwise
                    
                    -> conditionalDeps n args
            
            V i -> snd (bs !! i) cd `union` unconditionalDeps args
            
            Bind n (Lam ty) t -> getDepsTerm vs bs cd (lamToLet [] app)
            
            Bind n ( Let ty t') t -> getDepsTerm vs bs cd (App Complete (Bind n (Lam ty) t) t')
            Bind n (NLet ty t') t -> getDepsTerm vs bs cd (App Complete (Bind n (Lam ty) t) t')
            Proj t i
                -> error $ "cannot[0] analyse projection !" ++ show i ++ " of " ++ show t
            Erased -> M.empty
            _ -> error $ "cannot analyse application of " ++ show fun ++ " to " ++ show args
      where
        union = M.unionWith $ M.unionWith S.union
        ins = M.insertWith (M.unionWith S.union) cd
        unconditionalDeps :: [Term] -> Deps
        unconditionalDeps = unionMap (getDepsTerm vs bs cd)
        conditionalDeps :: Name -> [Term] -> Deps
        conditionalDeps n
            = ins (M.singleton (n, Result) S.empty) . unionMap (getDepsArgs n) . zip indices
          where
            indices = map Just [0 .. getArity n  1] ++ repeat Nothing
            getDepsArgs n (Just i,  t) = getDepsTerm vs bs (S.insert (n, Arg i) cd) t  
            getDepsArgs n (Nothing, t) = getDepsTerm vs bs cd t                        
        methodDeps :: Name -> (Int, Term) -> Deps
        methodDeps ctorName (methNo, t)
            = getDepsTerm (vars `M.union` vs) (bruijns ++ bs) cond body
          where
            vars = M.fromList [(v, VI
                { viDeps   = deps i
                , viFunArg = Just i
                , viMethod = Nothing
                }) | (v, i) <- zip args [0..]]
            deps i   = M.singleton (metameth, Arg i) S.empty
            bruijns  = reverse [(n, \cd -> M.singleton cd (deps i)) | (i, n) <- zip [0..] args]
            cond     = S.singleton (metameth, Result)
            metameth = mkFieldName ctorName methNo
            (args, body) = unfoldLams t
    
    getDepsTerm vs bs cd (Proj t (1)) = getDepsTerm vs bs cd t  
    getDepsTerm vs bs cd (Proj t i) = error $ "cannot[1] analyse projection !" ++ show i ++ " of " ++ show t
    
    getDepsTerm vs bs cd (Constant _) = M.empty
    getDepsTerm vs bs cd (TType    _) = M.empty
    getDepsTerm vs bs cd (UType    _) = M.empty
    getDepsTerm vs bs cd  Erased      = M.empty
    getDepsTerm vs bs cd  Impossible  = M.empty
    getDepsTerm vs bs cd t = error $ "cannot get deps of: " ++ show t
    
    getArity :: Name -> Int
    getArity (SN (WhereN i' ctorName (MN i field)))
        | Just (TyDecl (DCon _ _ _) ty) <- lookupDefExact ctorName ctx
        = let argTys = map snd $ getArgTys ty
            in if i <= length argTys
                then length $ getArgTys (argTys !! i)
                else error $ "invalid field number " ++ show i ++ " for " ++ show ctorName
        | otherwise = error $ "unknown instance constructor: " ++ show ctorName
    getArity n = case lookupDefExact n ctx of
        Just (CaseOp ci ty tys def tot cdefs) -> length tys
        Just (TyDecl (DCon tag arity _) _)    -> arity
        Just (TyDecl (Ref) ty)                -> length $ getArgTys ty
        Just (Operator ty arity op)           -> arity
        Just df -> error $ "Erasure/getArity: unrecognised entity '"
                             ++ show n ++ "' with definition: "  ++ show df
        Nothing -> error $ "Erasure/getArity: definition not found for " ++ show n
    
    
    lamToLet :: [Term] -> Term -> Term
    lamToLet    xs  (App _ f x)         = lamToLet (x:xs) f
    lamToLet (x:xs) (Bind n (Lam ty) t) = Bind n (Let ty x) (lamToLet xs t)
    lamToLet (x:xs)  t                  = App Complete (lamToLet xs t) x
    lamToLet    []   t                  = t
    
    unfoldLams :: Term -> ([Name], Term)
    unfoldLams (Bind n (Lam ty) t) = let (ns,t') = unfoldLams t in (n:ns, t')
    unfoldLams t = ([], t)
    union :: Deps -> Deps -> Deps
    union = M.unionWith (M.unionWith S.union)
    unions :: [Deps] -> Deps
    unions = M.unionsWith (M.unionWith S.union)
    unionMap :: (a -> Deps) -> [a] -> Deps
    unionMap f = M.unionsWith (M.unionWith S.union) . map f
mkFieldName :: Name -> Int -> Name
mkFieldName ctorName fieldNo = SN (WhereN fieldNo ctorName $ sMN fieldNo "field")