module Language.Haskell.Liquid.Constraint.RewriteCase 
    (getCaseRewrites) 
    where

import           Language.Fixpoint.Types
import qualified Language.Fixpoint.Misc as M
import           Language.Haskell.Liquid.Constraint.Types
import           Language.Haskell.Liquid.Types.Types
import           Language.Haskell.Liquid.Types.RType

import           Data.Maybe
import           Data.Tuple

import qualified Data.HashMap.Strict as M
import qualified Data.HashSet        as S

getCaseRewrites :: CGEnv -> SpecType -> LocalRewrites
getCaseRewrites :: CGEnv -> SpecType -> LocalRewrites
getCaseRewrites CGEnv
γ SpecType
spec =
    let Reft (Symbol
_, Expr
refinement) = UReftV Symbol (ReftV Symbol) -> ReftV Symbol
forall v r. UReftV v r -> r
ur_reft (UReftV Symbol (ReftV Symbol) -> ReftV Symbol)
-> UReftV Symbol (ReftV Symbol) -> ReftV Symbol
forall a b. (a -> b) -> a -> b
$ SpecType -> UReftV Symbol (ReftV Symbol)
forall v c tv r. RTypeV v c tv r -> r
rt_reft SpecType
spec
        -- Names of constustors both type constructors and data constructors
        ctors :: HashSet Symbol
ctors   = HashMap Symbol Sort -> HashSet Symbol
forall {v}. HashMap Symbol v -> HashSet Symbol
toSet (HashMap Symbol Sort -> HashSet Symbol)
-> HashMap Symbol Sort -> HashSet Symbol
forall a b. (a -> b) -> a -> b
$ SEnv Sort -> HashMap Symbol Sort
forall a. SEnv a -> HashMap Symbol a
seBinds  (SEnv Sort -> HashMap Symbol Sort)
-> SEnv Sort -> HashMap Symbol Sort
forall a b. (a -> b) -> a -> b
$ CGEnv -> SEnv Sort
constEnv CGEnv
γ
        -- All the global names (top level functions, types, etc...)
        globals :: HashSet Symbol
globals = HashMap Symbol SpecType -> HashSet Symbol
forall {v}. HashMap Symbol v -> HashSet Symbol
toSet (HashMap Symbol SpecType -> HashSet Symbol)
-> HashMap Symbol SpecType -> HashSet Symbol
forall a b. (a -> b) -> a -> b
$ AREnv SpecType -> HashMap Symbol SpecType
forall t. AREnv t -> HashMap Symbol t
reGlobal (AREnv SpecType -> HashMap Symbol SpecType)
-> AREnv SpecType -> HashMap Symbol SpecType
forall a b. (a -> b) -> a -> b
$ CGEnv -> AREnv SpecType
renv     CGEnv
γ
    in [(Symbol, Expr)] -> LocalRewrites
unloop
       ([(Symbol, Expr)] -> LocalRewrites)
-> [(Symbol, Expr)] -> LocalRewrites
forall a b. (a -> b) -> a -> b
$ ((Expr, Expr) -> [(Symbol, Expr)])
-> [(Expr, Expr)] -> [(Symbol, Expr)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Expr -> Expr -> [(Symbol, Expr)])
-> (Expr, Expr) -> [(Symbol, Expr)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Expr -> Expr -> [(Symbol, Expr)])
 -> (Expr, Expr) -> [(Symbol, Expr)])
-> (Expr -> Expr -> [(Symbol, Expr)])
-> (Expr, Expr)
-> [(Symbol, Expr)]
forall a b. (a -> b) -> a -> b
$ HashSet Symbol
-> HashSet Symbol -> Expr -> Expr -> [(Symbol, Expr)]
unify HashSet Symbol
ctors HashSet Symbol
globals)
       ([(Expr, Expr)] -> [(Symbol, Expr)])
-> [(Expr, Expr)] -> [(Symbol, Expr)]
forall a b. (a -> b) -> a -> b
$ [(Expr, Expr)] -> [(Expr, Expr)]
groupUnifiableEqualities 
       ([(Expr, Expr)] -> [(Expr, Expr)])
-> [(Expr, Expr)] -> [(Expr, Expr)]
forall a b. (a -> b) -> a -> b
$ Expr -> [(Expr, Expr)]
getEqualities Expr
refinement
    where toSet :: HashMap Symbol v -> HashSet Symbol
