{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
( addAtLeastParallelCounter
, encodeAtLeastWithPolarityParallelCounter
) where
import Control.Monad.Primitive
import Control.Monad.State.Strict
import Data.Bits
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin
addAtLeastParallelCounter :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m ()
addAtLeastParallelCounter :: forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m ()
addAtLeastParallelCounter Encoder m
enc AtLeast
constr = do
Lit
l <- Encoder m -> Polarity -> AtLeast -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> AtLeast -> m Lit
encodeAtLeastWithPolarityParallelCounter Encoder m
enc Polarity
Tseitin.polarityPos AtLeast
constr
Encoder m -> Clause -> m ()
forall (m :: * -> *) a. AddClause m a => a -> Clause -> m ()
SAT.addClause Encoder m
enc [Lit
l]
encodeAtLeastWithPolarityParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> Tseitin.Polarity -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastWithPolarityParallelCounter :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> AtLeast -> m Lit
encodeAtLeastWithPolarityParallelCounter Encoder m
enc Polarity
polarity (Clause
lhs,Lit
rhs) = do
if Lit
rhs Lit -> Lit -> Bool
forall a. Ord a => a -> a -> Bool
<= Lit
0 then
Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeConjWithPolarity Encoder m
enc Polarity
polarity []
else if Clause -> Lit
forall a. [a] -> Lit
forall (t :: * -> *) a. Foldable t => t a -> Lit
length Clause
lhs Lit -> Lit -> Bool
forall a. Ord a => a -> a -> Bool
< Lit
rhs then
Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity []
else do
let rhs_bits :: [Bool]
rhs_bits = Integer -> [Bool]
bits (Lit -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Lit
rhs)
(Clause
cnt, Clause
overflowBits) <- Encoder m -> Lit -> Clause -> m (Clause, Clause)
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Clause -> m (Clause, Clause)
encodeSumParallelCounter Encoder m
enc ([Bool] -> Lit
forall a. [a] -> Lit
forall (t :: * -> *) a. Foldable t => t a -> Lit
length [Bool]
rhs_bits) Clause
lhs
Lit
isGE <- Encoder m -> Polarity -> Clause -> [Bool] -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> [Bool] -> m Lit
encodeGE Encoder m
enc Polarity
polarity Clause
cnt [Bool]
rhs_bits
Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity (Clause -> m Lit) -> Clause -> m Lit
forall a b. (a -> b) -> a -> b
$ Lit
isGE Lit -> Clause -> Clause
forall a. a -> [a] -> [a]
: Clause
overflowBits
where
bits :: Integer -> [Bool]
bits :: Integer -> [Bool]
bits Integer
n = Integer -> Lit -> [Bool]
forall {t}. (Num t, Bits t) => t -> Lit -> [Bool]
f Integer
n Lit
0
where
f :: t -> Lit -> [Bool]
f t
0 !Lit
_ = []
f t
n Lit
i = t -> Lit -> Bool
forall a. Bits a => a -> Lit -> Bool
testBit t
n Lit
i Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: t -> Lit -> [Bool]
f (t -> Lit -> t
forall a. Bits a => a -> Lit -> a
clearBit t
n Lit
i) (Lit
iLit -> Lit -> Lit
forall a. Num a => a -> a -> a
+Lit
1)
encodeSumParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> Int -> [SAT.Lit] -> m ([SAT.Lit], [SAT.Lit])
encodeSumParallelCounter :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Clause -> m (Clause, Clause)
encodeSumParallelCounter Encoder m
enc Lit
w Clause
lits = do
let add :: [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
add :: Clause -> Clause -> Lit -> StateT Clause m Clause
add = Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go Lit
0 []
where
go :: Int -> [SAT.Lit] -> [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
go :: Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go Lit
i Clause
ret Clause
_xs Clause
_ys Lit
c | Lit
i Lit -> Lit -> Bool
forall a. Eq a => a -> a -> Bool
== Lit
w = do
(Clause -> Clause) -> StateT Clause m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Lit
cLit -> Clause -> Clause
forall a. a -> [a] -> [a]
:)
Clause -> StateT Clause m Clause
forall a. a -> StateT Clause m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> StateT Clause m Clause)
-> Clause -> StateT Clause m Clause
forall a b. (a -> b) -> a -> b
$ Clause -> Clause
forall a. [a] -> [a]
reverse Clause
ret
go Lit
_i Clause
ret [] [] Lit
c = Clause -> StateT Clause m Clause
forall a. a -> StateT Clause m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> StateT Clause m Clause)
-> Clause -> StateT Clause m Clause
forall a b. (a -> b) -> a -> b
$ Clause -> Clause
forall a. [a] -> [a]
reverse (Lit
c Lit -> Clause -> Clause
forall a. a -> [a] -> [a]
: Clause
ret)
go Lit
i Clause
ret (Lit
x : Clause
xs) (Lit
y : Clause
ys) Lit
c = do
Lit
z <- m Lit -> StateT Clause m Lit
forall (m :: * -> *) a. Monad m => m a -> StateT Clause m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Lit -> Lit -> Lit -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Lit -> Lit -> m Lit
Tseitin.encodeFASum Encoder m
enc Lit
x Lit
y Lit
c
Lit
c' <- m Lit -> StateT Clause m Lit
forall (m :: * -> *) a. Monad m => m a -> StateT Clause m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Lit -> Lit -> Lit -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Lit -> Lit -> m Lit
Tseitin.encodeFACarry Encoder m
enc Lit
x Lit
y Lit
c
Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go (Lit
iLit -> Lit -> Lit
forall a. Num a => a -> a -> a
+Lit
1) (Lit
z Lit -> Clause -> Clause
forall a. a -> [a] -> [a]
: Clause
ret) Clause
xs Clause
ys Lit
c'
go Lit
_ Clause
_ Clause
_ Clause
_ Lit
_ = [Char] -> StateT Clause m Clause
forall a. HasCallStack => [Char] -> a
error [Char]
"encodeSumParallelCounter: should not happen"
f :: Vector SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
f :: Vector Lit -> StateT Clause m Clause
f Vector Lit
xs
| Vector Lit -> Bool
forall a. Vector a -> Bool
V.null Vector Lit
xs = Clause -> StateT Clause m Clause
forall a. a -> StateT Clause m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
| Bool
otherwise = do
let len2 :: Lit
len2 = Vector Lit -> Lit
forall a. Vector a -> Lit
V.length Vector Lit
xs Lit -> Lit -> Lit
forall a. Integral a => a -> a -> a
`div` Lit
2
Clause
cnt1 <- Vector Lit -> StateT Clause m Clause
f (Lit -> Lit -> Vector Lit -> Vector Lit
forall a. Lit -> Lit -> Vector a -> Vector a
V.slice Lit
0 Lit
len2 Vector Lit
xs)
Clause
cnt2 <- Vector Lit -> StateT Clause m Clause
f (Lit -> Lit -> Vector Lit -> Vector Lit
forall a. Lit -> Lit -> Vector a -> Vector a
V.slice Lit
len2 Lit
len2 Vector Lit
xs)
Lit
c <- if Vector Lit -> Lit
forall a. Vector a -> Lit
V.length Vector Lit
xs Lit -> Lit -> Lit
forall a. Integral a => a -> a -> a
`mod` Lit
2 Lit -> Lit -> Bool
forall a. Eq a => a -> a -> Bool
== Lit
0 then
m Lit -> StateT Clause m Lit
forall (m :: * -> *) a. Monad m => m a -> StateT Clause m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Encoder m -> Clause -> m Lit
forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc []
else
m Lit -> StateT Clause m Lit
forall (m :: * -> *) a. Monad m => m a -> StateT Clause m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Lit -> StateT Clause m Lit) -> m Lit -> StateT Clause m Lit
forall a b. (a -> b) -> a -> b
$ Lit -> m Lit
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Lit -> m Lit) -> Lit -> m Lit
forall a b. (a -> b) -> a -> b
$ Vector Lit
xs Vector Lit -> Lit -> Lit
forall a. Vector a -> Lit -> a
V.! (Vector Lit -> Lit
forall a. Vector a -> Lit
V.length Vector Lit
xs Lit -> Lit -> Lit
forall a. Num a => a -> a -> a
- Lit
1)
Clause -> Clause -> Lit -> StateT Clause m Clause
add Clause
cnt1 Clause
cnt2 Lit
c
StateT Clause m Clause -> Clause -> m (Clause, Clause)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Vector Lit -> StateT Clause m Clause
f (Clause -> Vector Lit
forall a. [a] -> Vector a
V.fromList Clause
lits)) []
encodeGE :: forall m. PrimMonad m => Tseitin.Encoder m -> Tseitin.Polarity -> [SAT.Lit] -> [Bool] -> m SAT.Lit
encodeGE :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> [Bool] -> m Lit
encodeGE Encoder m
enc Polarity
polarity Clause
lhs [Bool]
rhs = do
let f :: [SAT.Lit] -> [Bool] -> SAT.Lit -> m SAT.Lit
f :: Clause -> [Bool] -> Lit -> m Lit
f [] [] Lit
r = Lit -> m Lit
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Lit
r
f [] (Bool
True : [Bool]
_) Lit
_ = Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity []
f [] (Bool
False : [Bool]
bs) Lit
r = Clause -> [Bool] -> Lit -> m Lit
f [] [Bool]
bs Lit
r
f (Lit
l : Clause
ls) (Bool
True : [Bool]
bs) Lit
r = do
Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [Bool]
bs (Lit -> m Lit) -> m Lit -> m Lit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeConjWithPolarity Encoder m
enc Polarity
polarity [Lit
l, Lit
r]
f (Lit
l : Clause
ls) (Bool
False : [Bool]
bs) Lit
r = do
Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [Bool]
bs (Lit -> m Lit) -> m Lit -> m Lit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity [Lit
l, Lit
r]
f (Lit
l : Clause
ls) [] Lit
r = do
Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [] (Lit -> m Lit) -> m Lit -> m Lit
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeDisjWithPolarity Encoder m
enc Polarity
polarity [Lit
l, Lit
r]
Lit
t <- Encoder m -> Polarity -> Clause -> m Lit
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> Clause -> m Lit
Tseitin.encodeConjWithPolarity Encoder m
enc Polarity
polarity []
Clause -> [Bool] -> Lit -> m Lit
f Clause
lhs [Bool]
rhs Lit
t