{-# OPTIONS_HADDOCK show-extensions #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.EUF.EUFSolver
-- Copyright   :  (c) Masahiro Sakai 2015
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.EUF.EUFSolver
  ( -- * The @Solver@ type
    Solver
  , newSolver

  -- * Problem description
  , FSym
  , Term (..)
  , ConstrID
  , VAFun (..)
  , newFSym
  , newFun
  , newConst
  , assertEqual
  , assertEqual'
  , assertNotEqual
  , assertNotEqual'

  -- * Query
  , check
  , areEqual

  -- * Explanation
  , explain

  -- * Model Construction
  , Entity
  , EntityTuple
  , Model (..)
  , getModel
  , eval
  , evalAp

  -- * Backtracking
  , pushBacktrackPoint
  , popBacktrackPoint

  -- * Low-level operations
  , termToFlatTerm
  , termToFSym
  , fsymToTerm
  , fsymToFlatTerm
  , flatTermToFSym
  ) where

import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef

import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC

data Solver
  = Solver
  { Solver -> Solver
svCCSolver :: !CC.Solver
  , Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
  , Solver -> IORef IntSet
svExplanation :: IORef IntSet
  , Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
  }

newSolver :: IO Solver
newSolver :: IO Solver
newSolver = do
  Solver
cc <- IO Solver
CC.newSolver
  IORef (Map (Term, Term) (Maybe Level))
deqs <- Map (Term, Term) (Maybe Level)
-> IO (IORef (Map (Term, Term) (Maybe Level)))
forall a. a -> IO (IORef a)
newIORef Map (Term, Term) (Maybe Level)
forall k a. Map k a
Map.empty
  IORef IntSet
expl <- IntSet -> IO (IORef IntSet)
forall a. a -> IO (IORef a)
newIORef IntSet
forall a. HasCallStack => a
undefined
  Vec (Map (Term, Term) ())
bp <- IO (Vec (Map (Term, Term) ()))
forall (a :: * -> * -> *) e. MArray a e IO => IO (GenericVec a e)
Vec.new

  let solver :: Solver
solver =
        Solver
        { svCCSolver :: Solver
svCCSolver = Solver
cc
        , svDisequalities :: IORef (Map (Term, Term) (Maybe Level))
svDisequalities = IORef (Map (Term, Term) (Maybe Level))
deqs
        , svExplanation :: IORef IntSet
svExplanation = IORef IntSet
expl
        , svBacktrackPoints :: Vec (Map (Term, Term) ())
svBacktrackPoints = Vec (Map (Term, Term) ())
bp
        }
  Solver -> IO Solver
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver

newFSym :: Solver -> IO FSym
newFSym :: Solver -> IO Level
newFSym Solver
solver = Solver -> IO Level
CC.newFSym (Solver -> Solver
svCCSolver Solver
solver)

newConst :: Solver -> IO Term
newConst :: Solver -> IO Term
newConst Solver
solver = Solver -> IO Term
CC.newConst (Solver -> Solver
svCCSolver Solver
solver)

newFun :: CC.VAFun a => Solver -> IO a
newFun :: forall a. VAFun a => Solver -> IO a
newFun Solver
solver = Solver -> IO a
forall a. VAFun a => Solver -> IO a
CC.newFun (Solver -> Solver
svCCSolver Solver
solver)

assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe Level
forall a. Maybe a
Nothing

assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = Solver -> Term -> Term -> Maybe Level -> IO ()
CC.merge' (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2 Maybe Level
cid

assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe Level
forall a. Maybe a
Nothing

assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' :: Solver -> Term -> Term -> Maybe Level -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe Level
cid = if Term
t1 Term -> Term -> Bool
forall a. Ord a => a -> a -> Bool
< Term
t2 then (Term, Term) -> Maybe Level -> IO ()
f (Term
t1,Term
t2) Maybe Level
cid else (Term, Term) -> Maybe Level -> IO ()
f (Term
t2,Term
t1) Maybe Level
cid
  where
    f :: (Term, Term) -> Maybe Level -> IO ()
f (Term, Term)
deq Maybe Level
cid = do
      Map (Term, Term) (Maybe Level)
ds <- IORef (Map (Term, Term) (Maybe Level))
-> IO (Map (Term, Term) (Maybe Level))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Term, Term)
deq (Term, Term) -> Map (Term, Term) (Maybe Level) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map (Term, Term) (Maybe Level)
ds) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> a
fst (Term, Term)
deq) -- It is important to name the term for model generation
        Level
