{-# LANGUAGE OverloadedStrings #-}
module Tokstyle.Analysis.Types
    ( CallGraph
    , CallSite(..)
    , CallType(..)
    , FunctionName
    , CalleeMap
    , PointsToMap
    , PointsToSummary
    , PointsToSummaryData(..)
    , getCallers
    , AbstractLocation(..)
    , toAbstractLocation
    , NodeId
    , Context
    , lookupOrError
    ) where

import           Data.Fix        (Fix (..))
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import           Data.Maybe      (fromMaybe)
import           Data.Set        (Set)
import           Data.Text       (Text)
import qualified Data.Text       as Text
import           GHC.Stack       (HasCallStack)
import qualified Language.Cimple as C

-- | Represents a location where a value can be stored. This allows the
-- analysis to distinguish between different variables and fields.
data AbstractLocation
    = VarLocation Text          -- ^ A local variable or parameter name.
    | GlobalVarLocation Text    -- ^ A global or static variable.
    | FieldLocation AbstractLocation Text -- ^ A struct/union field, e.g., msg.data
    | DerefLocation AbstractLocation    -- ^ The memory pointed to by a pointer, e.g., *p
    | ReturnLocation Text       -- ^ The return value of a function.
    | HeapLocation Int          -- ^ An abstract location on the heap.
    | FunctionLocation Text     -- ^ The address of a function.
    deriving (AbstractLocation -> AbstractLocation -> Bool
(AbstractLocation -> AbstractLocation -> Bool)
-> (AbstractLocation -> AbstractLocation -> Bool)
-> Eq AbstractLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AbstractLocation -> AbstractLocation -> Bool
$c/= :: AbstractLocation -> AbstractLocation -> Bool
== :: AbstractLocation -> AbstractLocation -> Bool
$c== :: AbstractLocation -> AbstractLocation -> Bool
Eq, Eq AbstractLocation
Eq AbstractLocation
-> (AbstractLocation -> AbstractLocation -> Ordering)
-> (AbstractLocation -> AbstractLocation -> Bool)
-> (AbstractLocation -> AbstractLocation -> Bool)
-> (AbstractLocation -> AbstractLocation -> Bool)
-> (AbstractLocation -> AbstractLocation -> Bool)
-> (AbstractLocation -> AbstractLocation -> AbstractLocation)
-> (AbstractLocation -> AbstractLocation -> AbstractLocation)
-> Ord AbstractLocation
AbstractLocation -> AbstractLocation -> Bool
AbstractLocation -> AbstractLocation -> Ordering
AbstractLocation -> AbstractLocation -> AbstractLocation
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: AbstractLocation -> AbstractLocation -> AbstractLocation
$cmin :: AbstractLocation -> AbstractLocation -> AbstractLocation
max :: AbstractLocation -> AbstractLocation -> AbstractLocation
$cmax :: AbstractLocation -> AbstractLocation -> AbstractLocation
>= :: AbstractLocation -> AbstractLocation -> Bool
$c>= :: AbstractLocation -> AbstractLocation -> Bool
> :: AbstractLocation -> AbstractLocation -> Bool
$c> :: AbstractLocation -> AbstractLocation -> Bool
<= :: AbstractLocation -> AbstractLocation -> Bool
$c<= :: AbstractLocation -> AbstractLocation -> Bool
< :: AbstractLocation -> AbstractLocation -> Bool
$c< :: AbstractLocation -> AbstractLocation -> Bool
compare :: AbstractLocation -> AbstractLocation -> Ordering
$ccompare :: AbstractLocation -> AbstractLocation -> Ordering
$cp1Ord :: Eq AbstractLocation
Ord, Int -> AbstractLocation -> ShowS
[AbstractLocation] -> ShowS
AbstractLocation -> String
(Int -> AbstractLocation -> ShowS)
-> (AbstractLocation -> String)
-> ([AbstractLocation] -> ShowS)
-> Show AbstractLocation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AbstractLocation] -> ShowS
$cshowList :: [AbstractLocation] -> ShowS
show :: AbstractLocation -> String
$cshow :: AbstractLocation -> String
showsPrec :: Int -> AbstractLocation -> ShowS
$cshowsPrec :: Int -> AbstractLocation -> ShowS
Show)

-- | A unique identifier for a C AST node.
type NodeId = Int

-- | The call-string context, limited to depth k.
type Context = [NodeId]

-- | A function name is just Text.
type FunctionName = Text

-- | Describes how a function is called.
data CallType = DirectCall | IndirectCall
    deriving (CallType -> CallType -> Bool
(CallType -> CallType -> Bool)
-> (CallType -> CallType -> Bool) -> Eq CallType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CallType -> CallType -> Bool
$c/= :: CallType -> CallType -> Bool
== :: CallType -> CallType -> Bool
$c== :: CallType -> CallType -> Bool
Eq, Eq CallType
Eq CallType
-> (CallType -> CallType -> Ordering)
-> (CallType -> CallType -> Bool)
-> (CallType -> CallType -> Bool)
-> (CallType -> CallType -> Bool)
-> (CallType -> CallType -> Bool)
-> (CallType -> CallType -> CallType)
-> (CallType -> CallType -> CallType)
-> Ord CallType
CallType -> CallType -> Bool
CallType -> CallType -> Ordering
CallType -> CallType -> CallType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: CallType -> CallType -> CallType
$cmin :: CallType -> CallType -> CallType
max :: CallType -> CallType -> CallType
$cmax :: CallType -> CallType -> CallType
>= :: CallType -> CallType -> Bool
$c>= :: CallType -> CallType -> Bool
> :: CallType -> CallType -> Bool
$c> :: CallType -> CallType -> Bool
<= :: CallType -> CallType -> Bool
$c<= :: CallType -> CallType -> Bool
< :: CallType -> CallType -> Bool
$c< :: CallType -> CallType -> Bool
compare :: CallType -> CallType -> Ordering
$ccompare :: CallType -> CallType -> Ordering
$cp1Ord :: Eq CallType
Ord, Int -> CallType -> ShowS
[CallType] -> ShowS
CallType -> String
(Int -> CallType -> ShowS)
-> (CallType -> String) -> ([CallType] -> ShowS) -> Show CallType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CallType] -> ShowS
$cshowList :: [CallType] -> ShowS
show :: CallType -> String
$cshow :: CallType -> String
showsPrec :: Int -> CallType -> ShowS
$cshowsPrec :: Int -> CallType -> ShowS
Show)

-- | A new, richer representation of a call site.
data CallSite = CallSite
    { CallSite -> Int
csNodeId   :: NodeId   -- The unique ID of the call expression node.
    , CallSite -> CallType
csCallType :: CallType -- Direct or Indirect
    } deriving (CallSite -> CallSite -> Bool
(CallSite -> CallSite -> Bool)
-> (CallSite -> CallSite -> Bool) -> Eq CallSite
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CallSite -> CallSite -> Bool
$c/= :: CallSite -> CallSite -> Bool
== :: CallSite -> CallSite -> Bool
$c== :: CallSite -> CallSite -> Bool
Eq, Eq CallSite
Eq CallSite
-> (CallSite -> CallSite -> Ordering)
-> (CallSite -> CallSite -> Bool)
-> (CallSite -> CallSite -> Bool)
-> (CallSite -> CallSite -> Bool)
-> (CallSite -> CallSite -> Bool)
-> (CallSite -> CallSite -> CallSite)
-> (CallSite -> CallSite -> CallSite)
-> Ord CallSite
CallSite -> CallSite -> Bool
CallSite -> CallSite -> Ordering
CallSite -> CallSite -> CallSite
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: CallSite -> CallSite -> CallSite
$cmin :: CallSite -> CallSite -> CallSite
max :: CallSite -> CallSite -> CallSite
$cmax :: CallSite -> CallSite -> CallSite
>= :: CallSite -> CallSite -> Bool
$c>= :: CallSite -> CallSite -> Bool
> :: CallSite -> CallSite -> Bool
$c> :: CallSite -> CallSite -> Bool
<= :: CallSite -> CallSite -> Bool
$c<= :: CallSite -> CallSite -> Bool
< :: CallSite -> CallSite -> Bool
$c< :: CallSite -> CallSite -> Bool
compare :: CallSite -> CallSite -> Ordering
$ccompare :: CallSite -> CallSite -> Ordering
$cp1Ord :: Eq CallSite
Ord, Int -> CallSite -> ShowS
[CallSite] -> ShowS
CallSite -> String
(Int -> CallSite -> ShowS)
-> (CallSite -> String) -> ([CallSite] -> ShowS) -> Show CallSite
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CallSite] -> ShowS
$cshowList :: [CallSite] -> ShowS
show :: CallSite -> String
$cshow :: CallSite -> String
showsPrec :: Int -> CallSite -> ShowS
$cshowsPrec :: Int -> CallSite -> ShowS
Show)

-- | A map from a callee's name to the set of ways it's called.
type CalleeMap = Map FunctionName (Set CallSite)

-- | The CallGraph is a map from a caller function to its CalleeMap.
type CallGraph = Map FunctionName CalleeMap

-- | The PointsToMap is the data flow fact. It maps a pointer's abstract
-- location to the set of abstract locations it can point to.
type PointsToMap = Map AbstractLocation (Set AbstractLocation)

-- | The summary for a function's points-to analysis in a specific context.
data PointsToSummaryData = PointsToSummaryData
    { PointsToSummaryData -> Set AbstractLocation
returnPointsTo :: Set AbstractLocation
    , PointsToSummaryData -> Map AbstractLocation (Set AbstractLocation)
outputPointsTo :: Map AbstractLocation (Set AbstractLocation)
    } deriving (PointsToSummaryData -> PointsToSummaryData -> Bool
(PointsToSummaryData -> PointsToSummaryData -> Bool)
-> (PointsToSummaryData -> PointsToSummaryData -> Bool)
-> Eq PointsToSummaryData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PointsToSummaryData -> PointsToSummaryData -> Bool
$c/= :: PointsToSummaryData -> PointsToSummaryData -> Bool
== :: PointsToSummaryData -> PointsToSummaryData -> Bool
$c== :: PointsToSummaryData -> PointsToSummaryData -> Bool
Eq, Eq PointsToSummaryData
Eq PointsToSummaryData
-> (PointsToSummaryData -> PointsToSummaryData -> Ordering)
-> (PointsToSummaryData -> PointsToSummaryData -> Bool)
-> (PointsToSummaryData -> PointsToSummaryData -> Bool)
-> (PointsToSummaryData -> PointsToSummaryData -> Bool)
-> (PointsToSummaryData -> PointsToSummaryData -> Bool)
-> (PointsToSummaryData
    -> PointsToSummaryData -> PointsToSummaryData)
-> (PointsToSummaryData
    -> PointsToSummaryData -> PointsToSummaryData)
-> Ord PointsToSummaryData
PointsToSummaryData -> PointsToSummaryData -> Bool
PointsToSummaryData -> PointsToSummaryData -> Ordering
PointsToSummaryData -> PointsToSummaryData -> PointsToSummaryData
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PointsToSummaryData -> PointsToSummaryData -> PointsToSummaryData
$cmin :: PointsToSummaryData -> PointsToSummaryData -> PointsToSummaryData
max :: PointsToSummaryData -> PointsToSummaryData -> PointsToSummaryData
$cmax :: PointsToSummaryData -> PointsToSummaryData -> PointsToSummaryData
>= :: PointsToSummaryData -> PointsToSummaryData -> Bool
$c>= :: PointsToSummaryData -> PointsToSummaryData -> Bool
> :: PointsToSummaryData -> PointsToSummaryData -> Bool
$c> :: PointsToSummaryData -> PointsToSummaryData -> Bool
<= :: PointsToSummaryData -> PointsToSummaryData -> Bool
$c<= :: PointsToSummaryData -> PointsToSummaryData -> Bool
< :: PointsToSummaryData -> PointsToSummaryData -> Bool
$c< :: PointsToSummaryData -> PointsToSummaryData -> Bool
compare :: PointsToSummaryData -> PointsToSummaryData -> Ordering
$ccompare :: PointsToSummaryData -> PointsToSummaryData -> Ordering
$cp1Ord :: Eq PointsToSummaryData
Ord, Int -> PointsToSummaryData -> ShowS
[PointsToSummaryData] -> ShowS
PointsToSummaryData -> String
(Int -> PointsToSummaryData -> ShowS)
-> (PointsToSummaryData -> String)
-> ([PointsToSummaryData] -> ShowS)
-> Show PointsToSummaryData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PointsToSummaryData] -> ShowS
$cshowList :: [PointsToSummaryData] -> ShowS
show :: PointsToSummaryData -> String
$cshow :: PointsToSummaryData -> String
showsPrec :: Int -> PointsToSummaryData -> ShowS
$cshowsPrec :: Int -> PointsToSummaryData -> ShowS
Show)

-- | The full, context-sensitive points-to summary for a function.
type PointsToSummary = Map Context PointsToSummaryData

