{- |
Module      : Data.RME.Base
Copyright   : Galois, Inc. 2016
License     : BSD3
Maintainer  : huffman@galois.com
Stability   : experimental
Portability : portable

Reed-Muller Expansion normal form for Boolean Formulas.
-}

module Data.RME.Base
  ( RME
  , true, false, lit
  , constant, isBool
  , compl, xor, conj, disj, iff, mux
  , eval
  , sat, allsat
  , degree
  , depth, size
  , explode
  ) where

-- | Boolean formulas in Algebraic Normal Form, using a representation
-- based on the Reed-Muller expansion.

-- Invariants: The last argument to a `Node` constructor should never
-- be `R0`. Also the `Int` arguments should strictly increase as you
-- go deeper in the tree.

data RME = Node !Int !RME !RME | R0 | R1
  deriving (RME -> RME -> Bool
(RME -> RME -> Bool) -> (RME -> RME -> Bool) -> Eq RME
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RME -> RME -> Bool
== :: RME -> RME -> Bool
$c/= :: RME -> RME -> Bool
/= :: RME -> RME -> Bool
Eq, Eq RME
Eq RME =>
(RME -> RME -> Ordering)
-> (RME -> RME -> Bool)
-> (RME -> RME -> Bool)
-> (RME -> RME -> Bool)
-> (RME -> RME -> Bool)
-> (RME -> RME -> RME)
-> (RME -> RME -> RME)
-> Ord RME
RME -> RME -> Bool
RME -> RME -> Ordering
RME -> RME -> RME
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: RME -> RME -> Ordering
compare :: RME -> RME -> Ordering
$c< :: RME -> RME -> Bool
< :: RME -> RME -> Bool
$c<= :: RME -> RME -> Bool
<= :: RME -> RME -> Bool
$c> :: RME -> RME -> Bool
> :: RME -> RME -> Bool
$c>= :: RME -> RME -> Bool
>= :: RME -> RME -> Bool
$cmax :: RME -> RME -> RME
max :: RME -> RME -> RME
$cmin :: RME -> RME -> RME
min :: RME -> RME -> RME
Ord, Int -> RME -> ShowS
[RME] -> ShowS
RME -> String
(Int -> RME -> ShowS)
-> (RME -> String) -> ([RME] -> ShowS) -> Show RME
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RME -> ShowS
showsPrec :: Int -> RME -> ShowS
$cshow :: RME -> String
show :: RME -> String
$cshowList :: [RME] -> ShowS
showList :: [RME] -> ShowS
Show)

-- | Evaluate formula with given variable assignment.
eval :: RME -> (Int -> Bool) -> Bool
eval :: RME -> (Int -> Bool) -> Bool
eval RME
anf Int -> Bool
v =
  case RME
anf of
    RME
R0 -> Bool
False
    RME
R1 -> Bool
True
    Node Int
n RME
a RME
b -> (RME -> (Int -> Bool) -> Bool
eval RME
a Int -> Bool
v) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= (Int -> Bool
v Int
n Bool -> Bool -> Bool
&& RME -> (Int -> Bool) -> Bool
eval RME
b Int -> Bool
v)

-- | Normalizing constructor.
node :: Int -> RME -> RME -> RME
node :: Int -> RME -> RME -> RME
node Int
_ RME
a RME
R0 = RME
a
node Int
n RME
a RME
b = Int -> RME -> RME -> RME
Node Int
n RME
a RME
b

-- | Constant true formula.
true :: RME
true :: RME
true = RME
R1

-- | Constant false formula.
false :: RME
false :: RME
false = RME
R0

-- | Boolean constant formulas.
constant :: Bool -> RME
constant :: Bool -> RME
constant Bool
False = RME
false
constant Bool
True = RME
true

-- | Test whether an RME formula is a constant boolean.
isBool :: RME -> Maybe Bool
isBool :: RME -> Maybe Bool
isBool RME
R0 = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
isBool RME
R1 = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
isBool RME
_ = Maybe Bool
forall a. Maybe a
Nothing

-- | Boolean literals.
lit :: Int -> RME
lit :: Int -> RME
lit Int
n = Int -> RME -> RME -> RME
Node Int
n RME
R0 RME
R1

