module Idris.Core.CaseTree(CaseDef(..), SC, SC'(..), CaseAlt, CaseAlt'(..), ErasureInfo,
                     Phase(..), CaseTree, CaseType(..),
                     simpleCase, small, namesUsed, findCalls, findUsedArgs,
                     substSC, substAlt, mkForce) where
import Idris.Core.TT
import Control.Applicative hiding (Const)
import Control.Monad.State
import Control.Monad.Reader
import Data.Maybe
import Data.List hiding (partition)
import qualified Data.List(partition)
import Debug.Trace
data CaseDef = CaseDef [Name] !SC [Term]
    deriving Show
data SC' t = Case CaseType Name [CaseAlt' t]  
           | ProjCase t [CaseAlt' t] 
           | STerm !t
           | UnmatchedCase String 
           | ImpossibleCase 
    deriving (Eq, Ord, Functor)
data CaseType = Updatable | Shared
   deriving (Eq, Ord, Show)
type SC = SC' Term
data CaseAlt' t = ConCase Name Int [Name] !(SC' t)
                | FnCase Name [Name]      !(SC' t) 
                | ConstCase Const         !(SC' t)
                | SucCase Name            !(SC' t)
                | DefaultCase             !(SC' t)
    deriving (Show, Eq, Ord, Functor)
type CaseAlt = CaseAlt' Term
instance Show t => Show (SC' t) where
    show sc = show' 1 sc
      where
        show' i (Case up n alts) = "case" ++ u ++ show n ++ " of\n" ++ indent i ++
                                    showSep ("\n" ++ indent i) (map (showA i) alts)
            where u = case up of
                           Updatable -> "! "
                           Shared -> " "
        show' i (ProjCase tm alts) = "case " ++ show tm ++ " of " ++
                                      showSep ("\n" ++ indent i) (map (showA i) alts)
        show' i (STerm tm) = show tm
        show' i (UnmatchedCase str) = "error " ++ show str
        show' i ImpossibleCase = "impossible"
        indent i = concat $ take i (repeat "    ")
        showA i (ConCase n t args sc)
           = show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
                ++ show' (i+1) sc
        showA i (FnCase n args sc)
           = "FN " ++ show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
                ++ show' (i+1) sc
        showA i (ConstCase t sc)
           = show t ++ " => " ++ show' (i+1) sc
        showA i (SucCase n sc)
           = show n ++ "+1 => " ++ show' (i+1) sc
        showA i (DefaultCase sc)
           = "_ => " ++ show' (i+1) sc
type CaseTree = SC
type Clause   = ([Pat], (Term, Term))
type CS = ([Term], Int, [(Name, Type)])
instance TermSize SC where
    termsize n (Case _ n' as) = termsize n as
    termsize n (ProjCase n' as) = termsize n as
    termsize n (STerm t) = termsize n t
    termsize n _ = 1
instance TermSize CaseAlt where
    termsize n (ConCase _ _ _ s) = termsize n s
    termsize n (FnCase _ _ s) = termsize n s
    termsize n (ConstCase _ s) = termsize n s
    termsize n (SucCase _ s) = termsize n s
    termsize n (DefaultCase s) = termsize n s
small :: Name -> [Name] -> SC -> Bool
small n args t = let as = findAllUsedArgs t args in
                     length as == length (nub as) &&
                     termsize n t < 10
namesUsed :: SC -> [Name]
namesUsed sc = nub $ nu' [] sc where
    nu' ps (Case _ n alts) = nub (concatMap (nua ps) alts) \\ [n]
    nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts
    nu' ps (STerm t)     = nub $ nut ps t
    nu' ps _ = []
    nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) \\ args
    nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) \\ args
    nua ps (ConstCase _ sc) = nu' ps sc
    nua ps (SucCase _ sc) = nu' ps sc
    nua ps (DefaultCase sc) = nu' ps sc
    nut ps (P _ n _) | n `elem` ps = []
                     | otherwise = [n]
    nut ps (App _ f a) = nut ps f ++ nut ps a
    nut ps (Proj t _) = nut ps t
    nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
    nut ps (Bind n b sc) = nut (n:ps) sc
    nut ps _ = []
findCalls :: SC -> [Name] -> [(Name, [[Name]])]
findCalls sc topargs = nub $ nu' topargs sc where
    nu' ps (Case _ n alts) = nub (concatMap (nua (n : ps)) alts)
    nu' ps (ProjCase t alts) = nub $ nut ps t ++ concatMap (nua ps) alts
    nu' ps (STerm t)     = nub $ nut ps t
    nu' ps _ = []
    nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc)
    nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc)
    nua ps (ConstCase _ sc) = nu' ps sc
    nua ps (SucCase _ sc) = nu' ps sc
    nua ps (DefaultCase sc) = nu' ps sc
    nut ps (P _ n _) | n `elem` ps = []
                     | otherwise = [(n, [])] 
    nut ps fn@(App _ f a)
        | (P _ n _, args) <- unApply fn
             = if n `elem` ps then nut ps f ++ nut ps a
                  else [(n, map argNames args)] ++ concatMap (nut ps) args
        | (P (TCon _ _) n _, _) <- unApply fn = []
        | otherwise = nut ps f ++ nut ps a
    nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
    nut ps (Proj t _) = nut ps t
    nut ps (Bind n b sc) = nut (n:ps) sc
    nut ps _ = []
    argNames tm = let ns = directUse tm in
                      filter (\x -> x `elem` ns) topargs
directUse :: TT Name -> [Name]
directUse (P _ n _) = [n]
directUse (Bind n (Let t v) sc) = nub $ directUse v ++ (directUse sc \\ [n])
                                        ++ directUse t
directUse (Bind n b sc) = nub $ directUse (binderTy b) ++ (directUse sc \\ [n])
directUse fn@(App _ f a)
    | (P Ref (UN pfk) _, [App _ e w]) <- unApply fn,
         pfk == txt "prim_fork"
             = directUse e ++ directUse w 
    | (P Ref (UN fce) _, [_, _, a]) <- unApply fn,
         fce == txt "Force"
             = directUse a 
    | (P Ref n _, args) <- unApply fn = [] 
    | (P (TCon _ _) n _, args) <- unApply fn = [] 
    | otherwise = nub $ directUse f ++ directUse a
directUse (Proj x i) = nub $ directUse x
directUse _ = []
findUsedArgs :: SC -> [Name] -> [Name]
findUsedArgs sc topargs = nub (findAllUsedArgs sc topargs)
findAllUsedArgs sc topargs = filter (\x -> x `elem` topargs) (nu' sc) where
    nu' (Case _ n alts) = n : concatMap nua alts
    nu' (ProjCase t alts) = directUse t ++ concatMap nua alts
    nu' (STerm t)     = directUse t
    nu' _             = []
    nua (ConCase n i args sc) = nu' sc
    nua (FnCase n  args sc)   = nu' sc
    nua (ConstCase _ sc)      = nu' sc
    nua (SucCase _ sc)        = nu' sc
    nua (DefaultCase sc)      = nu' sc
isUsed :: SC -> Name -> Bool
isUsed sc n = used sc where
  used (Case _ n' alts) = n == n' || any usedA alts
  used (ProjCase t alts) = n `elem` freeNames t || any usedA alts
  used (STerm t) = n `elem` freeNames t
  used _ = False
  usedA (ConCase _ _ args sc) = used sc
  usedA (FnCase _ args sc) = used sc
  usedA (ConstCase _ sc) = used sc
  usedA (SucCase _ sc) = used sc
  usedA (DefaultCase sc) = used sc
type ErasureInfo = Name -> [Int]  
type CaseBuilder a = ReaderT ErasureInfo (State CS) a
runCaseBuilder :: ErasureInfo -> CaseBuilder a -> (CS -> (a, CS))
runCaseBuilder ei bld = runState $ runReaderT bld ei
data Phase = CompileTime | RunTime
    deriving (Show, Eq)
simpleCase :: Bool -> SC -> Bool ->
              Phase -> FC -> [Int] -> [Type] ->
              [([Name], Term, Term)] ->
              ErasureInfo ->
              TC CaseDef
simpleCase tc defcase reflect phase fc inacc argtys cs erInfo
      = sc' tc defcase phase fc (filter (\(_, _, r) ->
                                          case r of
                                            Impossible -> False
                                            _ -> True) cs)
          where
 sc' tc defcase phase fc []
                 = return $ CaseDef [] (UnmatchedCase (show fc ++ ":No pattern clauses")) []
 sc' tc defcase phase fc cs
      = let proj       = phase == RunTime
            vnames     = fstT (head cs)
            pats       = map (\ (avs, l, r) ->
                                   (avs, toPats reflect tc l, (l, r))) cs
            chkPats    = mapM chkAccessible pats in
            case chkPats of
                OK pats ->
                    let numargs    = length (fst (head pats))
                        ns         = take numargs args
                        (ns', ps') = order [(n, i `elem` inacc) | (i,n) <- zip [0..] ns] pats
                        (tree, st) = runCaseBuilder erInfo
                                         (match ns' ps' defcase)
                                         ([], numargs, [])
                        t          = CaseDef ns (prune proj (depatt ns' tree)) (fstT st) in
                        if proj then return (stripLambdas t)
                                else return t
                Error err -> Error (At fc err)
    where args = map (\i -> sMN i "e") [0..]
          defaultCase True = STerm Erased
          defaultCase False = UnmatchedCase "Error"
          fstT (x, _, _) = x
          lstT (_, _, x) = x
          
          
          chkAccessible (avs, l, c)
               | phase == RunTime || reflect = return (l, c)
               | otherwise = do mapM_ (acc l) avs
                                return (l, c)
          acc [] n = Error (Inaccessible n)
          acc (PV x t : xs) n | x == n = OK ()
          acc (PCon _ _ _ ps : xs) n = acc (ps ++ xs) n
          acc (PSuc p : xs) n = acc (p : xs) n
          acc (_ : xs) n = acc xs n
checkSameTypes :: [(Name, Type)] -> SC -> Bool
checkSameTypes tys (Case _ n alts)
        = case lookup n tys of
               Just t -> and (map (checkAlts t) alts)
               _ -> and (map ((checkSameTypes tys).getSC) alts)
  where
    checkAlts t (ConCase n _ _ sc) = isType n t && checkSameTypes tys sc
    checkAlts (Constant t) (ConstCase c sc) = isConstType c t && checkSameTypes tys sc
    checkAlts _ (ConstCase c sc) = False
    checkAlts _ _ = True
    getSC (ConCase _ _ _ sc) = sc
    getSC (FnCase _ _ sc) = sc
    getSC (ConstCase _ sc) = sc
    getSC (SucCase _ sc) = sc
    getSC (DefaultCase sc) = sc
checkSameTypes _ _ = True
isType n t | (P (TCon _ _) _ _, _) <- unApply t = True
isType n t | (P Ref _ _, _) <- unApply t = True
isType n t = False
isConstType (I _) (AType (ATInt ITNative)) = True
isConstType (BI _) (AType (ATInt ITBig)) = True
isConstType (Fl _) (AType ATFloat) = True
isConstType (Ch _) (AType (ATInt ITChar)) = True
isConstType (Str _) StrType = True
isConstType (B8 _) (AType (ATInt _)) = True
isConstType (B16 _) (AType (ATInt _)) = True
isConstType (B32 _) (AType (ATInt _)) = True
isConstType (B64 _) (AType (ATInt _)) = True
isConstType _ _ = False
data Pat = PCon Bool Name Int [Pat]
         | PConst Const
         | PV Name Type
         | PSuc Pat 
         | PReflected Name [Pat]
         | PAny
         | PTyPat 
    deriving Show
toPats :: Bool -> Bool -> Term -> [Pat]
toPats reflect tc f = reverse (toPat reflect tc (getArgs f)) where
   getArgs (App _ f a) = a : getArgs f
   getArgs _ = []
toPat :: Bool -> Bool -> [Term] -> [Pat]
toPat reflect tc = map $ toPat' []
  where
    toPat' [_,_,arg] (P (DCon t a uniq) nm@(UN n) _)
        | n == txt "Delay"
        = PCon uniq nm t [PAny, PAny, toPat' [] arg]
    toPat' args (P (DCon t a uniq) nm@(NS (UN n) [own]) _)
        | n == txt "Read" && own == txt "Ownership"
        = PCon False nm t (map shareCons (map (toPat' []) args))
      where shareCons (PCon _ n i ps) = PCon False n i (map shareCons ps)
            shareCons p = p
    toPat' args (P (DCon t a uniq) n _)
        = PCon uniq n t $ map (toPat' []) args
    
    toPat' [p, Constant (BI 1)] (P _ (UN pabi) _)
        | pabi == txt "prim__addBigInt"
        = PSuc $ toPat' [] p
    toPat' []   (P Bound n ty) = PV n ty
    toPat' args (App _ f a)    = toPat' (a : args) f
    toPat' [] (Constant x) | isTypeConst x = PTyPat
                           | otherwise     = PConst x
    toPat' [] (Bind n (Pi _ t _) sc)
        | reflect && noOccurrence n sc
        = PReflected (sUN "->") [toPat' [] t, toPat' [] sc]
    toPat' args (P _ n _)
        | reflect
        = PReflected n $ map (toPat' []) args
    toPat' _ t = PAny
    fixedN IT8 = "Bits8"
    fixedN IT16 = "Bits16"
    fixedN IT32 = "Bits32"
    fixedN IT64 = "Bits64"
data Partition = Cons [Clause]
               | Vars [Clause]
    deriving Show
isVarPat (PV _ _ : ps , _) = True
isVarPat (PAny   : ps , _) = True
isVarPat (PTyPat : ps , _) = True
isVarPat _                 = False
isConPat (PCon _ _ _ _ : ps, _) = True
isConPat (PReflected _ _ : ps, _) = True
isConPat (PSuc _   : ps, _) = True
isConPat (PConst _   : ps, _) = True
isConPat _                    = False
partition :: [Clause] -> [Partition]
partition [] = []
partition ms@(m : _)
    | isVarPat m = let (vars, rest) = span isVarPat ms in
                       Vars vars : partition rest
    | isConPat m = let (cons, rest) = span isConPat ms in
                       Cons cons : partition rest
partition xs = error $ "Partition " ++ show xs
order :: [(Name, Bool)] -> [Clause] -> ([Name], [Clause])
order []  cs = ([], cs)
order ns' [] = (map fst ns', [])
order ns' cs = let patnames = transpose (map (zip ns') (map fst cs))
                   
                   
                   
                   
                   (patnames_ord, patnames_rest)
                        = Data.List.partition (noClash . map snd) patnames
                   
                   pats' = transpose (sortBy moreDistinct (reverse patnames_ord)
                                         ++ patnames_rest) in
                   (getNOrder pats', zipWith rebuild pats' cs)
  where
    getNOrder [] = error $ "Failed order on " ++ show (map fst ns', cs)
    getNOrder (c : _) = map (fst . fst) c
    rebuild patnames clause = (map snd patnames, snd clause)
    noClash [] = True
    noClash (p : ps) = not (any (clashPat p) ps) && noClash ps
    clashPat (PCon _ _ _ _) (PConst _) = True
    clashPat (PConst _) (PCon _ _ _ _) = True
    clashPat (PCon _ _ _ _) (PSuc _) = True
    clashPat (PSuc _) (PCon _ _ _ _) = True
    clashPat (PCon _ n i _) (PCon _ n' i' _) | i == i' = n /= n'
    clashPat _ _ = False
    
    moreDistinct xs ys = compare (snd . fst . head $ xs, numNames [] (map snd ys))
                                 (snd . fst . head $ ys, numNames [] (map snd xs))
    numNames xs (PCon _ n _ _ : ps)
        | not (Left n `elem` xs) = numNames (Left n : xs) ps
    numNames xs (PConst c : ps)
        | not (Right c `elem` xs) = numNames (Right c : xs) ps
    numNames xs (_ : ps) = numNames xs ps
    numNames xs [] = length xs
match :: [Name] -> [Clause] -> SC 
                            -> CaseBuilder SC
match [] (([], ret) : xs) err
    = do (ts, v, ntys) <- get
         put (ts ++ (map (fst.snd) xs), v, ntys)
         case snd ret of
            Impossible -> return ImpossibleCase
            tm -> return $ STerm tm 
match vs cs err = do let ps = partition cs
                     mixture vs ps err
mixture :: [Name] -> [Partition] -> SC -> CaseBuilder SC
mixture vs [] err = return err
mixture vs (Cons ms : ps) err = do fallthrough <- mixture vs ps err
                                   conRule vs ms fallthrough
mixture vs (Vars ms : ps) err = do fallthrough <- mixture vs ps err
                                   varRule vs ms fallthrough
inaccessibleArgs :: Name -> CaseBuilder [Int]
inaccessibleArgs n = do
    getInaccessiblePositions <- ask  
    return $ getInaccessiblePositions n
data ConType = CName Name Int 
             | CFn Name 
             | CSuc 
             | CConst Const 
   deriving (Show, Eq)
data Group = ConGroup Bool 
                      ConType 
                      [([Pat], Clause)] 
   deriving Show
conRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC
conRule (v:vs) cs err = do groups <- groupCons cs
                           caseGroups (v:vs) groups err
caseGroups :: [Name] -> [Group] -> SC -> CaseBuilder SC
caseGroups (v:vs) gs err = do g <- altGroups gs
                              return $ Case (getShared gs) v (sort g)
  where
    getShared (ConGroup True _ _ : _) = Updatable
    getShared _ = Shared
    altGroups [] = return [DefaultCase err]
    altGroups (ConGroup _ (CName n i) args : cs)
        = (:) <$> altGroup n i args <*> altGroups cs
    altGroups (ConGroup _ (CFn n) args : cs)
        = (:) <$> altFnGroup n args <*> altGroups cs
    altGroups (ConGroup _ CSuc args : cs)
        = (:) <$> altSucGroup args <*> altGroups cs
    altGroups (ConGroup _ (CConst c) args : cs)
        = (:) <$> altConstGroup c args <*> altGroups cs
    altGroup n i args
         = do inacc <- inaccessibleArgs n
              (newVars, accVars, inaccVars, nextCs) <- argsToAlt inacc args
              matchCs <- match (accVars ++ vs ++ inaccVars) nextCs err
              return $ ConCase n i newVars matchCs
    altFnGroup n args = do (newVars, _, [], nextCs) <- argsToAlt [] args
                           matchCs <- match (newVars ++ vs) nextCs err
                           return $ FnCase n newVars matchCs
    altSucGroup args = do ([newVar], _, [], nextCs) <- argsToAlt [] args
                          matchCs <- match (newVar:vs) nextCs err
                          return $ SucCase newVar matchCs
    altConstGroup n args = do (_, _, [], nextCs) <- argsToAlt [] args
                              matchCs <- match vs nextCs err
                              return $ ConstCase n matchCs
argsToAlt :: [Int] -> [([Pat], Clause)] -> CaseBuilder ([Name], [Name], [Name], [Clause])
argsToAlt _ [] = return ([], [], [], [])
argsToAlt inacc rs@((r, m) : rest) = do
    newVars <- getNewVars r
    let (accVars, inaccVars) = partitionAcc newVars
    return (newVars, accVars, inaccVars, addRs rs)
  where
    
    getNewVars :: [Pat] -> CaseBuilder [Name]
    getNewVars [] = return []
    getNewVars ((PV n t) : ns) = do v <- getVar "e"
                                    nsv <- getNewVars ns
                                    
                                    
                                    
                                    
                                    
                                    
                                    
                                    
                                    (cs, i, ntys) <- get
                                    put (cs, i, (v, t) : ntys)
                                    return (v : nsv)
    getNewVars (PAny   : ns) = (:) <$> getVar "i" <*> getNewVars ns
    getNewVars (PTyPat : ns) = (:) <$> getVar "t" <*> getNewVars ns
    getNewVars (_      : ns) = (:) <$> getVar "e" <*> getNewVars ns
    
    
    partitionAcc xs =
        ( [x | (i,x) <- zip [0..] xs, i `notElem` inacc]
        , [x | (i,x) <- zip [0..] xs, i    `elem` inacc]
        )
    addRs [] = []
    addRs ((r, (ps, res)) : rs) = ((acc++ps++inacc, res) : addRs rs)
      where
        (acc, inacc) = partitionAcc r
    uniq i (UN n) = MN i n
    uniq i n = n
getVar :: String -> CaseBuilder Name
getVar b = do (t, v, ntys) <- get; put (t, v+1, ntys); return (sMN v b)
groupCons :: [Clause] -> CaseBuilder [Group]
groupCons cs = gc [] cs
  where
    gc acc [] = return acc
    gc acc ((p : ps, res) : cs) =
        do acc' <- addGroup p ps res acc
           gc acc' cs
    addGroup p ps res acc = case p of
        PCon uniq con i args -> return $ addg uniq (CName con i) args (ps, res) acc
        PConst cval -> return $ addConG cval (ps, res) acc
        PSuc n -> return $ addg False CSuc [n] (ps, res) acc
        PReflected fn args -> return $ addg False (CFn fn) args (ps, res) acc
        pat -> fail $ show pat ++ " is not a constructor or constant (can't happen)"
    addg uniq c conargs res []
           = [ConGroup uniq c [(conargs, res)]]
    addg uniq c conargs res (g@(ConGroup _ c' cs):gs)
        | c == c' = ConGroup uniq c (cs ++ [(conargs, res)]) : gs
        | otherwise = g : addg uniq c conargs res gs
    addConG con res [] = [ConGroup False (CConst con) [([], res)]]
    addConG con res (g@(ConGroup False (CConst n) cs) : gs)
        | con == n = ConGroup False (CConst n) (cs ++ [([], res)]) : gs
    addConG con res (g : gs) = g : addConG con res gs
varRule :: [Name] -> [Clause] -> SC -> CaseBuilder SC
varRule (v : vs) alts err =
    do alts' <- mapM (repVar v) alts
       match vs alts' err
  where
    repVar v (PV p ty : ps , (lhs, res))
           = do (cs, i, ntys) <- get
                put (cs, i, (v, ty) : ntys)
                return (ps, (lhs, subst p (P Bound v ty) res))
    repVar v (PAny : ps , res) = return (ps, res)
    repVar v (PTyPat : ps , res) = return (ps, res)
depatt :: [Name] -> SC -> SC
depatt ns tm = dp [] tm
  where
    dp ms (STerm tm) = STerm (applyMaps ms tm)
    dp ms (Case up x alts) = Case up x (map (dpa ms x) alts)
    dp ms sc = sc
    dpa ms x (ConCase n i args sc)
        = ConCase n i args (dp ((x, (n, args)) : ms) sc)
    dpa ms x (FnCase n args sc)
        = FnCase n args (dp ((x, (n, args)) : ms) sc)
    dpa ms x (ConstCase c sc) = ConstCase c (dp ms sc)
    dpa ms x (SucCase n sc) = SucCase n (dp ms sc)
    dpa ms x (DefaultCase sc) = DefaultCase (dp ms sc)
    applyMaps ms f@(App _ _ _)
       | (P nt cn pty, args) <- unApply f
            = let args' = map (applyMaps ms) args in
                  applyMap ms nt cn pty args'
        where
          applyMap [] nt cn pty args' = mkApp (P nt cn pty) args'
          applyMap ((x, (n, args)) : ms) nt cn pty args'
            | and ((length args == length args') :
                     (n == cn) : zipWith same args args') = P Ref x Erased
            | otherwise = applyMap ms nt cn pty args'
          same n (P _ n' _) = n == n'
          same _ _ = False
    applyMaps ms (App s f a) = App s (applyMaps ms f) (applyMaps ms a)
    applyMaps ms t = t
prune :: Bool 
      -> SC -> SC
prune proj (Case up n alts) = case alts' of
    [] -> ImpossibleCase
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    as@[ConCase cn i args sc]
        | proj -> let sc' = prune proj sc in
                      if any (isUsed sc') args
                         then Case up n [ConCase cn i args sc']
                         else sc'
    [SucCase cn sc]
        | proj
        -> projRep cn n (1) $ prune proj sc
    [ConstCase _ sc]
        -> prune proj sc
    
    
    [s@(SucCase _ _), DefaultCase dc]
        -> Case up n [ConstCase (BI 0) dc, s]
    as  -> Case up n as
  where
    alts' = filter (not . erased) $ map pruneAlt alts
    pruneAlt (ConCase cn i ns sc) = ConCase cn i ns (prune proj sc)
    pruneAlt (FnCase cn ns sc) = FnCase cn ns (prune proj sc)
    pruneAlt (ConstCase c sc) = ConstCase c (prune proj sc)
    pruneAlt (SucCase n sc) = SucCase n (prune proj sc)
    pruneAlt (DefaultCase sc) = DefaultCase (prune proj sc)
    erased (DefaultCase (STerm Erased)) = True
    erased (DefaultCase ImpossibleCase) = True
    erased _ = False
    projRep :: Name -> Name -> Int -> SC -> SC
    projRep arg n i (Case up x alts) | x == arg
        = ProjCase (Proj (P Bound n Erased) i) $ map (projRepAlt arg n i) alts
    projRep arg n i (Case up x alts)
        = Case up x (map (projRepAlt arg n i) alts)
    projRep arg n i (ProjCase t alts)
        = ProjCase (projRepTm arg n i t) $ map (projRepAlt arg n i) alts
    projRep arg n i (STerm t) = STerm (projRepTm arg n i t)
    projRep arg n i c = c
    projRepAlt arg n i (ConCase cn t args rhs)
        = ConCase cn t args (projRep arg n i rhs)
    projRepAlt arg n i (FnCase cn args rhs)
        = FnCase cn args (projRep arg n i rhs)
    projRepAlt arg n i (ConstCase t rhs)
        = ConstCase t (projRep arg n i rhs)
    projRepAlt arg n i (SucCase sn rhs)
        = SucCase sn (projRep arg n i rhs)
    projRepAlt arg n i (DefaultCase rhs)
        = DefaultCase (projRep arg n i rhs)
    projRepTm arg n i t = subst arg (Proj (P Bound n Erased) i) t
prune _ t = t
stripLambdas :: CaseDef -> CaseDef
stripLambdas (CaseDef ns (STerm (Bind x (Lam _) sc)) tm)
    = stripLambdas (CaseDef (ns ++ [x]) (STerm (instantiate (P Bound x Erased) sc)) tm)
stripLambdas x = x
substSC :: Name -> Name -> SC -> SC
substSC n repl (Case up n' alts)
    | n == n'   = Case up repl (map (substAlt n repl) alts)
    | otherwise = Case up n'   (map (substAlt n repl) alts)
substSC n repl (STerm t) = STerm $ subst n (P Bound repl Erased) t
substSC n repl (UnmatchedCase errmsg) = UnmatchedCase errmsg
substSC n repl  ImpossibleCase = ImpossibleCase
substSC n repl sc = error $ "unsupported in substSC: " ++ show sc
substAlt :: Name -> Name -> CaseAlt -> CaseAlt
substAlt n repl (ConCase cn a ns sc) = ConCase cn a ns (substSC n repl sc)
substAlt n repl (FnCase fn ns sc)    = FnCase fn ns (substSC n repl sc)
substAlt n repl (ConstCase c sc)     = ConstCase c (substSC n repl sc)
substAlt n repl (SucCase n' sc)
    | n == n'   = SucCase n  (substSC n repl sc)
    | otherwise = SucCase n' (substSC n repl sc)
substAlt n repl (DefaultCase sc)     = DefaultCase (substSC n repl sc)
mkForce :: Name -> Name -> SC -> SC
mkForce = mkForceSC
  where
    mkForceSC n arg (Case up x alts) | x == arg
        = Case up n $ map (mkForceAlt n arg) alts
    mkForceSC n arg (Case up x alts)
        = Case up x (map (mkForceAlt n arg) alts)
    mkForceSC n arg (ProjCase t alts)
        = ProjCase t $ map (mkForceAlt n arg) alts
    mkForceSC n arg c = c
    mkForceAlt n arg (ConCase cn t args rhs)
        = ConCase cn t args (mkForceSC n arg rhs)
    mkForceAlt n arg (FnCase cn args rhs)
        = FnCase cn args (mkForceSC n arg rhs)
    mkForceAlt n arg (ConstCase t rhs)
        = ConstCase t (mkForceSC n arg rhs)
    mkForceAlt n arg (SucCase sn rhs)
        = SucCase sn (mkForceSC n arg rhs)
    mkForceAlt n arg (DefaultCase rhs)
        = DefaultCase (mkForceSC n arg rhs)
    forceTm n arg t = subst n arg t