-- | Helper function to get all functions that call a given function.
getCallers :: CallGraph -> FunctionName -> Map FunctionName (Set CallSite)
getCallers :: CallGraph -> FunctionName -> Map FunctionName (Set CallSite)
getCallers CallGraph
graph FunctionName
calleeName =
    (Map FunctionName (Set CallSite)
 -> FunctionName
 -> Map FunctionName (Set CallSite)
 -> Map FunctionName (Set CallSite))
-> Map FunctionName (Set CallSite)
-> CallGraph
-> Map FunctionName (Set CallSite)
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
Map.foldlWithKey' (\Map FunctionName (Set CallSite)
acc FunctionName
caller Map FunctionName (Set CallSite)
calleeMap ->
        case FunctionName
-> Map FunctionName (Set CallSite) -> Maybe (Set CallSite)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup FunctionName
calleeName Map FunctionName (Set CallSite)
calleeMap of
            Just Set CallSite
directions -> FunctionName
-> Set CallSite
-> Map FunctionName (Set CallSite)
-> Map FunctionName (Set CallSite)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert FunctionName
caller Set CallSite
directions Map FunctionName (Set CallSite)
acc
            Maybe (Set CallSite)
Nothing         -> Map FunctionName (Set CallSite)
acc
    ) Map FunctionName (Set CallSite)
forall k a. Map k a
Map.empty CallGraph
graph

-- | Helper to convert an LHS expression AST node to an AbstractLocation.
toAbstractLocation :: HasCallStack => C.Node (C.Lexeme Text) -> AbstractLocation
toAbstractLocation :: Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation (Fix NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
node) = case NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
node of
    C.VarExpr (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) -> FunctionName -> AbstractLocation
VarLocation FunctionName
name
    C.MemberAccess Node (Lexeme FunctionName)
struct (C.L AlexPosn
_ LexemeClass
_ FunctionName
field) ->
        AbstractLocation -> FunctionName -> AbstractLocation
FieldLocation (HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
struct) FunctionName
field
    C.PointerAccess Node (Lexeme FunctionName)
ptr (C.L AlexPosn
_ LexemeClass
_ FunctionName
field) ->
        AbstractLocation -> FunctionName -> AbstractLocation
FieldLocation (AbstractLocation -> AbstractLocation
DerefLocation (HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
ptr)) FunctionName
field
    C.UnaryExpr UnaryOp
C.UopDeref Node (Lexeme FunctionName)
ptr ->
        AbstractLocation -> AbstractLocation
DerefLocation (HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
ptr)
    C.UnaryExpr UnaryOp
C.UopAddress Node (Lexeme FunctionName)
inner ->
        HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
inner
    C.CastExpr Node (Lexeme FunctionName)
_ Node (Lexeme FunctionName)
inner ->
        HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
inner
    C.ArrayAccess Node (Lexeme FunctionName)
array (Fix (C.LiteralExpr LiteralType
C.Int (C.L AlexPosn
_ LexemeClass
_ FunctionName
index))) ->
        AbstractLocation -> FunctionName -> AbstractLocation
FieldLocation (HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
array) (String -> FunctionName
Text.pack (String
"_index_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ FunctionName -> String
Text.unpack FunctionName
index))
    C.ArrayAccess Node (Lexeme FunctionName)
array Node (Lexeme FunctionName)
_ -> HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
array -- Fallback for non-constant index
    C.VarDecl Node (Lexeme FunctionName)
_ (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) [Node (Lexeme FunctionName)]
_ -> FunctionName -> AbstractLocation
VarLocation FunctionName
name
    C.LiteralExpr LiteralType
C.ConstId (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) -> FunctionName -> AbstractLocation
VarLocation FunctionName
name
    NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
_ -> String -> AbstractLocation
forall a. HasCallStack => String -> a
error (String -> AbstractLocation) -> String -> AbstractLocation
forall a b. (a -> b) -> a -> b
$ String
"toAbstractLocation: Unhandled LHS node: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName)) -> String
forall a. Show a => a -> String
show NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
node

-- | A safer version of 'Map.!'.
lookupOrError :: (Ord k, Show k) => String -> Map k a -> k -> a
lookupOrError :: String -> Map k a -> k -> a
lookupOrError String
context Map k a
m k
k = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe (String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
context String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": Key not found in map: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ k -> String
forall a. Show a => a -> String
show k
k) (k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k a
m)