module Idris.TypeSearch (
  searchByType, searchPred, defaultScoreFunction
) where
import Control.Applicative (Applicative (..), (<$>), (<*>), (<|>))
import Control.Arrow (first, second, (&&&), (***))
import Control.Monad (when, guard)
import Data.List (find, partition, (\\))
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (catMaybes, fromMaybe, isJust, maybeToList, mapMaybe)
import Data.Monoid (Monoid (mempty, mappend))
import Data.Ord (comparing)
import qualified Data.PriorityQueue.FingerTree as Q
import Data.Set (Set)
import qualified Data.Set as S
import qualified Data.Text as T (pack, isPrefixOf)
import Data.Traversable (traverse)
import Idris.AbsSyntax (addUsingConstraints, addImpl, getIState, putIState, implicit, logLvl)
import Idris.AbsSyntaxTree (class_instances, ClassInfo, defaultSyntax, eqTy, Idris,
  IState (idris_classes, idris_docstrings, tt_ctxt, idris_outputmode),
  implicitAllowed, OutputMode(..), PTerm, toplevel)
import Idris.Core.Evaluate (Context (definitions), Def (Function, TyDecl, CaseOp), normaliseC)
import Idris.Core.TT hiding (score)
import Idris.Core.Unify (match_unify)
import Idris.Delaborate (delabTy)
import Idris.Docstrings (noDocs, overview)
import Idris.Elab.Type (elabType)
import Idris.Output (iputStrLn, iRenderOutput, iPrintResult, iRenderError, iRenderResult, prettyDocumentedIst)
import Idris.IBC
import Prelude hiding (pred)
import Util.Pretty (text, char, vsep, (<>), Doc, annotate)
searchByType :: [String] -> PTerm -> Idris ()
searchByType pkgs pterm = do
  i <- getIState 
  when (not (null pkgs)) $ 
     iputStrLn $ "Searching packages: " ++ showSep ", " pkgs
  
  mapM_ loadPkgIndex pkgs
  pterm' <- addUsingConstraints syn emptyFC pterm
  pterm'' <- implicit toplevel syn name pterm'
  let pterm'''  = addImpl [] i pterm''
  ty <- elabType toplevel syn (fst noDocs) (snd noDocs) emptyFC [] name NoFC pterm'
  let names = searchUsing searchPred i ty
  let names' = take numLimit names
  let docs =
       [ let docInfo = (n, delabTy i n, fmap (overview . fst) (lookupCtxtExact n (idris_docstrings i))) in
         displayScore theScore <> char ' ' <> prettyDocumentedIst i docInfo
                | (n, theScore) <- names']
  if (not (null docs))
     then case idris_outputmode i of
               RawOutput _  -> do mapM_ iRenderOutput docs
                                  iPrintResult ""
               IdeMode _ _ -> iRenderResult (vsep docs)
     else iRenderError $ text "No results found"
  putIState i 
  where
    numLimit = 50
    syn = defaultSyntax { implicitAllowed = True } 
    name = sMN 0 "searchType" 
searchUsing :: (IState -> Type -> [(Name, Type)] -> [(Name, a)]) 
  -> IState -> Type -> [(Name, a)]
searchUsing pred istate ty = pred istate nty . concat . M.elems $ 
  M.mapWithKey (\key -> M.toAscList . M.mapMaybe (f key)) (definitions ctxt)
  where
  nty = normaliseC ctxt [] ty
  ctxt = tt_ctxt istate
  f k x = do
    guard $ not (special k)
    type2 <- typeFromDef x
    return $ normaliseC ctxt [] type2
  special :: Name -> Bool
  special (NS n _) = special n
  special (SN _) = True
  special (UN n) =    T.pack "default#" `T.isPrefixOf` n 
                   || n `elem` map T.pack ["believe_me", "really_believe_me"]
  special _ = False
searchPred :: IState -> Type -> [(Name, Type)] -> [(Name, Score)]
searchPred istate ty1 = matcher where
  maxScore = 100
  matcher = matchTypesBulk istate maxScore ty1
typeFromDef :: (Def, i, b, c, d) -> Maybe Type
typeFromDef (def, _, _, _, _) = get def where
  get :: Def -> Maybe Type
  get (Function ty _) = Just ty
  get (TyDecl _ ty) = Just ty
 
  get (CaseOp _ ty _ _ _ _)  = Just ty
  get _ = Nothing
unLazy :: Type -> Type
unLazy typ = case typ of
  App _ (App _ (P _ lazy _) _) ty | lazy == sUN "Delayed" -> unLazy ty
  Bind name binder ty -> Bind name (fmap unLazy binder) (unLazy ty)
  App s t1 t2 -> App s (unLazy t1) (unLazy t2)
  Proj ty i -> Proj (unLazy ty) i
  _ -> typ
reverseDag :: Ord k => [((k, a), Set k)] -> [((k, a), Set k)]
reverseDag xs = map f xs where
  f ((k, v), _) = ((k, v), S.fromList . map (fst . fst) $ filter (S.member k . snd) xs)
computeDagP :: Ord n 
  => (TT n -> Bool) 
  -> TT n
  -> ([((n, TT n), Set n)], [(n, TT n)], TT n)
computeDagP removePred t = (reverse (map f arguments), reverse theRemovedArgs , retTy) where
  f (n, ty) = ((n, ty), M.keysSet (usedVars ty))
  (arguments, theRemovedArgs, retTy) = go [] [] t
  
  go args removedArgs (Bind n (Pi _ ty _) sc) = let arg = (n, ty) in
    if removePred ty
      then go args (arg : removedArgs) sc
      else go (arg : args) removedArgs sc
  go args removedArgs sc = (args, removedArgs, sc)
usedVars :: Ord n => TT n -> Map n (TT n, Bool)
usedVars = f True where
  f b (P Bound n t) = M.singleton n (t, b) `M.union` f b t
  f b (Bind n binder t2) = (M.delete n (f b t2) `M.union`) $ case binder of
    Let t v ->   f b t `M.union` f b v
    Guess t v -> f b t `M.union` f b v
    bind -> f b (binderTy bind)
  f b (App _ t1 t2) = f b t1 `M.union` f (b && isInjective t1) t2
  f b (Proj t _) = f b t
  f _ (V _) = error "unexpected! run vToP first"
  f _ _ = M.empty
deleteFromDag :: Ord n => n -> [((n, TT n), (a, Set n))] -> [((n, TT n), (a, Set n))]
deleteFromDag name [] = []
deleteFromDag name (((name2, ty), (ix, set)) : xs) = (if name == name2
  then id
  else (((name2, ty) , (ix, S.delete name set)) :) ) (deleteFromDag name xs)
deleteFromArgList :: Ord n => n -> [(n, TT n)] -> [(n, TT n)]
deleteFromArgList n = filter ((/= n) . fst)
data AsymMods = Mods
  { argApp         :: !Int
  , typeClassApp   :: !Int
  , typeClassIntro :: !Int
  } deriving (Eq, Show)
data Sided a = Sided
  { left  :: !a
  , right :: !a
  } deriving (Eq, Show)
sided :: (a -> a -> b) -> Sided a -> b
sided f (Sided l r) = f l r
both :: (a -> b) -> Sided a -> Sided b
both f (Sided l r) = Sided (f l) (f r)
data Score = Score
  { transposition :: !Int 
  , equalityFlips :: !Int 
  , asymMods      :: !(Sided AsymMods) 
  } deriving (Eq, Show)
displayScore :: Score -> Doc OutputAnnotation
displayScore score = case both noMods (asymMods score) of
  Sided True  True  -> annotated EQ "=" 
  Sided True  False -> annotated LT "<" 
  Sided False True  -> annotated GT ">" 
  Sided False False -> text "_"
  where 
  annotated ordr = annotate (AnnSearchResult ordr) . text
  noMods (Mods app tcApp tcIntro) = app + tcApp + tcIntro == 0
scoreCriterion :: Score -> Bool
scoreCriterion (Score _ _ amods) = not
  (  sided (&&) (both ((> 0) . argApp) amods)
  || sided (+) (both argApp amods) > 4
  || sided (||) (both (\(Mods _ tcApp tcIntro) -> tcApp > 3 || tcIntro > 3) amods)
  ) 
defaultScoreFunction :: Score -> Int
defaultScoreFunction (Score trans eqFlip amods) = 
  trans + eqFlip + linearPenalty + upAndDowncastPenalty
  where
  
  linearPenalty = (\(Sided l r) -> 3 * l + r) 
    (both (\(Mods app tcApp tcIntro) -> 3 * app + 4 * tcApp + 2 * tcIntro) amods)
  
  upAndDowncastPenalty = 100 * 
    sided (*) (both (\(Mods app tcApp tcIntro) -> 2 * app + tcApp + tcIntro) amods)
instance Ord Score where
  compare = comparing defaultScoreFunction
instance Monoid a => Monoid (Sided a) where
  mempty = Sided mempty mempty
  (Sided l1 r1) `mappend` (Sided l2 r2) = Sided (l1 `mappend` l2) (r1 `mappend` r2)
instance Monoid AsymMods where
  mempty = Mods 0 0 0
  (Mods a b c) `mappend` (Mods a' b' c') = Mods (a + a') (b + b') (c + c')
instance Monoid Score where
  mempty = Score 0 0 mempty
  (Score t e mods) `mappend` (Score t' e' mods') = Score (t + t') (e + e') (mods `mappend` mods') 
type ArgsDAG = [((Name, Type), (Int, Set Name))]
type Classes = [(Name, Type)]
data State = State
  { holes          :: ![(Name, Type)] 
  , argsAndClasses :: !(Sided (ArgsDAG, Classes))
     
  , score     :: !Score 
  , usedNames :: ![Name] 
  } deriving Show
modifyTypes :: (Type -> Type) -> (ArgsDAG, Classes) -> (ArgsDAG, Classes)
modifyTypes f = modifyDag *** modifyList
  where
  modifyDag = map (first (second f))
  modifyList = map (second f)
findNameInArgsDAG :: Name -> ArgsDAG -> Maybe (Type, Maybe Int)
findNameInArgsDAG name = fmap ((snd . fst) &&& (Just . fst . snd)) . find ((name ==) . fst . fst)
findName :: Name -> (ArgsDAG, Classes) -> Maybe (Type, Maybe Int)
findName n (args, classes) = findNameInArgsDAG n args <|> ((,) <$> lookup n classes <*> Nothing)
deleteName :: Name -> (ArgsDAG, Classes) -> (ArgsDAG, Classes)
deleteName n (args, classes) = (deleteFromDag n args, filter ((/= n) . fst) classes)
tcToMaybe :: TC a -> Maybe a
tcToMaybe (OK x) = Just x
tcToMaybe (Error _) = Nothing
inArgTys :: (Type -> Type) -> ArgsDAG -> ArgsDAG
inArgTys = map . first . second
typeclassUnify :: Ctxt ClassInfo -> Context -> Type -> Type -> Maybe [(Name, Type)]
typeclassUnify classInfo ctxt ty tyTry = do
  res <- tcToMaybe $ match_unify ctxt [] (ty, Nothing) (retTy, Nothing) [] theHoles []
  guard $ null (theHoles \\ map fst res)
  let argTys' = map (second $ foldr (.) id [ subst n t | (n, t) <- res ]) tcArgs
  return argTys'
  where
  tyTry' = vToP tyTry
  theHoles = map fst nonTcArgs
  retTy = getRetTy tyTry'
  (tcArgs, nonTcArgs) = partition (isTypeClassArg classInfo . snd) $ getArgTys tyTry'
isTypeClassArg :: Ctxt ClassInfo -> Type -> Bool
isTypeClassArg classInfo ty = not (null (getClassName clss >>= flip lookupCtxt classInfo)) where
  (clss, args) = unApply ty
  getClassName (P (TCon _ _) className _) = [className]
  getClassName _ = []
subsets :: [a] -> [[a]]
subsets [] = [[]]
subsets (x : xs) = let ss = subsets xs in map (x :) ss ++ ss
flipEqualities :: Type -> [(Int, Type)]
flipEqualities t = case t of
  eq1@(App _ (App _ (App _ (App _ eq@(P _ eqty _) tyL) tyR) valL) valR) | eqty == eqTy ->
    [(0, eq1), (1, app (app (app (app eq tyR) tyL) valR) valL)]
  Bind n binder sc -> (\bind' (j, sc') -> (fst (binderTy bind') + j, Bind n (fmap snd bind') sc'))
    <$> traverse flipEqualities binder <*> flipEqualities sc
  App _ f x -> (\(i, f') (j, x') -> (i + j, app f' x')) 
    <$> flipEqualities f <*> flipEqualities x
  t' -> [(0, t')]
 where app = App Complete
matchTypesBulk :: forall info. IState -> Int -> Type -> [(info, Type)] -> [(info, Score)]
matchTypesBulk istate maxScore type1 types = getAllResults startQueueOfQueues where
  getStartQueues :: (info, Type) -> Maybe (Score, (info, Q.PQueue Score State))
  getStartQueues nty@(info, type2) = case mapMaybe startStates ty2s of
    [] -> Nothing
    xs -> Just (minimum (map fst xs), (info, Q.fromList xs))
    where
    ty2s = (\(i, dag) (j, retTy) -> (i + j, dag, retTy))
      <$> flipEqualitiesDag dag2 <*> flipEqualities retTy2
    flipEqualitiesDag dag = case dag of 
      [] -> [(0, [])]
      ((n, ty), (pos, deps)) : xs -> 
         (\(i, ty') (j, xs') -> (i + j , ((n, ty'), (pos, deps)) : xs'))
           <$> flipEqualities ty <*> flipEqualitiesDag xs
    startStates (numEqFlips, sndDag, sndRetTy) = do
      state <- unifyQueue (State startingHoles 
                (Sided (dag1, typeClassArgs1) (sndDag, typeClassArgs2))
                (mempty { equalityFlips = numEqFlips }) usedns) [(retTy1, sndRetTy)]
      return (score state, state)
    (dag2, typeClassArgs2, retTy2) = makeDag (uniqueBinders (map fst argNames1) type2)
    argNames2 = map fst dag2
    usedns = map fst startingHoles
    startingHoles = argNames1 ++ argNames2
    startingTypes = [(retTy1, retTy2)] 
  startQueueOfQueues :: Q.PQueue Score (info, Q.PQueue Score State)
  startQueueOfQueues = Q.fromList $ mapMaybe getStartQueues types
  getAllResults :: Q.PQueue Score (info, Q.PQueue Score State) -> [(info, Score)]
  getAllResults q = case Q.minViewWithKey q of
    Nothing -> []
    Just ((nextScore, (info, stateQ)), q') ->
      if defaultScoreFunction nextScore <= maxScore
        then case nextStepsQueue stateQ of
          Nothing -> getAllResults q'
          Just (Left stateQ') -> case Q.minViewWithKey stateQ' of
             Nothing -> getAllResults q'
             Just ((newQscore,_), _) -> getAllResults (Q.add newQscore (info, stateQ') q')
          Just (Right theScore) -> (info, theScore) : getAllResults q'
        else []
  ctxt = tt_ctxt istate
  classInfo = idris_classes istate
  (dag1, typeClassArgs1, retTy1) = makeDag type1
  argNames1 = map fst dag1
  makeDag :: Type -> (ArgsDAG, Classes, Type)
  makeDag = first3 (zipWith (\i (ty, deps) -> (ty, (i, deps))) [0..] . reverseDag) . 
    computeDagP (isTypeClassArg classInfo) . vToP . unLazy
  first3 f (a,b,c) = (f a, b, c)
  
  
  resolveUnis :: [(Name, Type)] -> State -> Maybe (State, [(Type, Type)])
  resolveUnis [] state = Just (state, [])
  resolveUnis ((name, term@(P Bound name2 _)) : xs) 
    state | isJust findArgs = do
    ((ty1, ix1), (ty2, ix2)) <- findArgs
    (state'', queue) <- resolveUnis xs state'
    let transScore = fromMaybe 0 (abs <$> (() <$> ix1 <*> ix2))
    return (inScore (\s -> s { transposition = transposition s + transScore }) state'', (ty1, ty2) : queue)
    where
    unresolved = argsAndClasses state
    inScore f stat = stat { score = f (score stat) }
    findArgs = ((,) <$> findName name  (left unresolved) <*> findName name2 (right unresolved)) <|>
               ((,) <$> findName name2 (left unresolved) <*> findName name  (right unresolved))
    matchnames = [name, name2]
    deleteArgs = deleteName name . deleteName name2
    state' = state { holes = filter (not . (`elem` matchnames) . fst) (holes state)
                   , argsAndClasses = both (modifyTypes (subst name term) . deleteArgs) unresolved}
  resolveUnis ((name, term) : xs)
    state@(State hs unresolved _ _) = case both (findName name) unresolved of
      Sided Nothing  Nothing  -> Nothing
      Sided (Just _) (Just _) -> error "Idris internal error: TypeSearch.resolveUnis"
      oneOfEach -> first (addScore (both scoreFor oneOfEach)) <$> nextStep
    where
    scoreFor (Just _) = mempty { argApp = 1 }
    scoreFor Nothing  = mempty { argApp = otherApplied }
    
    
    matchedVarMap = usedVars term
    bothT f (x, y) = (f x, f y)
    (injUsedVars, notInjUsedVars) = bothT M.keys . M.partition id . M.filterWithKey (\k _-> k `elem` map fst hs) $ M.map snd matchedVarMap 
    varsInTy = injUsedVars ++ notInjUsedVars
    toDelete = name : varsInTy
    deleteMany = foldr (.) id (map deleteName toDelete)
    otherApplied = length notInjUsedVars
    addScore additions theState = theState { score = let s = score theState in
      s { asymMods = asymMods s `mappend` additions } }
    state' = state { holes = filter (not . (`elem` toDelete) . fst) hs  
                   , argsAndClasses = both (modifyTypes (subst name term) . deleteMany) (argsAndClasses state) }
    nextStep = resolveUnis xs state'
  
  unifyQueue :: State -> [(Type, Type)] -> Maybe State
  unifyQueue state [] = return state
  unifyQueue state ((ty1, ty2) : queue) = do
    
    res <- tcToMaybe $ match_unify ctxt [ (n, Pi Nothing ty (TType (UVar 0))) | (n, ty) <- holes state] 
                                   (ty1, Nothing) 
                                   (ty2, Nothing) [] (map fst $ holes state) []
    (state', queueAdditions) <- resolveUnis res state
    guard $ scoreCriterion (score state')
    unifyQueue state' (queue ++ queueAdditions)
  possClassInstances :: [Name] -> Type -> [Type]
  possClassInstances usedns ty = do
    className <- getClassName clss
    classDef <- lookupCtxt className classInfo
    n <- class_instances classDef
    def <- lookupCtxt (fst n) (definitions ctxt)
    nty <- normaliseC ctxt [] <$> (case typeFromDef def of Just x -> [x]; Nothing -> [])
    let ty' = vToP (uniqueBinders usedns nty)
    return ty'
    where
    (clss, _) = unApply ty
    getClassName (P (TCon _ _) className _) = [className]
    getClassName _ = []
  
  
  nextStepsQueue :: Q.PQueue Score State -> Maybe (Either (Q.PQueue Score State) Score)
  nextStepsQueue queue = do
    ((nextScore, next), rest) <- Q.minViewWithKey queue
    Just $ if isFinal next 
      then Right nextScore
      else let additions = if scoreCriterion nextScore
                 then Q.fromList [ (score state, state) | state <- nextSteps next ]
                 else Q.empty in
           Left (Q.union rest additions)
    where
    isFinal (State [] (Sided ([], []) ([], [])) _ _) = True
    isFinal _ = False
  
  
  
  nextSteps :: State -> [State]
  
  nextSteps (State [] unresolved@(Sided ([], c1) ([], c2)) scoreAcc usedns) = 
    if null results3 then results4 else results3
    where
    
    results3 =
         catMaybes [ unifyQueue (State [] 
         (Sided ([], deleteFromArgList n1 c1)
                ([], map (second subst2for1) (deleteFromArgList n2 c2)))
         scoreAcc usedns) [(ty1, ty2)] 
     | (n1, ty1) <- c1, (n2, ty2) <- c2, let subst2for1 = psubst n2 (P Bound n1 ty1)]
    
    results4 = [ State [] (both (\(cs, _, _) -> ([], cs)) sds) 
               (scoreAcc `mappend` Score 0 0 (both (\(_, amods, _) -> amods) sds))
               (usedns ++ sided (++) (both (\(_, _, hs) -> hs) sds))
               | sds <- allMods ]
      where
      allMods = parallel defMod mods
      mods :: Sided [( Classes, AsymMods, [Name] )]
      mods = both (instanceMods . snd) unresolved
      defMod :: Sided (Classes, AsymMods, [Name])
      defMod = both (\(_, cs) -> (cs, mempty , [])) unresolved
      parallel :: Sided a -> Sided [a] -> [Sided a]
      parallel (Sided l r) (Sided ls rs) = map (flip Sided r) ls ++ map (Sided l) rs
      instanceMods :: Classes -> [( Classes , AsymMods, [Name] )]
      instanceMods classes = [ ( newClassArgs, mempty { typeClassApp = 1 }, newHoles )
                      | (_, ty) <- classes
                      , inst <- possClassInstances usedns ty
                      , newClassArgs <- maybeToList $ typeclassUnify classInfo ctxt ty inst
                      , let newHoles = map fst newClassArgs ]
  
  nextSteps (State hs (Sided (dagL, c1) (dagR, c2)) scoreAcc usedns) = results where
    results = concatMap takeSomeClasses results1
    
    
    canBeFirst :: ArgsDAG -> [(Name, Type)]
    canBeFirst = map fst . filter (S.null . snd . snd)
    
    results1 = catMaybes [ unifyQueue (State (filter (not . (`elem` [n1,n2]) . fst) hs) 
         (Sided (deleteFromDag n1 dagL, c1) 
                (inArgTys subst2for1 $ deleteFromDag n2 dagR, map (second subst2for1) c2))
          scoreAcc usedns) [(ty1, ty2)] 
     | (n1, ty1) <- canBeFirst dagL, (n2, ty2) <- canBeFirst dagR
     , let subst2for1 = psubst n2 (P Bound n1 ty1)]
  
  
  takeSomeClasses (State [] unresolved@(Sided ([], _) ([], _)) scoreAcc usedns) = 
    map statesFromMods . prod $ both (classMods . snd) unresolved
    where
    swap (Sided l r) = Sided r l
    statesFromMods :: Sided (Classes, AsymMods) -> State
    statesFromMods sides = let classes = both (\(c, _) -> ([], c)) sides
                               mods    = swap (both snd sides) in
      State [] classes (scoreAcc `mappend` (mempty { asymMods = mods })) usedns
    classMods :: Classes -> [(Classes, AsymMods)]
    classMods cs = let lcs = length cs in 
      [ (cs', mempty { typeClassIntro = lcs  length cs' }) | cs' <- subsets cs ]
    prod :: Sided [a] -> [Sided a]
    prod (Sided ls rs) = [Sided l r | l <- ls, r <- rs]
  
  takeSomeClasses s = [s]