{-# OPTIONS_HADDOCK show-extensions #-}
module ToySolver.EUF.EUFSolver
(
Solver
, newSolver
, FSym
, Term (..)
, ConstrID
, VAFun (..)
, newFSym
, newFun
, newConst
, assertEqual
, assertEqual'
, assertNotEqual
, assertNotEqual'
, check
, areEqual
, explain
, Entity
, EntityTuple
, Model (..)
, getModel
, eval
, evalAp
, pushBacktrackPoint
, popBacktrackPoint
, termToFlatTerm
, termToFSym
, fsymToTerm
, fsymToFlatTerm
, flatTermToFSym
) where
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef
import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC
data Solver
= Solver
{ Solver -> Solver
svCCSolver :: !CC.Solver
, Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
, Solver -> IORef IntSet
svExplanation :: IORef IntSet
, Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
}
newSolver :: IO Solver
newSolver :: IO Solver
newSolver = do
Solver
cc <- IO Solver
CC.newSolver
IORef (Map (Term, Term) (Maybe Level))
deqs <- Map (Term, Term) (Maybe Level)
-> IO (IORef (Map (Term, Term) (Maybe Level)))
forall a. a -> IO (IORef a)
newIORef Map (Term, Term) (Maybe Level)
forall k a. Map k a
Map.empty
IORef IntSet
expl <- IntSet -> IO (IORef IntSet)
forall a. a -> IO (IORef a)
newIORef IntSet
forall a. HasCallStack => a
undefined
Vec (Map (Term, Term) ())
bp <- IO (Vec (Map (Term, Term) ()))
forall (a :: * -> * -> *) e. MArray a e IO => IO (GenericVec a e)
Vec.new
let solver :: Solver
solver =
Solver
{ svCCSolver :: Solver
svCCSolver = Solver
cc
, svDisequalities :: IORef (Map (Term, Term) (Maybe Level))
svDisequalities = IORef (Map (Term, Term) (Maybe Level))
deqs
, svExplanation :: IORef IntSet
svExplanation = IORef IntSet
expl
, svBacktrackPoints :: Vec (Map (Term, Term) ())
svBacktrackPoints = Vec (Map (Term, Term) ())
bp
}
Solver -> IO Solver
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver
newFSym :: Solver -> IO FSym
newFSym :: Solver -> IO Level
newFSym Solver
solver = Solver -> IO Level
CC.newFSym (Solver -> Solver
svCCSolver Solver
solver)
newConst :: Solver -> IO Term
newConst :: Solver -> IO Term
newConst Solver
solver = Solver -> IO Term
CC.newConst (Solver -> Solver
svCCSolver Solver
solver)
newFun :: CC.VAFun a => Solver -> IO a
newFun :: forall a. VAFun a => Solver -> IO a
newFun Solver
solver = Solver -> IO a
forall a. VAFun a => Solver -> IO a
CC.newFun (Solver -> Solver
svCCSolver Solver
solver)
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe Level
forall a. Maybe a
Nothing
assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = Solver -> Term -> Term -> Maybe Level -> IO ()
CC.merge' (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2 Maybe Level
cid
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe Level
forall a. Maybe a
Nothing
assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = if Term
t1 Term -> Term -> Bool
forall a. Ord a => a -> a -> Bool
< Term
t2 then (Term, Term) -> Maybe Level -> IO ()
f (Term
t1,Term
t2) Maybe Level
cid else (Term, Term) -> Maybe Level -> IO ()
f (Term
t2,Term
t1) Maybe Level
cid
where
f :: (Term, Term) -> Maybe Level -> IO ()
f (Term, Term)
deq Maybe Level
cid = do
Map (Term, Term) (Maybe Level)
ds <- IORef (Map (Term, Term) (Maybe Level))
-> IO (Map (Term, Term) (Maybe Level))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Term, Term)
deq (Term, Term) -> Map (Term, Term) (Maybe Level) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map (Term, Term) (Maybe Level)
ds) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> a
fst (Term, Term)
deq)
Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> b
snd (Term, Term)
deq)
IORef (Map (Term, Term) (Maybe Level))
-> Map (Term, Term) (Maybe Level) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) (Map (Term, Term) (Maybe Level) -> IO ())
-> Map (Term, Term) (Maybe Level) -> IO ()
forall a b. (a -> b) -> a -> b
$! (Term, Term)
-> Maybe Level
-> Map (Term, Term) (Maybe Level)
-> Map (Term, Term) (Maybe Level)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq Maybe Level
cid Map (Term, Term) (Maybe Level)
ds
Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Level
lvLevel -> Level -> Bool
forall a. Eq a => a -> a -> Bool
==Level
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Vec (Map (Term, Term) ())
-> Level -> (Map (Term, Term) () -> Map (Term, Term) ()) -> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> Level -> (e -> e) -> IO ()
Vec.unsafeModify' (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) (Level
lv Level -> Level -> Level
forall a. Num a => a -> a -> a
- Level
1) ((Map (Term, Term) () -> Map (Term, Term) ()) -> IO ())
-> (Map (Term, Term) () -> Map (Term, Term) ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Term, Term) -> () -> Map (Term, Term) () -> Map (Term, Term) ()
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq ()
check :: Solver -> IO Bool
check :: Solver -> IO Bool
check Solver
solver = do
Map (Term, Term) (Maybe Level)
ds <- IORef (Map (Term, Term) (Maybe Level))
-> IO (Map (Term, Term) (Maybe Level))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
(Either () () -> Bool) -> IO (Either () ()) -> IO Bool
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Either () () -> Bool
forall a b. Either a b -> Bool
isRight (IO (Either () ()) -> IO Bool) -> IO (Either () ()) -> IO Bool
forall a b. (a -> b) -> a -> b
$ ExceptT () IO () -> IO (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT () IO () -> IO (Either () ()))
-> ExceptT () IO () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ [((Term, Term), Maybe Level)]
-> (((Term, Term), Maybe Level) -> ExceptT () IO ())
-> ExceptT () IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map (Term, Term) (Maybe Level) -> [((Term, Term), Maybe Level)]
forall k a. Map k a -> [(k, a)]
Map.toList Map (Term, Term) (Maybe Level)
ds) ((((Term, Term), Maybe Level) -> ExceptT () IO ())
-> ExceptT () IO ())
-> (((Term, Term), Maybe Level) -> ExceptT () IO ())
-> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ \((Term
t1,Term
t2), Maybe Level
cid) -> do
Bool
b <- IO Bool -> ExceptT () IO Bool
forall (m :: * -> *) a. Monad m => m a -> ExceptT () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Bool -> ExceptT () IO Bool) -> IO Bool -> ExceptT () IO Bool
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
if Bool
b then do
Just IntSet
cs <- IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall (m :: * -> *) a. Monad m => m a -> ExceptT () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet))
-> IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
IO () -> ExceptT () IO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT () IO ()) -> IO () -> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ IORef IntSet -> IntSet -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef IntSet
svExplanation Solver
solver) (IntSet -> IO ()) -> IntSet -> IO ()
forall a b. (a -> b) -> a -> b
$!
case Maybe Level
cid of
Maybe Level
Nothing -> IntSet
cs
Just Level
c -> Level -> IntSet -> IntSet
IntSet.insert Level
c IntSet
cs
() -> ExceptT () IO ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE ()
else
() -> ExceptT () IO ()
forall a. a -> ExceptT () IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain :: Solver -> Maybe (Term, Term) -> IO IntSet
explain Solver
solver Maybe (Term, Term)
Nothing = IORef IntSet -> IO IntSet
forall a. IORef a -> IO a
readIORef (Solver -> IORef IntSet
svExplanation Solver
solver)
explain Solver
solver (Just (Term
t1,Term
t2)) = do
Maybe IntSet
ret <- Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
case Maybe IntSet
ret of
Maybe IntSet
Nothing -> [Char] -> IO IntSet
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.explain: should not happen"
Just IntSet
cs -> IntSet -> IO IntSet
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IntSet
cs
getModel :: Solver -> IO Model
getModel :: Solver -> IO Model
getModel = Solver -> IO Model
CC.getModel (Solver -> IO Model) -> (Solver -> Solver) -> Solver -> IO Model
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
type Level = Int
getCurrentLevel :: Solver -> IO Level
getCurrentLevel :: Solver -> IO Level
getCurrentLevel Solver
solver = Vec (Map (Term, Term) ()) -> IO Level
forall (a :: * -> * -> *) e. GenericVec a e -> IO Level
Vec.getSize (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint Solver
solver = do
Solver -> IO ()
CC.pushBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
Vec (Map (Term, Term) ()) -> Map (Term, Term) () -> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> e -> IO ()
Vec.push (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) Map (Term, Term) ()
forall k a. Map k a
Map.empty
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint Solver
solver = do
Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
if Level
lvLevel -> Level -> Bool
forall a. Eq a => a -> a -> Bool
==Level
0 then
[Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
else do
Solver -> IO ()
CC.popBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
Map (Term, Term) ()
xs <- Vec (Map (Term, Term) ()) -> IO (Map (Term, Term) ())
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> IO e
Vec.unsafePop (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
IORef (Map (Term, Term) (Maybe Level))
-> (Map (Term, Term) (Maybe Level)
-> Map (Term, Term) (Maybe Level))
-> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) ((Map (Term, Term) (Maybe Level) -> Map (Term, Term) (Maybe Level))
-> IO ())
-> (Map (Term, Term) (Maybe Level)
-> Map (Term, Term) (Maybe Level))
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Map (Term, Term) (Maybe Level)
-> Map (Term, Term) () -> Map (Term, Term) (Maybe Level)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`Map.difference` Map (Term, Term) ()
xs)
termToFlatTerm :: Solver -> Term -> IO FlatTerm
termToFlatTerm = Solver -> Term -> IO FlatTerm
CC.termToFlatTerm (Solver -> Term -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> Term -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
termToFSym :: Solver -> Term -> IO Level
termToFSym = Solver -> Term -> IO Level
CC.termToFSym (Solver -> Term -> IO Level)
-> (Solver -> Solver) -> Solver -> Term -> IO Level
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToTerm :: Solver -> Level -> IO Term
fsymToTerm = Solver -> Level -> IO Term
CC.fsymToTerm (Solver -> Level -> IO Term)
-> (Solver -> Solver) -> Solver -> Level -> IO Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToFlatTerm :: Solver -> Level -> IO FlatTerm
fsymToFlatTerm = Solver -> Level -> IO FlatTerm
CC.fsymToFlatTerm (Solver -> Level -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> Level -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
flatTermToFSym :: Solver -> FlatTerm -> IO Level
flatTermToFSym = Solver -> FlatTerm -> IO Level
CC.flatTermToFSym (Solver -> FlatTerm -> IO Level)
-> (Solver -> Solver) -> Solver -> FlatTerm -> IO Level
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver