{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Strict     #-}
module Tokstyle.Analysis.Symbolic
    ( SVal (..)
    , Constraint (..)
    , SState (..)
    , emptyState
    , assign
    , addConstraint
    , canBeNull
    , merge
    , negateConstraint
    , sIte
    , sBinOp
    , sUnaryOp
    , sVar
    , sAddr
    , valDepth
    , lookupStore
    ) where

import           Control.Applicative          ((<|>))
import qualified Data.Map.Merge.Strict        as Map
import           Data.Map.Strict              (Map)
import qualified Data.Map.Strict              as Map
import           Data.Maybe                   (fromMaybe, listToMaybe)
import           Data.Set                     (Set)
import qualified Data.Set                     as Set
import           Language.Cimple              (BinaryOp (..), UnaryOp (..))
import           Tokstyle.Analysis.AccessPath

-- | A Symbolic Value represents the result of an expression.
data SVal
    = STop               -- ^ Unknown value
    | SVar AccessPath    -- ^ Initial value of a path
    | SAddr AccessPath   -- ^ The address of a memory location (e.g. &a, or a literal string)
    | SNull              -- ^ Literal NULL / 0 / nullptr
    | SBinOp BinaryOp SVal SVal
    | SUnaryOp UnaryOp SVal
    | SIte SVal SVal SVal -- ^ If-Then-Else (Phi node)
    deriving (Int -> SVal -> ShowS
[SVal] -> ShowS
SVal -> String
(Int -> SVal -> ShowS)
-> (SVal -> String) -> ([SVal] -> ShowS) -> Show SVal
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SVal] -> ShowS
$cshowList :: [SVal] -> ShowS
show :: SVal -> String
$cshow :: SVal -> String
showsPrec :: Int -> SVal -> ShowS
$cshowsPrec :: Int -> SVal -> ShowS
Show, SVal -> SVal -> Bool
(SVal -> SVal -> Bool) -> (SVal -> SVal -> Bool) -> Eq SVal
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SVal -> SVal -> Bool
$c/= :: SVal -> SVal -> Bool
== :: SVal -> SVal -> Bool
$c== :: SVal -> SVal -> Bool
Eq, Eq SVal
Eq SVal
-> (SVal -> SVal -> Ordering)
-> (SVal -> SVal -> Bool)
-> (SVal -> SVal -> Bool)
-> (SVal -> SVal -> Bool)
-> (SVal -> SVal -> Bool)
-> (SVal -> SVal -> SVal)
-> (SVal -> SVal -> SVal)
-> Ord SVal
SVal -> SVal -> Bool
SVal -> SVal -> Ordering
SVal -> SVal -> SVal
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SVal -> SVal -> SVal
$cmin :: SVal -> SVal -> SVal
max :: SVal -> SVal -> SVal
$cmax :: SVal -> SVal -> SVal
>= :: SVal -> SVal -> Bool
$c>= :: SVal -> SVal -> Bool
> :: SVal -> SVal -> Bool
$c> :: SVal -> SVal -> Bool
<= :: SVal -> SVal -> Bool
$c<= :: SVal -> SVal -> Bool
< :: SVal -> SVal -> Bool
$c< :: SVal -> SVal -> Bool
compare :: SVal -> SVal -> Ordering
$ccompare :: SVal -> SVal -> Ordering
$cp1Ord :: Eq SVal
Ord)

-- | A constraint on symbolic values.
data Constraint
    = SEquals SVal SVal
    | SNotEquals SVal SVal
    | SBool SVal         -- ^ The value is true/non-zero
    deriving (Int -> Constraint -> ShowS
[Constraint] -> ShowS
Constraint -> String
(Int -> Constraint -> ShowS)
-> (Constraint -> String)
-> ([Constraint] -> ShowS)
-> Show Constraint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Constraint] -> ShowS
$cshowList :: [Constraint] -> ShowS
show :: Constraint -> String
$cshow :: Constraint -> String
showsPrec :: Int -> Constraint -> ShowS
$cshowsPrec :: Int -> Constraint -> ShowS
Show, Constraint -> Constraint -> Bool
(Constraint -> Constraint -> Bool)
-> (Constraint -> Constraint -> Bool) -> Eq Constraint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Constraint -> Constraint -> Bool
$c/= :: Constraint -> Constraint -> Bool
== :: Constraint -> Constraint -> Bool
$c== :: Constraint -> Constraint -> Bool
Eq, Eq Constraint
Eq Constraint
-> (Constraint -> Constraint -> Ordering)
-> (Constraint -> Constraint -> Bool)
-> (Constraint -> Constraint -> Bool)
-> (Constraint -> Constraint -> Bool)
-> (Constraint -> Constraint -> Bool)
-> (Constraint -> Constraint -> Constraint)
-> (Constraint -> Constraint -> Constraint)
-> Ord Constraint
Constraint -> Constraint -> Bool
Constraint -> Constraint -> Ordering
Constraint -> Constraint -> Constraint
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Constraint -> Constraint -> Constraint
$cmin :: Constraint -> Constraint -> Constraint
max :: Constraint -> Constraint -> Constraint
$cmax :: Constraint -> Constraint -> Constraint
>= :: Constraint -> Constraint -> Bool
$c>= :: Constraint -> Constraint -> Bool
> :: Constraint -> Constraint -> Bool
$c> :: Constraint -> Constraint -> Bool
<= :: Constraint -> Constraint -> Bool
$c<= :: Constraint -> Constraint -> Bool
< :: Constraint -> Constraint -> Bool
$c< :: Constraint -> Constraint -> Bool
compare :: Constraint -> Constraint -> Ordering
$ccompare :: Constraint -> Constraint -> Ordering
$cp1Ord :: Eq Constraint
Ord)

-- | The symbolic state at a point in the program.
data SState = SState
    { SState -> Map AccessPath SVal
store       :: Map AccessPath SVal -- ^ Maps paths (e.g. "p->x") to their symbolic value
    , SState -> Set Constraint
constraints :: Set Constraint      -- ^ Known truths (e.g. {SVar 1 != SNull})
    } deriving (Int -> SState -> ShowS
[SState] -> ShowS
SState -> String
(Int -> SState -> ShowS)
-> (SState -> String) -> ([SState] -> ShowS) -> Show SState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SState] -> ShowS
$cshowList :: [SState] -> ShowS
show :: SState -> String
$cshow :: SState -> String
showsPrec :: Int -> SState -> ShowS
$cshowsPrec :: Int -> SState -> ShowS
Show, SState -> SState -> Bool
(SState -> SState -> Bool)
-> (SState -> SState -> Bool) -> Eq SState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SState -> SState -> Bool
$c/= :: SState -> SState -> Bool
== :: SState -> SState -> Bool
$c== :: SState -> SState -> Bool
Eq, Eq SState
Eq SState
-> (SState -> SState -> Ordering)
-> (SState -> SState -> Bool)
-> (SState -> SState -> Bool)
-> (SState -> SState -> Bool)
-> (SState -> SState -> Bool)
-> (SState -> SState -> SState)
-> (SState -> SState -> SState)
-> Ord SState
SState -> SState -> Bool
SState -> SState -> Ordering
SState -> SState -> SState
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SState -> SState -> SState
$cmin :: SState -> SState -> SState
max :: SState -> SState -> SState
$cmax :: SState -> SState -> SState
>= :: SState -> SState -> Bool
$c>= :: SState -> SState -> Bool
> :: SState -> SState -> Bool
$c> :: SState -> SState -> Bool
<= :: SState -> SState -> Bool
$c<= :: SState -> SState -> Bool
< :: SState -> SState -> Bool
$c< :: SState -> SState -> Bool
compare :: SState -> SState -> Ordering
$ccompare :: SState -> SState -> Ordering
$cp1Ord :: Eq SState
Ord)

