module Language.Haskell.Liquid.Constraint.RewriteCase
(getCaseRewrites)
where
import Language.Fixpoint.Types
import qualified Language.Fixpoint.Misc as M
import Language.Haskell.Liquid.Constraint.Types
import Language.Haskell.Liquid.Types.Types
import Language.Haskell.Liquid.Types.RType
import Data.Maybe
import Data.Tuple
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as S
getCaseRewrites :: CGEnv -> SpecType -> LocalRewrites
getCaseRewrites :: CGEnv -> SpecType -> LocalRewrites
getCaseRewrites CGEnv
γ SpecType
spec =
let Reft (Symbol
_, Expr
refinement) = UReftV Symbol (ReftV Symbol) -> ReftV Symbol
forall v r. UReftV v r -> r
ur_reft (UReftV Symbol (ReftV Symbol) -> ReftV Symbol)
-> UReftV Symbol (ReftV Symbol) -> ReftV Symbol
forall a b. (a -> b) -> a -> b
$ SpecType -> UReftV Symbol (ReftV Symbol)
forall v c tv r. RTypeV v c tv r -> r
rt_reft SpecType
spec
ctors :: HashSet Symbol
ctors = HashMap Symbol Sort -> HashSet Symbol
forall {v}. HashMap Symbol v -> HashSet Symbol
toSet (HashMap Symbol Sort -> HashSet Symbol)
-> HashMap Symbol Sort -> HashSet Symbol
forall a b. (a -> b) -> a -> b
$ SEnv Sort -> HashMap Symbol Sort
forall a. SEnv a -> HashMap Symbol a
seBinds (SEnv Sort -> HashMap Symbol Sort)
-> SEnv Sort -> HashMap Symbol Sort
forall a b. (a -> b) -> a -> b
$ CGEnv -> SEnv Sort
constEnv CGEnv
γ
globals :: HashSet Symbol
globals = HashMap Symbol SpecType -> HashSet Symbol
forall {v}. HashMap Symbol v -> HashSet Symbol
toSet (HashMap Symbol SpecType -> HashSet Symbol)
-> HashMap Symbol SpecType -> HashSet Symbol
forall a b. (a -> b) -> a -> b
$ AREnv SpecType -> HashMap Symbol SpecType
forall t. AREnv t -> HashMap Symbol t
reGlobal (AREnv SpecType -> HashMap Symbol SpecType)
-> AREnv SpecType -> HashMap Symbol SpecType
forall a b. (a -> b) -> a -> b
$ CGEnv -> AREnv SpecType
renv CGEnv
γ
in [(Symbol, Expr)] -> LocalRewrites
unloop
([(Symbol, Expr)] -> LocalRewrites)
-> [(Symbol, Expr)] -> LocalRewrites
forall a b. (a -> b) -> a -> b
$ ((Expr, Expr) -> [(Symbol, Expr)])
-> [(Expr, Expr)] -> [(Symbol, Expr)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Expr -> Expr -> [(Symbol, Expr)])
-> (Expr, Expr) -> [(Symbol, Expr)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Expr -> Expr -> [(Symbol, Expr)])
-> (Expr, Expr) -> [(Symbol, Expr)])
-> (Expr -> Expr -> [(Symbol, Expr)])
-> (Expr, Expr)
-> [(Symbol, Expr)]
forall a b. (a -> b) -> a -> b
$ HashSet Symbol
-> HashSet Symbol -> Expr -> Expr -> [(Symbol, Expr)]
unify HashSet Symbol
ctors HashSet Symbol
globals)
([(Expr, Expr)] -> [(Symbol, Expr)])
-> [(Expr, Expr)] -> [(Symbol, Expr)]
forall a b. (a -> b) -> a -> b
$ [(Expr, Expr)] -> [(Expr, Expr)]
groupUnifiableEqualities
([(Expr, Expr)] -> [(Expr, Expr)])
-> [(Expr, Expr)] -> [(Expr, Expr)]
forall a b. (a -> b) -> a -> b
$ Expr -> [(Expr, Expr)]
getEqualities Expr
refinement
where toSet :: HashMap Symbol v -> HashSet Symbol
toSet = [Symbol] -> HashSet Symbol
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([Symbol] -> HashSet Symbol)
-> (HashMap Symbol v -> [Symbol])
-> HashMap Symbol v
-> HashSet Symbol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Symbol v -> [Symbol]
forall k v. HashMap k v -> [k]
M.keys
unify :: S.HashSet Symbol -> S.HashSet Symbol -> Expr -> Expr -> [(Symbol, Expr)]
unify :: HashSet Symbol
-> HashSet Symbol -> Expr -> Expr -> [(Symbol, Expr)]
unify HashSet Symbol
ctors HashSet Symbol
globals = Expr -> Expr -> [(Symbol, Expr)]
go
where
go :: Expr -> Expr -> [(Symbol, Expr)]
go Expr
e1 Expr
e2 | Expr
e1 Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
e2 = []
go (EVar Symbol
s1) Expr
e2 | Symbol -> Bool
isLocal Symbol
s1 = [(Symbol
s1, Expr
e2)]
go Expr
e1 (EVar Symbol
s2) | Symbol -> Bool
isLocal Symbol
s2 = [(Symbol
s2, Expr
e1)]
go Expr
e1 Expr
e2
| (EVar Symbol
name1 , [Expr]
args1) <- Expr -> (Expr, [Expr])
forall v. ExprV v -> (ExprV v, [ExprV v])
splitEApp Expr
e2
, (EVar Symbol
name2 , [Expr]
args2) <- Expr -> (Expr, [Expr])
forall v. ExprV v -> (ExprV v, [ExprV v])
splitEApp Expr
e1
, Symbol
name1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
name2
, Symbol -> Bool
isCtor Symbol
name1
, [Expr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
args1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Expr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
args2
= [[(Symbol, Expr)]] -> [(Symbol, Expr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Symbol, Expr)]] -> [(Symbol, Expr)])
-> [[(Symbol, Expr)]] -> [(Symbol, Expr)]
forall a b. (a -> b) -> a -> b
$ (Expr -> Expr -> [(Symbol, Expr)])
-> [Expr] -> [Expr] -> [[(Symbol, Expr)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Expr -> Expr -> [(Symbol, Expr)]
go [Expr]
args1 [Expr]
args2
go Expr
_ Expr
_ = []
isCtor :: Symbol -> Bool
isCtor Symbol
name = Symbol
name Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet Symbol
ctors
isLocal :: Symbol -> Bool
isLocal Symbol
name = Bool -> Bool
not (Symbol
name Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet Symbol
globals
Bool -> Bool -> Bool
|| Symbol
name Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet Symbol
ctors
Bool -> Bool -> Bool
|| Symbol -> Symbol -> Bool
isPrefixOfSym Symbol
anfPrefix Symbol
name)
groupUnifiableEqualities :: [(Expr, Expr)] -> [(Expr, Expr)]
groupUnifiableEqualities :: [(Expr, Expr)] -> [(Expr, Expr)]
groupUnifiableEqualities = [[(Expr, Expr)]] -> [(Expr, Expr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Expr, Expr)]] -> [(Expr, Expr)])
-> ([(Expr, Expr)] -> [[(Expr, Expr)]])
-> [(Expr, Expr)]
-> [(Expr, Expr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Expr] -> [[(Expr, Expr)]]) -> [[Expr]] -> [[(Expr, Expr)]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [Expr] -> [[(Expr, Expr)]]
forall {b}. [b] -> [[(b, b)]]
mkEqs ([[Expr]] -> [[(Expr, Expr)]])
-> ([(Expr, Expr)] -> [[Expr]])
-> [(Expr, Expr)]
-> [[(Expr, Expr)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Expr, Expr)] -> [[Expr]]
forall {v}. Hashable v => [(v, v)] -> [[v]]
grouping
where
mkEqs :: [b] -> [[(b, b)]]
mkEqs (b
e1 : [b]
es) = [ (b
e1, b
e) | b
e <- [b]
es ] [(b, b)] -> [[(b, b)]] -> [[(b, b)]]
forall a. a -> [a] -> [a]
: [b] -> [[(b, b)]]
mkEqs [b]
es
mkEqs [b]
_ = []
grouping :: [(v, v)] -> [[v]]
grouping [(v, v)]
eqs = ((v, [v]) -> [v]) -> [(v, [v])] -> [[v]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (v, [v]) -> [v]
forall a b. (a, b) -> b
snd ([(v, [v])] -> [[v]]) -> [(v, [v])] -> [[v]]
forall a b. (a -> b) -> a -> b
$ [(v, v)] -> [(v, [v])]
forall k v. (Eq k, Hashable k) => [(k, v)] -> [(k, [v])]
M.groupList ([(v, v)] -> [(v, [v])]) -> [(v, v)] -> [(v, [v])]
forall a b. (a -> b) -> a -> b
$ [(v, v)]
eqs [(v, v)] -> [(v, v)] -> [(v, v)]
forall a. [a] -> [a] -> [a]
++ ((v, v) -> (v, v)) -> [(v, v)] -> [(v, v)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (v, v) -> (v, v)
forall a b. (a, b) -> (b, a)
swap [(v, v)]
eqs
getEqualities :: Expr -> [(Expr, Expr)]
getEqualities :: Expr -> [(Expr, Expr)]
getEqualities (PAtom Brel
Eq Expr
e1 Expr
e2) = [(Expr
e1, Expr
e2)]
getEqualities (PAnd [Expr]
eqs) = (Expr -> [(Expr, Expr)]) -> [Expr] -> [(Expr, Expr)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Expr -> [(Expr, Expr)]
getEqualities [Expr]
eqs
getEqualities Expr
_ = []
newtype AcyclicRewrites = AR (M.HashMap Symbol Expr)
unloop :: [(Symbol, Expr)] -> LocalRewrites
unloop :: [(Symbol, Expr)] -> LocalRewrites
unloop = HashMap Symbol Expr -> LocalRewrites
LocalRewrites (HashMap Symbol Expr -> LocalRewrites)
-> ([(Symbol, Expr)] -> HashMap Symbol Expr)
-> [(Symbol, Expr)]
-> LocalRewrites
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AcyclicRewrites -> HashMap Symbol Expr
toRewrites (AcyclicRewrites -> HashMap Symbol Expr)
-> ([(Symbol, Expr)] -> AcyclicRewrites)
-> [(Symbol, Expr)]
-> HashMap Symbol Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AcyclicRewrites -> (Symbol, Expr) -> AcyclicRewrites)
-> AcyclicRewrites -> [(Symbol, Expr)] -> AcyclicRewrites
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' AcyclicRewrites -> (Symbol, Expr) -> AcyclicRewrites
doInsert AcyclicRewrites
empty
where doInsert :: AcyclicRewrites -> (Symbol, Expr) -> AcyclicRewrites
doInsert AcyclicRewrites
ar (Symbol
s, Expr
e) = AcyclicRewrites
ar AcyclicRewrites -> Maybe AcyclicRewrites -> AcyclicRewrites
forall a. a -> Maybe a -> a
`fromMaybe` AcyclicRewrites -> Symbol -> Expr -> Maybe AcyclicRewrites
insert AcyclicRewrites
ar Symbol
s Expr
e
toRewrites :: AcyclicRewrites -> M.HashMap Symbol Expr
toRewrites :: AcyclicRewrites -> HashMap Symbol Expr
toRewrites (AR HashMap Symbol Expr
m) = HashMap Symbol Expr
m
existsPath :: AcyclicRewrites -> Symbol -> Symbol -> Bool
existsPath :: AcyclicRewrites -> Symbol -> Symbol -> Bool
existsPath (AR HashMap Symbol Expr
m) Symbol
s1' Symbol
s2 = Symbol -> Bool
go Symbol
s1'
where
go :: Symbol -> Bool
go Symbol
s1 | Symbol
s1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
s2 = Bool
True
go Symbol
s1 | Just Expr
e <- Symbol -> HashMap Symbol Expr -> Maybe Expr
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Symbol
s1 HashMap Symbol Expr
m
= (Symbol -> Bool) -> [Symbol] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Symbol -> Bool
go ([Symbol] -> Bool) -> [Symbol] -> Bool
forall a b. (a -> b) -> a -> b
$ HashSet Symbol -> [Symbol]
forall a. HashSet a -> [a]
S.toList (HashSet Symbol -> [Symbol]) -> HashSet Symbol -> [Symbol]
forall a b. (a -> b) -> a -> b
$ Expr -> HashSet Symbol
exprSymbolsSet Expr
e
go Symbol
_ = Bool
False
empty :: AcyclicRewrites
empty :: AcyclicRewrites
empty = HashMap Symbol Expr -> AcyclicRewrites
AR HashMap Symbol Expr
forall k v. HashMap k v
M.empty
insert :: AcyclicRewrites -> Symbol -> Expr -> Maybe AcyclicRewrites
insert :: AcyclicRewrites -> Symbol -> Expr -> Maybe AcyclicRewrites
insert AcyclicRewrites
ar Symbol
s Expr
e | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Symbol
s Symbol -> HashSet Symbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` Expr -> HashSet Symbol
exprSymbolsSet Expr
e
, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Symbol -> Bool) -> [Symbol] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Symbol
s2 -> AcyclicRewrites -> Symbol -> Symbol -> Bool
existsPath AcyclicRewrites
ar Symbol
s2 Symbol
s) ([Symbol] -> Bool) -> [Symbol] -> Bool
forall a b. (a -> b) -> a -> b
$ HashSet Symbol -> [Symbol]
forall a. HashSet a -> [a]
S.toList (HashSet Symbol -> [Symbol]) -> HashSet Symbol -> [Symbol]
forall a b. (a -> b) -> a -> b
$ Expr -> HashSet Symbol
exprSymbolsSet Expr
e
= AcyclicRewrites -> Maybe AcyclicRewrites
forall a. a -> Maybe a
Just (AcyclicRewrites -> Maybe AcyclicRewrites)
-> AcyclicRewrites -> Maybe AcyclicRewrites
forall a b. (a -> b) -> a -> b
$ AcyclicRewrites -> Symbol -> Expr -> AcyclicRewrites
insertUnsafe AcyclicRewrites
ar Symbol
s Expr
e
| Bool
otherwise
= Maybe AcyclicRewrites
forall a. Maybe a
Nothing
where insertUnsafe :: AcyclicRewrites -> Symbol -> Expr -> AcyclicRewrites
insertUnsafe (AR HashMap Symbol Expr
m) Symbol
s' Expr
e' = HashMap Symbol Expr -> AcyclicRewrites
AR (HashMap Symbol Expr -> AcyclicRewrites)
-> HashMap Symbol Expr -> AcyclicRewrites
forall a b. (a -> b) -> a -> b
$ Symbol -> Expr -> HashMap Symbol Expr -> HashMap Symbol Expr
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert Symbol
s' Expr
e' HashMap Symbol Expr
m