-- | Logical complement.
compl :: RME -> RME
compl :: RME -> RME
compl RME
R0 = RME
R1
compl RME
R1 = RME
R0
compl (Node Int
n RME
a RME
b) = Int -> RME -> RME -> RME
Node Int
n (RME -> RME
compl RME
a) RME
b

-- | Logical exclusive-or.
xor :: RME -> RME -> RME
xor :: RME -> RME -> RME
xor RME
R0 RME
y = RME
y
xor RME
R1 RME
y = RME -> RME
compl RME
y
xor RME
x RME
R0 = RME
x
xor RME
x RME
R1 = RME -> RME
compl RME
x
xor x :: RME
x@(Node Int
i RME
a RME
b) y :: RME
y@(Node Int
j RME
c RME
d)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int -> RME -> RME -> RME
Node Int
i (RME -> RME -> RME
xor RME
a RME
y) RME
b
  | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Int -> RME -> RME -> RME
Node Int
j (RME -> RME -> RME
xor RME
x RME
c) RME
d
  | Bool
otherwise = Int -> RME -> RME -> RME
node Int
i (RME -> RME -> RME
xor RME
a RME
c) (RME -> RME -> RME
xor RME
b RME
d)

-- | Logical conjunction.
conj :: RME -> RME -> RME
conj :: RME -> RME -> RME
conj RME
R0 RME
_ = RME
R0
conj RME
R1 RME
y = RME
y
conj RME
_ RME
R0 = RME
R0
conj RME
x RME
R1 = RME
x
conj x :: RME
x@(Node Int
i RME
a RME
b) y :: RME
y@(Node Int
j RME
c RME
d)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int -> RME -> RME -> RME
node Int
i (RME -> RME -> RME
conj RME
a RME
y) (RME -> RME -> RME
conj RME
b RME
y)
  | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Int -> RME -> RME -> RME
node Int
j (RME -> RME -> RME
conj RME
x RME
c) (RME -> RME -> RME
conj RME
x RME
d)
  | Bool
otherwise = Int -> RME -> RME -> RME
node Int
i RME
ac (RME -> RME -> RME
xor RME
ac (RME -> RME -> RME
conj (RME -> RME -> RME
xor RME
a RME
b) (RME -> RME -> RME
xor RME
c RME
d)))
  where ac :: RME
ac = RME -> RME -> RME
conj RME
a RME
c

-- | Logical disjunction.
disj :: RME -> RME -> RME
disj :: RME -> RME -> RME
disj RME
R0 RME
y = RME
y
disj RME
R1 RME
_ = RME
R1
disj RME
x RME
R0 = RME
x
disj RME
_ RME
R1 = RME
R1
disj x :: RME
x@(Node Int
i RME
a RME
b) y :: RME
y@(Node Int
j RME
c RME
d)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int -> RME -> RME -> RME
node Int
i (RME -> RME -> RME
disj RME
a RME
y) (RME -> RME -> RME
conj RME
b (RME -> RME
compl RME
y))
  | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Int -> RME -> RME -> RME
node Int
j (RME -> RME -> RME
disj RME
x RME
c) (RME -> RME -> RME
conj (RME -> RME
compl RME
x) RME
d)
  | Bool
otherwise = Int -> RME -> RME -> RME
node Int
i RME
ac (RME -> RME -> RME
xor RME
ac (RME -> RME -> RME
disj (RME -> RME -> RME
xor RME
a RME
b) (RME -> RME -> RME
xor RME
c RME
d)))
  where ac :: RME
ac = RME -> RME -> RME
disj RME
a RME
c

-- | Logical equivalence.
iff :: RME -> RME -> RME
iff :: RME -> RME -> RME
iff RME
x RME
y = RME -> RME -> RME
xor (RME -> RME
compl RME
x) RME
y
{-
iff R0 y = compl y
iff R1 y = y
iff x R0 = compl x
iff x R1 = x
iff x@(Node i a b) y@(Node j c d)
  | i < j = Node i (iff a y) b
  | j < i = Node j (iff x c) d
  | otherwise = node i (iff a c) (xor b d)
-}

-- | Logical if-then-else.
mux :: RME -> RME -> RME -> RME
--mux w x y = xor (conj w x) (conj (compl w) y)
mux :: RME -> RME -> RME -> RME
mux RME
R0 RME
_ RME
y = RME
y
mux RME
R1 RME
x RME
_ = RME
x
mux RME
b RME
x RME
y = RME -> RME -> RME
xor (RME -> RME -> RME
conj RME
b (RME -> RME -> RME
xor RME
x RME
y)) RME
y

{-
mux R0 x y = y
mux R1 x y = x
mux w R0 y = conj (compl w) y
mux w R1 y = disj w y
mux w x R0 = conj w x
mux w x R1 = disj (compl w) x
mux w@(Node i a b) x@(Node j c d) y@(Node k e f)
  | i < j && i < k = node i (mux a x y) (conj b (xor x y))
  | j < i && j < k = node i (mux w c y) (conj w d)
  | k < i && k < j = node i (mux w x e) (conj (compl w) f)
  | i == j && i < k = node i (mux a c y) _
-}

-- | Satisfiability checker.
sat :: RME -> Maybe [(Int, Bool)]
sat :: RME -> Maybe [(Int, Bool)]
sat RME
R0 = Maybe [(Int, Bool)]
forall a. Maybe a
Nothing
sat RME
R1 = [(Int, Bool)] -> Maybe [(Int, Bool)]
forall a. a -> Maybe a
Just []
sat (Node Int
n RME
a RME
b) =
  case RME -> Maybe [(Int, Bool)]
sat RME
a of
    Just [(Int, Bool)]
xs -> [(Int, Bool)] -> Maybe [(Int, Bool)]
forall a. a -> Maybe a
Just ((Int
n, Bool
False) (Int, Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. a -> [a] -> [a]
: [(Int, Bool)]
xs)
    Maybe [(Int, Bool)]
Nothing -> ([(Int, Bool)] -> [(Int, Bool)])
-> Maybe [(Int, Bool)] -> Maybe [(Int, Bool)]
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Int
n, Bool
True) (Int, Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. a -> [a] -> [a]
:) (RME -> Maybe [(Int, Bool)]
sat RME
b)

-- | List of all satisfying assignments.
allsat :: RME -> [[(Int, Bool)]]
allsat :: RME -> [[(Int, Bool)]]
allsat RME
R0 = []
allsat RME
R1 = [[]]
allsat (Node Int
n RME
a RME
b) =
  ([(Int, Bool)] -> [(Int, Bool)])
-> [[(Int, Bool)]] -> [[(Int, Bool)]]
forall a b. (a -> b) -> [a] -> [b]
map ((Int
n, Bool
False) (Int, Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. a -> [a] -> [a]
:) (RME -> [[(Int, Bool)]]
allsat RME
a) [[(Int, Bool)]] -> [[(Int, Bool)]] -> [[(Int, Bool)]]
forall a. [a] -> [a] -> [a]
++ ([(Int, Bool)] -> [(Int, Bool)])
-> [[(Int, Bool)]] -> [[(Int, Bool)]]
forall a b. (a -> b) -> [a] -> [b]
map ((Int
n, Bool
True) (Int, Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. a -> [a] -> [a]
:) (RME -> [[(Int, Bool)]]
allsat (RME -> RME -> RME
xor RME
a RME
b))

-- | Maximum polynomial degree.
degree :: RME -> Int
degree :: RME -> Int
degree RME
R0 = Int
0
degree RME
R1 = Int
0
degree (Node Int
_ RME
a RME
b) = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (RME -> Int
degree RME
a) (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RME -> Int
degree RME
b)

-- | Tree depth.
depth :: RME -> Int
depth :: RME -> Int
depth RME
R0 = Int
0
depth RME
R1 = Int
0
depth (Node Int
_ RME
a RME
b) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (RME -> Int
depth RME
a) (RME -> Int
depth RME
b)

-- | Tree size.
size :: RME -> Int
size :: RME -> Int
size RME
R0 = Int
1
size RME
R1 = Int
1
size (Node Int
_ RME
a RME
b) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RME -> Int
size RME
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RME -> Int
size RME
b

-- | Convert to an explicit polynomial representation.
explode :: RME -> [[Int]]
explode :: RME -> [[Int]]
explode RME
R0 = []
explode RME
R1 = [[]]
explode (Node Int
i RME
a RME
b) = RME -> [[Int]]
explode RME
a [[Int]] -> [[Int]] -> [[Int]]
forall a. [a] -> [a] -> [a]
++ ([Int] -> [Int]) -> [[Int]] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map (Int
iInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:) (RME -> [[Int]]
explode RME
b)