{-# LANGUAGE ScopedTypeVariables, FlexibleContexts #-}
{-# OPTIONS_GHC -Wall #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.PB.Internal.BDD
-- Copyright   :  (c) Masahiro Sakai 2016
-- License     :  BSD-style
-- 
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-- References:
--
-- * [ES06] N. Eén and N. Sörensson. Translating Pseudo-Boolean
--   Constraints into SAT. JSAT 2:1–26, 2006.
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.PB.Internal.BDD
  ( addPBLinAtLeastBDD
  , encodePBLinAtLeastBDD
  ) where

import Control.Monad.State.Strict
import Control.Monad.Primitive
import Data.Ord
import Data.List
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addPBLinAtLeastBDD :: PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m ()
addPBLinAtLeastBDD enc constr = do
  l <- encodePBLinAtLeastBDD enc constr
  SAT.addClause enc [l]

encodePBLinAtLeastBDD :: forall m. PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m SAT.Lit
encodePBLinAtLeastBDD enc (lhs,rhs) = do
  let lhs' = sortBy (flip (comparing fst)) lhs
  flip evalStateT Map.empty $ do
    let f :: SAT.PBLinSum -> Integer -> Integer -> StateT (Map (SAT.PBLinSum, Integer) SAT.Lit) m SAT.Lit
        f xs rhs slack
          | rhs <= 0  = lift $ Tseitin.encodeConj enc [] -- true
          | slack < 0 = lift $ Tseitin.encodeDisj enc [] -- false
          | otherwise = do
              m <- get
              case Map.lookup (xs,rhs) m of
                Just l -> return l
                Nothing -> do
                  case xs of
                    [(_,l)] -> return l
                    (c,l) : xs' -> do
                      thenLit <- f xs' (rhs - c) slack
                      l2 <- lift $ Tseitin.encodeConjWithPolarity enc Tseitin.polarityPos [l, thenLit]
                      l3 <- if c > slack then
                              return l2
                            else do
                              elseLit <- f xs' rhs (slack - c)
                              lift $ Tseitin.encodeDisjWithPolarity enc Tseitin.polarityPos [l2, elseLit]
                      modify (Map.insert (xs,rhs) l3)
                      return l3
    f lhs' rhs (sum [c | (c,_) <- lhs'] - rhs)