{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.PB.Internal.BDD
-- Copyright   :  (c) Masahiro Sakai 2016
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-- References:
--
-- * [ES06] N. Eén and N. Sörensson. Translating Pseudo-Boolean
--   Constraints into SAT. JSAT 2:1–26, 2006.
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.PB.Internal.BDD
  ( addPBLinAtLeastBDD
  , encodePBLinAtLeastWithPolarityBDD
  ) where

import Control.Monad.State.Strict
import Control.Monad.Primitive
import Data.Ord
import Data.List
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addPBLinAtLeastBDD :: PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m ()
addPBLinAtLeastBDD :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> PBLinAtLeast -> m ()
addPBLinAtLeastBDD Encoder m
enc PBLinAtLeast
constr = do
  Lit
l <- Encoder m -> Polarity -> PBLinAtLeast -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> PBLinAtLeast -> m Lit
encodePBLinAtLeastWithPolarityBDD Encoder m
enc Polarity
Tseitin.polarityPos PBLinAtLeast
constr
  Encoder m -> Clause -> m ()
forall (m :: * -> *) a. AddClause m a => a -> Clause -> m ()
SAT.addClause Encoder m
enc [Lit
l]

encodePBLinAtLeastWithPolarityBDD :: forall m. PrimMonad m => Tseitin.Encoder m -> Tseitin.Polarity -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeastWithPolarityBDD :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> PBLinAtLeast -> m Lit
encodePBLinAtLeastWithPolarityBDD Encoder m
enc Polarity
polarity ([PBLinTerm]
lhs,Integer
rhs) = do
  let lhs' :: [PBLinTerm]
lhs' = (PBLinTerm -> PBLinTerm -> Ordering) -> [PBLinTerm] -> [PBLinTerm]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((PBLinTerm -> PBLinTerm -> Ordering)
-> PBLinTerm -> PBLinTerm -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((PBLinTerm -> Integer) -> PBLinTerm -> PBLinTerm -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing PBLinTerm -> Integer
forall a b. (a, b) -> a
fst)) [PBLinTerm]
lhs
  (StateT (Map PBLinAtLeast Lit) m Lit
 -> Map PBLinAtLeast Lit -> m Lit)
-> Map PBLinAtLeast Lit
-> StateT (Map PBLinAtLeast Lit) m Lit
-> m Lit
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (Map PBLinAtLeast Lit) m Lit
-> Map PBLinAtLeast Lit -> m Lit
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT Map PBLinAtLeast Lit
forall k a. Map k a
Map.empty (StateT (Map PBLinAtLeast Lit) m Lit -> m Lit)
-> StateT (Map PBLinAtLeast Lit) m Lit -> m Lit
forall a b. (a -> b) -> a -> b
$ do
    let f :: SAT.PBLinSum -> Integer -> Integer -> StateT (Map (SAT.PBLinSum, Integer) SAT.Lit) m SAT.Lit
        f :: [PBLinTerm]
-> Integer -> Integer -> StateT (Map PBLinAtLeast Lit) m Lit
f [PBLinTerm]
xs Integer
rhs Integer
slack
          | Integer
rhs Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0  = m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall (m :: * -> *) a.
Monad m =>
m a -> StateT (Map PBLinAtLeast Lit) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT (Map PBLinAtLeast Lit) m Lit)
-> m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeConjWithPolarity Encoder m
enc Polarity
polarity [] -- true
          | Integer
slack Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall (m :: * -> *) a.
Monad m =>
m a -> StateT (Map PBLinAtLeast Lit) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT (Map PBLinAtLeast Lit) m Lit)
-> m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity [] -- false
          | Bool
otherwise = do
              Map PBLinAtLeast Lit
m <- StateT (Map PBLinAtLeast Lit) m (Map PBLinAtLeast Lit)
forall s (m :: * -> *). MonadState s m => m s
get
              case PBLinAtLeast -> Map PBLinAtLeast Lit -> Maybe Lit
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ([PBLinTerm]
xs,Integer
rhs) Map PBLinAtLeast Lit
m of
                Just Lit
l -> Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a. a -> StateT (Map PBLinAtLeast Lit) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Lit
l
                Maybe Lit
Nothing -> do
                  case [PBLinTerm]
xs of
                    [] -> [Char] -> StateT (Map PBLinAtLeast Lit) m Lit
forall a. HasCallStack => [Char] -> a
error [Char]
"encodePBLinAtLeastBDD: should not happen"
                    [(Integer
_,Lit
l)] -> Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a. a -> StateT (Map PBLinAtLeast Lit) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Lit
l
                    (Integer
c,Lit
l) : [PBLinTerm]
xs' -> do
                      Lit
thenLit <- [PBLinTerm]
-> Integer -> Integer -> StateT (Map PBLinAtLeast Lit) m Lit
f [PBLinTerm]
xs' (Integer
rhs Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
c) Integer
slack
                      Lit
l2 <- m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall (m :: * -> *) a.
Monad m =>
m a -> StateT (Map PBLinAtLeast Lit) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT (Map PBLinAtLeast Lit) m Lit)
-> m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeConjWithPolarity Encoder m
enc Polarity
polarity [Lit
l, Lit
thenLit]
                      Lit
l3 <- if Integer
c Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
slack then
                              Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a. a -> StateT (Map PBLinAtLeast Lit) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Lit
l2
                            else do
                              Lit
elseLit <- [PBLinTerm]
-> Integer -> Integer -> StateT (Map PBLinAtLeast Lit) m Lit
f [PBLinTerm]
xs' Integer
rhs (Integer
slack Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
c)
                              m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall (m :: * -> *) a.
Monad m =>
m a -> StateT (Map PBLinAtLeast Lit) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT (Map PBLinAtLeast Lit) m Lit)
-> m Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity [Lit
l2, Lit
elseLit]
                      (Map PBLinAtLeast Lit -> Map PBLinAtLeast Lit)
-> StateT (Map PBLinAtLeast Lit) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (PBLinAtLeast -> Lit -> Map PBLinAtLeast Lit -> Map PBLinAtLeast Lit
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ([PBLinTerm]
xs,Integer
rhs) Lit
l3)
                      Lit -> StateT (Map PBLinAtLeast Lit) m Lit
forall a. a -> StateT (Map PBLinAtLeast Lit) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Lit
l3
    [PBLinTerm]
-> Integer -> Integer -> StateT (Map PBLinAtLeast Lit) m Lit
f [PBLinTerm]
lhs' Integer
rhs ([Integer] -> Integer
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer
c | (Integer
c,Lit
_) <- [PBLinTerm]
lhs'] Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
rhs)