emptyState :: SState
emptyState :: SState
emptyState = Map AccessPath SVal -> Set Constraint -> SState
SState Map AccessPath SVal
forall k a. Map k a
Map.empty Set Constraint
forall a. Set a
Set.empty

valDepth :: SVal -> Int
valDepth :: SVal -> Int
valDepth = \case
    SVar AccessPath
p         -> AccessPath -> Int
pathDepth AccessPath
p
    SAddr AccessPath
p        -> AccessPath -> Int
pathDepth AccessPath
p
    SBinOp BinaryOp
_ SVal
v1 SVal
v2 -> Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SVal -> Int
valDepth SVal
v1) (SVal -> Int
valDepth SVal
v2)
    SUnaryOp UnaryOp
_ SVal
v   -> Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ SVal -> Int
valDepth SVal
v
    SIte SVal
c SVal
t SVal
e     -> Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SVal -> Int
valDepth SVal
c) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SVal -> Int
valDepth SVal
t) (SVal -> Int
valDepth SVal
e))
    SVal
_              -> Int
1

sVar :: AccessPath -> SVal
sVar :: AccessPath -> SVal
sVar AccessPath
p | AccessPath -> Int
pathDepth AccessPath
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 = SVal
STop
       | Bool
otherwise = AccessPath -> SVal
SVar AccessPath
p

sAddr :: AccessPath -> SVal
sAddr :: AccessPath -> SVal
sAddr AccessPath
p | AccessPath -> Int
pathDepth AccessPath
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 = SVal
STop
        | Bool
otherwise = AccessPath -> SVal
SAddr AccessPath
p

sIte :: SVal -> SVal -> SVal -> SVal
sIte :: SVal -> SVal -> SVal -> SVal
sIte SVal
c SVal
t SVal
e
    | SVal
t SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
e = SVal
t
    | SUnaryOp UnaryOp
UopNot SVal
c' <- SVal
c = SVal -> SVal -> SVal -> SVal
sIte SVal
c' SVal
e SVal
t
    | SIte SVal
c' SVal
t' SVal
_ <- SVal
t, SVal
c SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
c' = SVal -> SVal -> SVal -> SVal
sIte SVal
c SVal
t' SVal
e
    | SIte SVal
c' SVal
_ SVal
e' <- SVal
e, SVal
c SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
c' = SVal -> SVal -> SVal -> SVal
sIte SVal
c SVal
t SVal
e'
    | SVal -> Int
valDepth SVal
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 Bool -> Bool -> Bool
|| SVal -> Int
valDepth SVal
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 Bool -> Bool -> Bool
|| SVal -> Int
valDepth SVal
e Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 = SVal
STop
    | Bool
otherwise = SVal -> SVal -> SVal -> SVal
SIte SVal
c SVal
t SVal
e

sBinOp :: BinaryOp -> SVal -> SVal -> SVal
sBinOp :: BinaryOp -> SVal -> SVal -> SVal
sBinOp BinaryOp
op SVal
v1 SVal
v2
    | SVal -> Int
valDepth SVal
v1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 Bool -> Bool -> Bool
|| SVal -> Int
valDepth SVal
v2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 = SVal
STop
    | Bool
otherwise = BinaryOp -> SVal -> SVal -> SVal
SBinOp BinaryOp
op SVal
v1 SVal
v2

sUnaryOp :: UnaryOp -> SVal -> SVal
sUnaryOp :: UnaryOp -> SVal -> SVal
sUnaryOp UnaryOp
UopNot (SUnaryOp UnaryOp
UopNot SVal
v) = SVal
v
sUnaryOp UnaryOp
op SVal
v
    | SVal -> Int
valDepth SVal
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10 = SVal
STop
    | Bool
otherwise = UnaryOp -> SVal -> SVal
SUnaryOp UnaryOp
op SVal
v

lookupStore :: AccessPath -> SState -> Maybe SVal
lookupStore :: AccessPath -> SState -> Maybe SVal
lookupStore AccessPath
path SState
st = AccessPath -> Maybe SVal
go AccessPath
path
  where
    go :: AccessPath -> Maybe SVal
go AccessPath
p = case AccessPath -> Map AccessPath SVal -> Maybe SVal
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup AccessPath
p (SState -> Map AccessPath SVal
store SState
st) of
        Just SVal
v -> SVal -> Maybe SVal
forall a. a -> Maybe a
Just SVal
v
        Maybe SVal
Nothing -> case AccessPath
p of
            PathField AccessPath
p' String
_ -> case AccessPath -> Maybe SVal
go AccessPath
p' of
                Just SVal
STop -> SVal -> Maybe SVal
forall a. a -> Maybe a
Just SVal
STop
                Maybe SVal
_         -> Maybe SVal
forall a. Maybe a
Nothing
            PathDeref AccessPath
p'   -> case AccessPath -> Maybe SVal
go AccessPath
p' of
                Just SVal
STop -> SVal -> Maybe SVal
forall a. a -> Maybe a
Just SVal
STop
                Maybe SVal
_         -> Maybe SVal
forall a. Maybe a
Nothing
            AccessPath
_              -> Maybe SVal
forall a. Maybe a
Nothing

-- | Assign a value to a path, and handle invalidation of dependent paths.
assign :: AccessPath -> SVal -> SState -> SState
assign :: AccessPath -> SVal -> SState -> SState
assign AccessPath
path SVal
val SState
st =
    let -- Remove the path and all paths that have it as a prefix (e.g. assigning to 'p' invalidates 'p->x')
        store' :: Map AccessPath SVal
store' = (AccessPath -> SVal -> Bool)
-> Map AccessPath SVal -> Map AccessPath SVal
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\AccessPath
k SVal
_ -> Bool -> Bool
not (AccessPath
path AccessPath -> AccessPath -> Bool
`isPathPrefixOf` AccessPath
k)) (SState -> Map AccessPath SVal
store SState
st)
        store'' :: Map AccessPath SVal
store'' = if SVal
val SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
STop then Map AccessPath SVal
store' else AccessPath -> SVal -> Map AccessPath SVal -> Map AccessPath SVal
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert AccessPath
path SVal
val Map AccessPath SVal
store'
    in SState
st { store :: Map AccessPath SVal
store = Map AccessPath SVal
store'' }

-- | Add a new constraint to the state, with basic simplification.
addConstraint :: Constraint -> SState -> SState
addConstraint :: Constraint -> SState -> SState
addConstraint Constraint
c SState
st
    | Set Constraint -> Int
forall a. Set a -> Int
Set.size (SState -> Set Constraint
constraints SState
st) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
100 = SState
st
    | Bool
otherwise = SState
st { constraints :: Set Constraint
constraints = Set Constraint -> Set Constraint -> Set Constraint
forall a. Ord a => Set a -> Set a -> Set a
Set.union (SState -> Set Constraint
constraints SState
st) (Constraint -> Set Constraint
simplify Constraint
c) }
  where
    simplify :: Constraint -> Set Constraint
simplify = \case
        SBool (SBinOp BinaryOp
BopNe SVal
v SVal
SNull) -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton (SVal -> SVal -> Constraint
SNotEquals SVal
v SVal
SNull)
        SBool (SBinOp BinaryOp
BopNe SVal
SNull SVal
v) -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton (SVal -> SVal -> Constraint
SNotEquals SVal
v SVal
SNull)
        SBool (SBinOp BinaryOp
BopEq SVal
v SVal
SNull) -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton (SVal -> SVal -> Constraint
SEquals SVal
v SVal
SNull)
        SBool (SBinOp BinaryOp
BopEq SVal
SNull SVal
v) -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton (SVal -> SVal -> Constraint
SEquals SVal
v SVal
SNull)
        SBool (SBinOp BinaryOp
BopAnd SVal
v1 SVal
v2)  -> Set Constraint -> Set Constraint -> Set Constraint
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool SVal
v1)) (Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool SVal
v2))
        SBool (SUnaryOp UnaryOp
UopNot (SBinOp BinaryOp
BopOr SVal
v1 SVal
v2)) -> Set Constraint -> Set Constraint -> Set Constraint
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool (UnaryOp -> SVal -> SVal
SUnaryOp UnaryOp
UopNot SVal
v1))) (Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool (UnaryOp -> SVal -> SVal
SUnaryOp UnaryOp
UopNot SVal
v2)))
        SBool (SUnaryOp UnaryOp
UopNot (SBinOp BinaryOp
BopAnd SVal
v1 SVal
v2)) -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton (SVal -> Constraint
SBool (UnaryOp -> SVal -> SVal
SUnaryOp UnaryOp
UopNot (BinaryOp -> SVal -> SVal -> SVal
SBinOp BinaryOp
BopAnd SVal
v1 SVal
v2)))
        SBool (SUnaryOp UnaryOp
UopNot (SUnaryOp UnaryOp
UopNot SVal
v)) -> Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool SVal
v)
        SBool (SUnaryOp UnaryOp
UopNot (SBinOp BinaryOp
BopEq SVal
v1 SVal
v2)) -> Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool (BinaryOp -> SVal -> SVal -> SVal
SBinOp BinaryOp
BopNe SVal
v1 SVal
v2))
        SBool (SUnaryOp UnaryOp
UopNot (SBinOp BinaryOp
BopNe SVal
v1 SVal
v2)) -> Constraint -> Set Constraint
simplify (SVal -> Constraint
SBool (BinaryOp -> SVal -> SVal -> SVal
SBinOp BinaryOp
BopEq SVal
v1 SVal
v2))
        SBool (SUnaryOp UnaryOp
UopNot SVal
v)    -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton (SVal -> Constraint
SBool (UnaryOp -> SVal -> SVal
SUnaryOp UnaryOp
UopNot SVal
v))
        Constraint
c' -> Constraint -> Set Constraint
forall a. a -> Set a
Set.singleton Constraint
c'

negateConstraint :: Constraint -> Constraint
negateConstraint :: Constraint -> Constraint
negateConstraint = \case
    SEquals SVal
v1 SVal
v2             -> SVal -> SVal -> Constraint
SNotEquals SVal
v1 SVal
v2
    SNotEquals SVal
v1 SVal
v2          -> SVal -> SVal -> Constraint
SEquals SVal
v1 SVal
v2
    SBool (SUnaryOp UnaryOp
UopNot SVal
v) -> SVal -> Constraint
SBool SVal
v
    SBool SVal
v                   -> SVal -> Constraint
SBool (UnaryOp -> SVal -> SVal
sUnaryOp UnaryOp
UopNot SVal
v)

-- | Solver: check if a value is known to be non-null.
-- Takes a predicate to check if an initial SVar is known to be non-null (e.g. from declarations).
canBeNull :: (SVal -> Bool) -> SVal -> SState -> Bool
canBeNull :: (SVal -> Bool) -> SVal -> SState -> Bool
canBeNull SVal -> Bool
isDeclNonNull SVal
val SState
st = Bool -> Bool
not (Int -> Set SVal -> Set SVal -> SState -> Bool
forall t.
(Ord t, Num t) =>
t -> Set SVal -> Set SVal -> SState -> Bool
isKnownNonNull (Int
0 :: Int) (SVal -> Set SVal
forall a. a -> Set a
Set.singleton SVal
val) Set SVal
forall a. Set a
Set.empty SState
st)
  where
    isKnownNonNull :: t -> Set SVal -> Set SVal -> SState -> Bool
isKnownNonNull t
depth Set SVal
todo Set SVal
seen SState
s
        | t
depth t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> t
10 = Bool
False -- Limit recursion
        | Set SVal -> Bool
forall a. Set a -> Bool
Set.null Set SVal
todo = Bool
False
        | Bool
otherwise =
            let (SVal
v, Set SVal
todo') = Set SVal -> (SVal, Set SVal)
forall a. Set a -> (a, Set a)
Set.deleteFindMin Set SVal
todo
                seen' :: Set SVal
seen' = SVal -> Set SVal -> Set SVal
forall a. Ord a => a -> Set a -> Set a
Set.insert SVal
v Set SVal
seen

                -- Check if this specific symbol is non-null
                direct :: Bool
direct = case SVal
v of
                    SVal
SNull -> Bool
False
                    SAddr _ -> Bool
True
                    SBinOp op _ _ | BinaryOp -> Bool
isComparison BinaryOp
op -> Bool
True
                    SIte STop t e -> (Bool -> Bool
not (t -> SVal -> SState -> Bool
canBeNull' (t
depth t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) SVal
t SState
s)) Bool -> Bool -> Bool
&&
                                     (Bool -> Bool
not (t -> SVal -> SState -> Bool
canBeNull' (t
depth t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) SVal
e SState
s))
                    SIte c t e -> (Bool -> Bool
not (t -> SVal -> SState -> Bool
canBeNull' (t
depth t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) SVal
t (Constraint -> SState -> SState
addConstraint (SVal -> Constraint
SBool SVal
c) SState
s))) Bool -> Bool -> Bool
&&
                                  (Bool -> Bool
not (t -> SVal -> SState -> Bool
canBeNull' (t
depth t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) SVal
e (Constraint -> SState -> SState
addConstraint (Constraint -> Constraint
negateConstraint (SVal -> Constraint
SBool SVal
c)) SState
s)))
                    SVal
_ -> SVal -> Bool
isDeclNonNull SVal
v Bool -> Bool -> Bool
||
                         (Constraint -> Bool) -> Set Constraint -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (SVal -> Constraint -> Bool
isNonNullConstraint SVal
v) (SState -> Set Constraint
constraints SState
s) Bool -> Bool -> Bool
||
                         (Constraint -> Bool) -> Set Constraint -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (SVal -> Constraint -> Bool
isEqualAddress SVal
v) (SState -> Set Constraint
constraints SState
s)

                -- Find symbols equal to this one that we haven't checked yet
                equals :: Set SVal
equals = [SVal] -> Set SVal
forall a. Ord a => [a] -> Set a
Set.fromList [ if SVal
v1 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v then SVal
v2 else SVal
v1
                                      | SEquals SVal
v1 SVal
v2 <- Set Constraint -> [Constraint]
forall a. Set a -> [a]
Set.toList (SState -> Set Constraint
constraints SState
s)
                                      , SVal
v1 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v Bool -> Bool -> Bool
|| SVal
v2 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v
                                      ]
                todo'' :: Set SVal
todo'' = (Set SVal
todo' Set SVal -> Set SVal -> Set SVal
forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set SVal
equals) Set SVal -> Set SVal -> Set SVal
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set SVal
seen'
            in Bool
direct Bool -> Bool -> Bool
|| t -> Set SVal -> Set SVal -> SState -> Bool
isKnownNonNull (t
depth t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) Set SVal
todo'' Set SVal
seen' SState
s

    canBeNull' :: t -> SVal -> SState -> Bool
canBeNull' t
d SVal
v SState
s = Bool -> Bool
not (t -> Set SVal -> Set SVal -> SState -> Bool
isKnownNonNull t
d (SVal -> Set SVal
forall a. a -> Set a
Set.singleton SVal
v) Set SVal
forall a. Set a
Set.empty SState
s)

    isComparison :: BinaryOp -> Bool
isComparison = \case
        BinaryOp
BopEq  -> Bool
True
        BinaryOp
BopNe  -> Bool
True
        BinaryOp
BopLt  -> Bool
True
        BinaryOp
BopLe  -> Bool
True
        BinaryOp
BopGt  -> Bool
True
        BinaryOp
BopGe  -> Bool
True
        BinaryOp
BopAnd -> Bool
True
        BinaryOp
BopOr  -> Bool
True
        BinaryOp
_      -> Bool
False

    isNonNullConstraint :: SVal -> Constraint -> Bool
isNonNullConstraint SVal
v = \case
        SNotEquals SVal
v1 SVal
v2 -> (SVal
v1 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v Bool -> Bool -> Bool
&& SVal
v2 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
SNull) Bool -> Bool -> Bool
|| (SVal
v1 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
SNull Bool -> Bool -> Bool
&& SVal
v2 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v)
        SBool SVal
v'         -> SVal
v' SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v
        Constraint
_                -> Bool
False

    isEqualAddress :: SVal -> Constraint -> Bool
isEqualAddress SVal
v = \case
        SEquals SVal
v1 (SAddr AccessPath
_) -> SVal
v1 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v
        SEquals (SAddr AccessPath
_) SVal
v2 -> SVal
v2 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v
        Constraint
_                    -> Bool
False

-- | Merge two symbolic states.
-- If a condition is provided, differing values are merged into an SIte (Phi node).
-- Otherwise, they are merged into an SIte with an inferred condition if possible,
-- or STop (unknown) if no condition can be inferred.
-- We also preserve non-nullness: if a path is non-null in both branches,
-- the merged value is constrained to be non-null.
merge :: (SVal -> Bool) -> Maybe SVal -> SState -> SState -> SState
merge :: (SVal -> Bool) -> Maybe SVal -> SState -> SState -> SState
merge SVal -> Bool
isDeclNonNull Maybe SVal
mCond SState
s1 SState
s2 =
    let inferredCond :: Maybe SVal
inferredCond = Maybe SVal
mCond Maybe SVal -> Maybe SVal -> Maybe SVal
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> SState -> SState -> Maybe SVal
inferCond SState
s1 SState
s2
        st :: SState
st = SState :: Map AccessPath SVal -> Set Constraint -> SState
SState
            { store :: Map AccessPath SVal
store       = SimpleWhenMissing AccessPath SVal SVal
-> SimpleWhenMissing AccessPath SVal SVal
-> SimpleWhenMatched AccessPath SVal SVal SVal
-> Map AccessPath SVal
-> Map AccessPath SVal
-> Map AccessPath SVal
forall k a c b.
Ord k =>
SimpleWhenMissing k a c
-> SimpleWhenMissing k b c
-> SimpleWhenMatched k a b c
-> Map k a
-> Map k b
-> Map k c
Map.merge SimpleWhenMissing AccessPath SVal SVal
forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
Map.dropMissing SimpleWhenMissing AccessPath SVal SVal
forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
Map.dropMissing ((AccessPath -> SVal -> SVal -> SVal)
-> SimpleWhenMatched AccessPath SVal SVal SVal
forall (f :: * -> *) k x y z.
Applicative f =>
(k -> x -> y -> z) -> WhenMatched f k x y z
Map.zipWithMatched (Maybe SVal -> AccessPath -> SVal -> SVal -> SVal
forall p. Maybe SVal -> p -> SVal -> SVal -> SVal
mergeVal Maybe SVal
inferredCond)) (SState -> Map AccessPath SVal
store SState
s1) (SState -> Map AccessPath SVal
store SState
s2)
            , constraints :: Set Constraint
constraints = Set Constraint -> Set Constraint -> Set Constraint
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection (SState -> Set Constraint
constraints SState
s1) (SState -> Set Constraint
constraints SState
s2)
            }
        -- Preserve non-nullness if both branches were non-null.
        -- We check variables that were modified in EITHER branch.
        allPaths :: Set AccessPath
allPaths = Map AccessPath SVal -> Set AccessPath
forall k a. Map k a -> Set k
Map.keysSet (SState -> Map AccessPath SVal
store SState
s1) Set AccessPath -> Set AccessPath -> Set AccessPath
forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Map AccessPath SVal -> Set AccessPath
forall k a. Map k a -> Set k
Map.keysSet (SState -> Map AccessPath SVal
store SState
s2)
        nonNullConstraints :: [Constraint]
nonNullConstraints = [ SVal -> SVal -> Constraint
SNotEquals SVal
v_merged SVal
SNull
                             | AccessPath
path <- Set AccessPath -> [AccessPath]
forall a. Set a -> [a]
Set.toList Set AccessPath
allPaths
                             , let v1 :: SVal
v1 = SVal -> Maybe SVal -> SVal
forall a. a -> Maybe a -> a
fromMaybe (AccessPath -> SVal
sVar AccessPath
path) (AccessPath -> SState -> Maybe SVal
lookupStore AccessPath
path SState
s1)
                                   v2 :: SVal
v2 = SVal -> Maybe SVal -> SVal
forall a. a -> Maybe a -> a
fromMaybe (AccessPath -> SVal
sVar AccessPath
path) (AccessPath -> SState -> Maybe SVal
lookupStore AccessPath
path SState
s2)
                             , Bool -> Bool
not ((SVal -> Bool) -> SVal -> SState -> Bool
canBeNull SVal -> Bool
isDeclNonNull SVal
v1 SState
s1)
                             , Bool -> Bool
not ((SVal -> Bool) -> SVal -> SState -> Bool
canBeNull SVal -> Bool
isDeclNonNull SVal
v2 SState
s2)
                             , let v_merged :: SVal
v_merged = SVal -> Maybe SVal -> SVal
forall a. a -> Maybe a -> a
fromMaybe (AccessPath -> SVal
sVar AccessPath
path) (AccessPath -> SState -> Maybe SVal
lookupStore AccessPath
path SState
st)
                             ]
    in (SState -> Constraint -> SState)
-> SState -> [Constraint] -> SState
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((Constraint -> SState -> SState) -> SState -> Constraint -> SState
forall a b c. (a -> b -> c) -> b -> a -> c
flip Constraint -> SState -> SState
addConstraint) SState
st [Constraint]
nonNullConstraints
  where
    mergeVal :: Maybe SVal -> p -> SVal -> SVal -> SVal
mergeVal Maybe SVal
mC p
_ SVal
v1 SVal
v2
        | SVal
v1 SVal -> SVal -> Bool
forall a. Eq a => a -> a -> Bool
== SVal
v2 = SVal
v1
        | Just SVal
c <- Maybe SVal
mC = SVal -> SVal -> SVal -> SVal
sIte SVal
c SVal
v1 SVal
v2
        | Bool
otherwise = SVal
STop

-- | Infer the condition that separates two states by looking for mismatched constraints.
inferCond :: SState -> SState -> Maybe SVal
inferCond :: SState -> SState -> Maybe SVal
inferCond SState
s1 SState
s2 = [SVal] -> Maybe SVal
forall a. [a] -> Maybe a
listToMaybe ([SVal] -> Maybe SVal) -> [SVal] -> Maybe SVal
forall a b. (a -> b) -> a -> b
$ (Constraint -> [SVal]) -> [Constraint] -> [SVal]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Constraint -> [SVal]
tryInfer (Set Constraint -> [Constraint]
forall a. Set a -> [a]
Set.toList (SState -> Set Constraint
constraints SState
s1))
  where
    tryInfer :: Constraint -> [SVal]
tryInfer Constraint
c1 = if Constraint -> Constraint
negateConstraint Constraint
c1 Constraint -> Set Constraint -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` SState -> Set Constraint
constraints SState
s2
                  then case Constraint
c1 of
                        SBool SVal
v          -> [SVal
v]
                        SEquals SVal
v1 SVal
v2    -> [BinaryOp -> SVal -> SVal -> SVal
sBinOp BinaryOp
BopEq SVal
v1 SVal
v2]
                        SNotEquals SVal
v1 SVal
v2 -> [BinaryOp -> SVal -> SVal -> SVal
sBinOp BinaryOp
BopNe SVal
v1 SVal
v2]
                  else []