_ <- Solver -> Term -> IO Level
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> b
snd (Term, Term)
deq) -- It is important to name the term for model generation
        IORef (Map (Term, Term) (Maybe Level))
-> Map (Term, Term) (Maybe Level) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) (Map (Term, Term) (Maybe Level) -> IO ())
-> Map (Term, Term) (Maybe Level) -> IO ()
forall a b. (a -> b) -> a -> b
$! (Term, Term)
-> Maybe Level
-> Map (Term, Term) (Maybe Level)
-> Map (Term, Term) (Maybe Level)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq Maybe Level
cid Map (Term, Term) (Maybe Level)
ds
        Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Level
lvLevel -> Level -> Bool
forall a. Eq a => a -> a -> Bool
==Level
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          Vec (Map (Term, Term) ())
-> Level -> (Map (Term, Term) () -> Map (Term, Term) ()) -> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> Level -> (e -> e) -> IO ()
Vec.unsafeModify' (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) (Level
lv Level -> Level -> Level
forall a. Num a => a -> a -> a
- Level
1) ((Map (Term, Term) () -> Map (Term, Term) ()) -> IO ())
-> (Map (Term, Term) () -> Map (Term, Term) ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Term, Term) -> () -> Map (Term, Term) () -> Map (Term, Term) ()
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq ()

check :: Solver -> IO Bool
check :: Solver -> IO Bool
check Solver
solver = do
  Map (Term, Term) (Maybe Level)
ds <- IORef (Map (Term, Term) (Maybe Level))
-> IO (Map (Term, Term) (Maybe Level))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver)
  (Either () () -> Bool) -> IO (Either () ()) -> IO Bool
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Either () () -> Bool
forall a b. Either a b -> Bool
isRight (IO (Either () ()) -> IO Bool) -> IO (Either () ()) -> IO Bool
forall a b. (a -> b) -> a -> b
$ ExceptT () IO () -> IO (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT () IO () -> IO (Either () ()))
-> ExceptT () IO () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ [((Term, Term), Maybe Level)]
-> (((Term, Term), Maybe Level) -> ExceptT () IO ())
-> ExceptT () IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map (Term, Term) (Maybe Level) -> [((Term, Term), Maybe Level)]
forall k a. Map k a -> [(k, a)]
Map.toList Map (Term, Term) (Maybe Level)
ds) ((((Term, Term), Maybe Level) -> ExceptT () IO ())
 -> ExceptT () IO ())
-> (((Term, Term), Maybe Level) -> ExceptT () IO ())
-> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ \((Term
t1,Term
t2), Maybe Level
cid) -> do
    Bool
b <- IO Bool -> ExceptT () IO Bool
forall (m :: * -> *) a. Monad m => m a -> ExceptT () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Bool -> ExceptT () IO Bool) -> IO Bool -> ExceptT () IO Bool
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
    if Bool
b then do
      Just IntSet
cs <- IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall (m :: * -> *) a. Monad m => m a -> ExceptT () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet))
-> IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
      IO () -> ExceptT () IO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT () IO ()) -> IO () -> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ IORef IntSet -> IntSet -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef IntSet
svExplanation Solver
solver) (IntSet -> IO ()) -> IntSet -> IO ()
forall a b. (a -> b) -> a -> b
$!
        case Maybe Level
cid of
          Maybe Level
Nothing -> IntSet
cs
          Just Level
c -> Level -> IntSet -> IntSet
IntSet.insert Level
c IntSet
cs
      () -> ExceptT () IO ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE ()
    else
      () -> ExceptT () IO ()
forall a. a -> ExceptT () IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