toSet = [Symbol] -> HashSet Symbol
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([Symbol] -> HashSet Symbol)
-> (HashMap Symbol v -> [Symbol])
-> HashMap Symbol v
-> HashSet Symbol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Symbol v -> [Symbol]
forall k v. HashMap k v -> [k]
M.keys

-- | Generates substitutions for non-global variables that make @e1@ and @e2@
-- equal.
--
-- If @v@ is not global, and @C@ is a data constructor
--
--  * @v@ and @e2@ produces @(v, e2)@
--  * @e1@ and @v@ produces @(e1, v)@
--  * @C a₁ ... aₙ@ and @C b₁ ... bₙ@ produces the substitutions from unifying
--    @(a₁, b₁), ..., (aₙ, bₙ)@
--
-- If any unification fails, the substitutions from the unifications that
-- succeed are still produced.
--
unify :: S.HashSet Symbol -> S.HashSet Symbol -> Expr -> Expr -> [(Symbol, Expr)]
unify :: HashSet Symbol
-> HashSet Symbol -> Expr -> Expr -> [(Symbol, Expr)]
unify HashSet Symbol
ctors HashSet Symbol
globals = Expr -> Expr -> [(Symbol, Expr)]
go
    where
        go :: Expr -> Expr -> [(Symbol, Expr)]
go Expr
e1 Expr
e2 | Expr
e1 Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
e2 = []
        -- NOTE: We don't need to check for ECst because the expressions arent
        -- elaborated
        go (EVar Symbol
s1) Expr
e2 | Symbol -> Bool
isLocal Symbol
s1 = [(Symbol
s1, Expr
e2)]
        go Expr
e1 (EVar Symbol
s2) | Symbol -> Bool
isLocal Symbol
s2 = [(Symbol
s2, Expr
e1)]
        -- TODO: Tecnically we could also unify under lambdas but you have to be
        -- carefull about alpha equivalence idk if the effort is worth it.
        go Expr
e1 Expr
e2 
            -- Performing the unification under constructor is safe because 
            -- C a₁ ... aₙ = C b₁ ... bₙ ⟺ ∀ n . a₁ = bₙ
            | (EVar Symbol
name1 , [Expr]
args1) <- Expr -> (Expr, [Expr])
forall v. ExprV v -> (ExprV v, [ExprV v])
splitEApp Expr
e2
            , (EVar Symbol
name2 , [Expr]
args2) <- Expr -> (Expr, [Expr])
forall v. ExprV v -> (ExprV v, [ExprV v])
splitEApp Expr
e1
            , Symbol
name1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
name2
            , Symbol -> Bool
isCtor Symbol
name1
            , [Expr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
args1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Expr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
args2
            = [[(Symbol, Expr)]] -> [(Symbol, Expr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Symbol, Expr)]] -> [(Symbol, Expr)])
-> [[(Symbol, Expr)]] -> [(Symbol, Expr)]
forall a b. (a -> b) -> a -> b
$ (Expr -> Expr -> [(Symbol, Expr)])
-> [Expr] -> [Expr] -> [[(Symbol, Expr)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Expr -> Expr -> [(Symbol, Expr)]
go [Expr]
args1 [Expr]
args2
        go Expr
_ Expr
_ = []

        isCtor :: Symbol -> Bool
isCtor  Symbol
name = Symbol
name Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet Symbol
ctors
        isLocal :: Symbol -> Bool
isLocal Symbol
name = Bool -> Bool
not (Symbol
name Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet Symbol
globals 
                           Bool -> Bool -> Bool
|| Symbol
name Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet Symbol
ctors
                           Bool -> Bool -> Bool
|| Symbol -> Symbol -> Bool
isPrefixOfSym Symbol
anfPrefix Symbol
name)


-- | Given a list of equalities this function produces the equalities that
-- result from applying transitivity exactly once. For instance, if we have
-- @[e1=e2, e2=e3, e1=e4]@ this function will produce @[e1=e3, e2=e4]@.
groupUnifiableEqualities :: [(Expr, Expr)] -> [(Expr, Expr)]
groupUnifiableEqualities :: [(Expr, Expr)] -> [(Expr, Expr)]
groupUnifiableEqualities = [[(Expr, Expr)]] -> [(Expr, Expr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Expr, Expr)]] -> [(Expr, Expr)])
-> ([(Expr, Expr)] -> [[(Expr, Expr)]])
-> [(Expr, Expr)]
-> [(Expr, Expr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Expr] -> [[(Expr, Expr)]]) -> [[Expr]] -> [[(Expr, Expr)]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [Expr] -> [[(Expr, Expr)]]
forall {b}. [b] -> [[(b, b)]]
mkEqs ([[Expr]] -> [[(Expr, Expr)]])
-> ([(Expr, Expr)] -> [[Expr]])
-> [(Expr, Expr)]
-> [[(Expr, Expr)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Expr, Expr)] -> [[Expr]]
forall {v}. Hashable v => [(v, v)] -> [[v]]
grouping
  where
    mkEqs :: [b] -> [[(b, b)]]
mkEqs (b
e1 : [b]
es) = [ (b
e1, b
e) | b
e <- [b]
es ] [(b, b)] -> [[(b, b)]] -> [[(b, b)]]
forall a. a -> [a] -> [a]
: [b] -> [[(b, b)]]
mkEqs [b]
es
    mkEqs [b]
_         = []

    grouping :: [(v, v)] -> [[v]]
grouping [(v, v)]
eqs = ((v, [v]) -> [v]) -> [(v, [v])] -> [[v]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (v, [v]) -> [v]
forall a b. (a, b) -> b
snd ([(v, [v])] -> [[v]]) -> [(v, [v])] -> [[v]]
forall a b. (a -> b) -> a -> b
$ [(v, v)] -> [(v, [v])]
forall k v. (Eq k, Hashable k) => [(k, v)] -> [(k, [v])]
M.groupList ([(v, v)] -> [(v, [v])]) -> [(v, v)] -> [(v, [v])]
forall a b. (a -> b) -> a -> b
$ [(v, v)]
eqs [(v, v)] -> [(v, v)] -> [(v, v)]
forall a. [a] -> [a] -> [a]
++ ((v, v) -> (v, v)) -> [(v, v)] -> [(v, v)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (v, v) -> (v, v)
forall a b. (a, b) -> (b, a)
swap [(v, v)]
eqs

getEqualities :: Expr -> [(Expr, Expr)]
getEqualities :: Expr -> [(Expr, Expr)]
getEqualities (PAtom Brel
Eq Expr
e1 Expr
e2) = [(Expr
e1, Expr
e2)]
getEqualities (PAnd [Expr]
eqs)       = (Expr -> [(Expr, Expr)]) -> [Expr] -> [(Expr, Expr)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Expr -> [(Expr, Expr)]
getEqualities [Expr]
eqs
getEqualities Expr
_                = []


-- +-----------------------------------------------------+
-- | AcyclicRewrites: collection of rewrites that cannot |
-- | cause an infinite loop                              |
-- +-----------------------------------------------------+
-- This could be extracted as a separate module

-- | A collection of rewrites that cannot cause an infinite loop
newtype AcyclicRewrites = AR (M.HashMap Symbol Expr)

-- We can think of the map as a directed graph where every symbol is a vertex and
-- there is an edge v₁ → v₂ if v₂ is free in the expression that v₁ is rewritten to.
-- To guarantee that the set of rewrite rules is terminating, we ensure that there
-- aren't any cycles in the graph.

-- | Drops rewrites that would cause an infinite loop. The procedure is order
-- biased as rewrites earlier in the list take precedence.
unloop :: [(Symbol, Expr)] -> LocalRewrites
unloop :: [(Symbol, Expr)] -> LocalRewrites
unloop = HashMap Symbol Expr -> LocalRewrites
LocalRewrites (HashMap Symbol Expr -> LocalRewrites)
-> ([(Symbol, Expr)] -> HashMap Symbol Expr)
-> [(Symbol, Expr)]
-> LocalRewrites
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AcyclicRewrites -> HashMap Symbol Expr
toRewrites (AcyclicRewrites -> HashMap Symbol Expr)
-> ([(Symbol, Expr)] -> AcyclicRewrites)
-> [(Symbol, Expr)]
-> HashMap Symbol Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AcyclicRewrites -> (Symbol, Expr) -> AcyclicRewrites)
-> AcyclicRewrites -> [(Symbol, Expr)] -> AcyclicRewrites
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' AcyclicRewrites -> (Symbol, Expr) -> AcyclicRewrites
doInsert AcyclicRewrites
empty 
    where doInsert :: AcyclicRewrites -> (Symbol, Expr) -> AcyclicRewrites
doInsert AcyclicRewrites
ar (Symbol
s, Expr
e) = AcyclicRewrites
ar AcyclicRewrites -> Maybe AcyclicRewrites -> AcyclicRewrites
forall a. a -> Maybe a -> a
`fromMaybe` AcyclicRewrites -> Symbol -> Expr -> Maybe AcyclicRewrites
insert AcyclicRewrites
ar Symbol
s Expr
e 

-- | Get the "raw" list of rewrites
toRewrites :: AcyclicRewrites -> M.HashMap Symbol Expr
toRewrites :: AcyclicRewrites -> HashMap Symbol Expr
toRewrites (AR HashMap Symbol Expr
m) = HashMap Symbol Expr
m

-- | @existsPth g s1 s2@ yields @True@ checks if there is a path from @s1@ to @s2@
-- in @g@. Time is @O(Σ(e ∈ m) |exprSymbolSet e|)@, or said otherwise, it is @O(m)@
-- where @m@ is the number of edges.
existsPath :: AcyclicRewrites -> Symbol -> Symbol -> Bool
existsPath :: AcyclicRewrites -> Symbol -> Symbol -> Bool
existsPath (AR HashMap Symbol Expr
m) Symbol
s1' Symbol
s2 = Symbol -> Bool
go Symbol
s1'
  where
    -- Since m is a DAG, we can use DFS to check if there is a path from s1 to
    -- s2 without worrying about infinite loops
    go :: Symbol -> Bool
go Symbol
s1 | Symbol
s1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
s2 = Bool
True
    go Symbol
s1 | Just Expr
e <- Symbol -> HashMap Symbol Expr -> Maybe Expr
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Symbol
s1 HashMap Symbol Expr
m
          = (Symbol -> Bool) -> [Symbol] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Symbol -> Bool
go ([Symbol] -> Bool) -> [Symbol] -> Bool
forall a b. (a -> b) -> a -> b
$ HashSet Symbol -> [Symbol]
forall a. HashSet a -> [a]
S.toList (HashSet Symbol -> [Symbol]) -> HashSet Symbol -> [Symbol]
forall a b. (a -> b) -> a -> b
$ Expr -> HashSet Symbol
exprSymbolsSet Expr
e
    go Symbol
_ = Bool
False

-- | Constructs an empty set of rewrites
empty :: AcyclicRewrites
empty :: AcyclicRewrites
empty = HashMap Symbol Expr -> AcyclicRewrites
AR HashMap Symbol Expr
forall k v. HashMap k v
M.empty

-- | Inserts the rewrite if it wont't cause an infinite loop
insert :: AcyclicRewrites -> Symbol -> Expr -> Maybe AcyclicRewrites
insert :: AcyclicRewrites -> Symbol -> Expr -> Maybe AcyclicRewrites
insert AcyclicRewrites
ar Symbol
s Expr
e | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Symbol
s Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` Expr -> HashSet Symbol
exprSymbolsSet Expr
e
              , Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Symbol -> Bool) -> [Symbol] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Symbol
s2 -> AcyclicRewrites -> Symbol -> Symbol -> Bool
existsPath AcyclicRewrites
ar Symbol
s2 Symbol
s) ([Symbol] -> Bool) -> [Symbol] -> Bool
forall a b. (a -> b) -> a -> b
$ HashSet Symbol -> [Symbol]
forall a. HashSet a -> [a]
S.toList (HashSet Symbol -> [Symbol]) -> HashSet Symbol -> [Symbol]
forall a b. (a -> b) -> a -> b
$ Expr -> HashSet Symbol
exprSymbolsSet Expr
e
              -- There are two ways to break the DAG invariant
              -- 1. If the rewrite is closing a loop
              -- 2. If the rewrite by itself is a cycle
              = AcyclicRewrites -> Maybe AcyclicRewrites
forall a. a -> Maybe a
Just (AcyclicRewrites -> Maybe AcyclicRewrites)
-> AcyclicRewrites -> Maybe AcyclicRewrites
forall a b. (a -> b) -> a -> b
$ AcyclicRewrites -> Symbol -> Expr -> AcyclicRewrites
insertUnsafe AcyclicRewrites
ar Symbol
s Expr
e
              | Bool
otherwise 
              = Maybe AcyclicRewrites
forall a. Maybe a
Nothing
    where insertUnsafe :: AcyclicRewrites -> Symbol -> Expr -> AcyclicRewrites
insertUnsafe (AR HashMap Symbol Expr
m) Symbol
s' Expr
e' = HashMap Symbol Expr -> AcyclicRewrites
AR (HashMap Symbol Expr -> AcyclicRewrites)
-> HashMap Symbol Expr -> AcyclicRewrites
forall a b. (a -> b) -> a -> b
$ Symbol -> Expr -> HashMap Symbol Expr -> HashMap Symbol Expr
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert Symbol
s' Expr
e' HashMap Symbol Expr
m