areEqual :: Solver -> Term -> Term -> IO Bool
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2

explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain :: Solver -> Maybe (Term, Term) -> IO IntSet
explain Solver
solver Maybe (Term, Term)
Nothing = IORef IntSet -> IO IntSet
forall a. IORef a -> IO a
readIORef (Solver -> IORef IntSet
svExplanation Solver
solver)
explain Solver
solver (Just (Term
t1,Term
t2)) = do
  Maybe IntSet
ret <- Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
  case Maybe IntSet
ret of
    Maybe IntSet
Nothing -> [Char] -> IO IntSet
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.explain: should not happen"
    Just IntSet
cs -> IntSet -> IO IntSet
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IntSet
cs

-- -------------------------------------------------------------------
-- Model construction
-- -------------------------------------------------------------------

getModel :: Solver -> IO Model
getModel :: Solver -> IO Model
getModel = Solver -> IO Model
CC.getModel (Solver -> IO Model) -> (Solver -> Solver) -> Solver -> IO Model
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver

-- -------------------------------------------------------------------
-- Backtracking
-- -------------------------------------------------------------------

type Level = Int

getCurrentLevel :: Solver -> IO Level
getCurrentLevel :: Solver -> IO Level
getCurrentLevel Solver
solver = Vec (Map (Term, Term) ()) -> IO Level
forall (a :: * -> * -> *) e. GenericVec a e -> IO Level
Vec.getSize (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)

pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint Solver
solver = do
  Solver -> IO ()
CC.pushBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
  Vec (Map (Term, Term) ()) -> Map (Term, Term) () -> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> e -> IO ()
Vec.push (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) Map (Term, Term) ()
forall k a. Map k a
Map.empty

popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint Solver
solver = do
  Level
lv <- Solver -> IO Level
getCurrentLevel Solver
solver
  if Level
lvLevel -> Level -> Bool
forall a. Eq a => a -> a -> Bool
==Level
0 then
    [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
  else do
    Solver -> IO ()
CC.popBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
    Map (Term, Term) ()
xs <- Vec (Map (Term, Term) ()) -> IO (Map (Term, Term) ())
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> IO e
Vec.unsafePop (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
    IORef (Map (Term, Term) (Maybe Level))
-> (Map (Term, Term) (Maybe Level)
    -> Map (Term, Term) (Maybe Level))
-> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Solver -> IORef (Map (Term, Term) (Maybe Level))
svDisequalities Solver
solver) ((Map (Term, Term) (Maybe Level) -> Map (Term, Term) (Maybe Level))
 -> IO ())
-> (Map (Term, Term) (Maybe Level)
    -> Map (Term, Term) (Maybe Level))
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Map (Term, Term) (Maybe Level)
-> Map (Term, Term) () -> Map (Term, Term) (Maybe Level)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`Map.difference` Map (Term, Term) ()
xs)

termToFlatTerm :: Solver -> Term -> IO FlatTerm
termToFlatTerm = Solver -> Term -> IO FlatTerm
CC.termToFlatTerm (Solver -> Term -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> Term -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
termToFSym :: Solver -> Term -> IO Level
termToFSym     = Solver -> Term -> IO Level
CC.termToFSym     (Solver -> Term -> IO Level)
-> (Solver -> Solver) -> Solver -> Term -> IO Level
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToTerm :: Solver -> Level -> IO Term
fsymToTerm     = Solver -> Level -> IO Term
CC.fsymToTerm     (Solver -> Level -> IO Term)
-> (Solver -> Solver) -> Solver -> Level -> IO Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToFlatTerm :: Solver -> Level -> IO FlatTerm
fsymToFlatTerm = Solver -> Level -> IO FlatTerm
CC.fsymToFlatTerm (Solver -> Level -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> Level -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
flatTermToFSym :: Solver -> FlatTerm -> IO Level
flatTermToFSym = Solver -> FlatTerm -> IO Level
CC.flatTermToFSym (Solver -> FlatTerm -> IO Level)
-> (Solver -> Solver) -> Solver -> FlatTerm -> IO Level
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver