{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Tokstyle.Analysis.CallGraph
    ( buildCallGraph
    , getCallees
    ) where

import           Control.Monad               (forM_)
import           Control.Monad.State.Strict  (State, execState, get, modify)
import           Data.Fix                    (Fix (..))
import           Data.Map.Strict             (Map)
import qualified Data.Map.Strict             as Map
import           Data.Maybe                  (fromMaybe)
import           Data.Set                    (Set)
import qualified Data.Set                    as Set
import           Data.Text                   (Text)
import qualified Language.Cimple             as C
import           Language.Cimple.TraverseAst (AstActions, astActions, doNode,
                                              traverseAst)
import           Tokstyle.Analysis.Types     (AbstractLocation (..), CallGraph,
                                              CallSite (..), CallType (..),
                                              CalleeMap, FunctionName, NodeId,
                                              PointsToMap, toAbstractLocation)

-- | The state used internally by the builder.
data BuilderState = BuilderState
    { BuilderState -> CallGraph
bsGraph           :: CallGraph
    , BuilderState -> Maybe FunctionName
bsCurrentFunction :: Maybe FunctionName
    , BuilderState -> Set FunctionName
bsFunctionNames   :: Set FunctionName
    , BuilderState -> PointsToMap
bsPointsToMap     :: PointsToMap
    }

-- | Defines the actions to perform while traversing the AST to build the call graph.
callGraphActions :: AstActions (State BuilderState) Text
callGraphActions :: AstActions (State BuilderState) FunctionName
callGraphActions = AstActions (State BuilderState) FunctionName
forall (f :: * -> *) text. Applicative f => AstActions f text
astActions
    { doNode :: FilePath
-> Node (Lexeme FunctionName)
-> State BuilderState ()
-> State BuilderState ()
doNode = \FilePath
_file Node (Lexeme FunctionName)
node State BuilderState ()
act -> case Node (Lexeme FunctionName)
-> NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
forall (f :: * -> *). Fix f -> f (Fix f)
unFix Node (Lexeme FunctionName)
node of
        -- Entering a function definition: update current function name in state
        C.FunctionDefn Scope
_ (Fix (C.FunctionPrototype Node (Lexeme FunctionName)
_ (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) [Node (Lexeme FunctionName)]
_)) Node (Lexeme FunctionName)
_ -> do
            (BuilderState -> BuilderState) -> State BuilderState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((BuilderState -> BuilderState) -> State BuilderState ())
-> (BuilderState -> BuilderState) -> State BuilderState ()
forall a b. (a -> b) -> a -> b
$ \BuilderState
st -> BuilderState
st { bsCurrentFunction :: Maybe FunctionName
bsCurrentFunction = FunctionName -> Maybe FunctionName
forall a. a -> Maybe a
Just FunctionName
name }
            State BuilderState ()
act -- Continue traversal into the function body
            (BuilderState -> BuilderState) -> State BuilderState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((BuilderState -> BuilderState) -> State BuilderState ())
-> (BuilderState -> BuilderState) -> State BuilderState ()
forall a b. (a -> b) -> a -> b
$ \BuilderState
st -> BuilderState
st { bsCurrentFunction :: Maybe FunctionName
bsCurrentFunction = Maybe FunctionName
forall a. Maybe a
Nothing } -- Reset after leaving the function

        -- Found a function call
        C.FunctionCall Node (Lexeme FunctionName)
calleeExpr [Node (Lexeme FunctionName)]
_ -> do
            BuilderState
st <- State BuilderState BuilderState
forall s (m :: * -> *). MonadState s m => m s
get
            let nodeId :: Int
nodeId = Node (Lexeme FunctionName) -> Int
forall a. Hashable a => Node a -> Int
C.getNodeId Node (Lexeme FunctionName)
node
            case BuilderState -> Maybe FunctionName
bsCurrentFunction BuilderState
st of
                Just FunctionName
callerName -> do
                    case Node (Lexeme FunctionName)
-> NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
forall (f :: * -> *). Fix f -> f (Fix f)
unFix Node (Lexeme FunctionName)
calleeExpr of
                        C.VarExpr (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) | FunctionName -> Set FunctionName -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member FunctionName
name (BuilderState -> Set FunctionName
bsFunctionNames BuilderState
st) ->
                            -- This is a direct call to a known function.
                            FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall FunctionName
callerName FunctionName
name (Int -> CallType -> CallSite
CallSite Int
nodeId CallType
DirectCall)
                        NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
_ -> do
                            -- This is an indirect call or a call to a function pointer.
                            let calleeLoc :: AbstractLocation
calleeLoc = HasCallStack => Node (Lexeme FunctionName) -> AbstractLocation
Node (Lexeme FunctionName) -> AbstractLocation
toAbstractLocation Node (Lexeme FunctionName)
calleeExpr
                            let resolvedCallees :: Set AbstractLocation
resolvedCallees = Set AbstractLocation
-> Maybe (Set AbstractLocation) -> Set AbstractLocation
forall a. a -> Maybe a -> a
fromMaybe Set AbstractLocation
forall a. Set a
Set.empty (Maybe (Set AbstractLocation) -> Set AbstractLocation)
-> Maybe (Set AbstractLocation) -> Set AbstractLocation
forall a b. (a -> b) -> a -> b
$ AbstractLocation -> PointsToMap -> Maybe (Set AbstractLocation)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup AbstractLocation
calleeLoc (BuilderState -> PointsToMap
bsPointsToMap BuilderState
st)
                            if Set AbstractLocation -> Bool
forall a. Set a -> Bool
Set.null Set AbstractLocation
resolvedCallees
                            then
                                -- If we can't resolve it, we'll just record it as an indirect
                                -- call to a name if we can find one.
                                case Node (Lexeme FunctionName)
-> NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
forall (f :: * -> *). Fix f -> f (Fix f)
unFix Node (Lexeme FunctionName)
calleeExpr of
                                    C.VarExpr (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) -> FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall FunctionName
callerName FunctionName
name (Int -> CallType -> CallSite
CallSite Int
nodeId CallType
IndirectCall)
                                    NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
_ -> FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall FunctionName
callerName FunctionName
"<indirect>" (Int -> CallType -> CallSite
CallSite Int
nodeId CallType
IndirectCall)
                            else
                                -- We resolved the pointer, so add indirect calls to all
                                -- possible targets.
                                [AbstractLocation]
-> (AbstractLocation -> State BuilderState ())
-> State BuilderState ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Set AbstractLocation -> [AbstractLocation]
forall a. Set a -> [a]
Set.toList Set AbstractLocation
resolvedCallees) ((AbstractLocation -> State BuilderState ())
 -> State BuilderState ())
-> (AbstractLocation -> State BuilderState ())
-> State BuilderState ()
forall a b. (a -> b) -> a -> b
$ \case
                                    FunctionLocation FunctionName
callee -> FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall FunctionName
callerName FunctionName
callee (Int -> CallType -> CallSite
CallSite Int
nodeId CallType
IndirectCall)
                                    VarLocation      FunctionName
callee -> FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall FunctionName
callerName FunctionName
callee (Int -> CallType -> CallSite
CallSite Int
nodeId CallType
IndirectCall)
                                    AbstractLocation
_ -> () -> State BuilderState ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                Maybe FunctionName
Nothing -> () -> State BuilderState ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            State BuilderState ()
act -- Continue traversal of arguments

        -- For all other nodes, just continue the traversal
        NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
_ -> State BuilderState ()
act
    }

-- | Pre-pass to collect all function names.
collectFunctionNames :: [(FilePath, [C.Node (C.Lexeme Text)])] -> Set FunctionName
collectFunctionNames :: [(FilePath, [Node (Lexeme FunctionName)])] -> Set FunctionName
collectFunctionNames [(FilePath, [Node (Lexeme FunctionName)])]
tus = State (Set FunctionName) () -> Set FunctionName -> Set FunctionName
forall s a. State s a -> s -> s
execState (AstActions (StateT (Set FunctionName) Identity) FunctionName
-> [(FilePath, [Node (Lexeme FunctionName)])]
-> State (Set FunctionName) ()
forall text a (f :: * -> *).
(TraverseAst text a, Applicative f) =>
AstActions f text -> a -> f ()
traverseAst AstActions (StateT (Set FunctionName) Identity) FunctionName
nameCollectorActions [(FilePath, [Node (Lexeme FunctionName)])]
tus) Set FunctionName
forall a. Set a
Set.empty
  where
    nameCollectorActions :: AstActions (StateT (Set FunctionName) Identity) FunctionName
nameCollectorActions = AstActions (StateT (Set FunctionName) Identity) FunctionName
forall (f :: * -> *) text. Applicative f => AstActions f text
astActions
        { doNode :: FilePath
-> Node (Lexeme FunctionName)
-> State (Set FunctionName) ()
-> State (Set FunctionName) ()
doNode = \FilePath
_file Node (Lexeme FunctionName)
node State (Set FunctionName) ()
act -> case Node (Lexeme FunctionName)
-> NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
forall (f :: * -> *). Fix f -> f (Fix f)
unFix Node (Lexeme FunctionName)
node of
            C.FunctionDefn Scope
_ (Fix (C.FunctionPrototype Node (Lexeme FunctionName)
_ (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) [Node (Lexeme FunctionName)]
_)) Node (Lexeme FunctionName)
_ -> do
                (Set FunctionName -> Set FunctionName)
-> State (Set FunctionName) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (FunctionName -> Set FunctionName -> Set FunctionName
forall a. Ord a => a -> Set a -> Set a
Set.insert FunctionName
name)
                State (Set FunctionName) ()
act
            C.FunctionDecl Scope
_ (Fix (C.FunctionPrototype Node (Lexeme FunctionName)
_ (C.L AlexPosn
_ LexemeClass
_ FunctionName
name) [Node (Lexeme FunctionName)]
_)) -> do
                (Set FunctionName -> Set FunctionName)
-> State (Set FunctionName) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (FunctionName -> Set FunctionName -> Set FunctionName
forall a. Ord a => a -> Set a -> Set a
Set.insert FunctionName
name)
                State (Set FunctionName) ()
act
            NodeF (Lexeme FunctionName) (Node (Lexeme FunctionName))
_ -> State (Set FunctionName) ()
act
        }

-- | The main function to build the call graph from a list of top-level AST nodes.
buildCallGraph :: [(FilePath, [C.Node (C.Lexeme Text)])] -> PointsToMap -> CallGraph
buildCallGraph :: [(FilePath, [Node (Lexeme FunctionName)])]
-> PointsToMap -> CallGraph
buildCallGraph [(FilePath, [Node (Lexeme FunctionName)])]
tus PointsToMap
pointsToMap =
    let functionNames :: Set FunctionName
functionNames = [(FilePath, [Node (Lexeme FunctionName)])] -> Set FunctionName
collectFunctionNames [(FilePath, [Node (Lexeme FunctionName)])]
tus
        initialState :: BuilderState
initialState = CallGraph
-> Maybe FunctionName
-> Set FunctionName
-> PointsToMap
-> BuilderState
BuilderState CallGraph
forall k a. Map k a
Map.empty Maybe FunctionName
forall a. Maybe a
Nothing Set FunctionName
functionNames PointsToMap
pointsToMap
    in BuilderState -> CallGraph
bsGraph (BuilderState -> CallGraph) -> BuilderState -> CallGraph
forall a b. (a -> b) -> a -> b
$ State BuilderState () -> BuilderState -> BuilderState
forall s a. State s a -> s -> s
execState (AstActions (State BuilderState) FunctionName
-> [(FilePath, [Node (Lexeme FunctionName)])]
-> State BuilderState ()
forall text a (f :: * -> *).
(TraverseAst text a, Applicative f) =>
AstActions f text -> a -> f ()
traverseAst AstActions (State BuilderState) FunctionName
callGraphActions [(FilePath, [Node (Lexeme FunctionName)])]
tus) BuilderState
initialState

-- | Helper to add a call to the graph.
addCall :: FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall :: FunctionName -> FunctionName -> CallSite -> State BuilderState ()
addCall FunctionName
caller FunctionName
callee CallSite
callSite = (BuilderState -> BuilderState) -> State BuilderState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((BuilderState -> BuilderState) -> State BuilderState ())
-> (BuilderState -> BuilderState) -> State BuilderState ()
forall a b. (a -> b) -> a -> b
$ \BuilderState
st ->
    let
        graph :: CallGraph
graph = BuilderState -> CallGraph
bsGraph BuilderState
st
        calleeMap :: Map FunctionName (Set CallSite)
calleeMap = Map FunctionName (Set CallSite)
-> FunctionName -> CallGraph -> Map FunctionName (Set CallSite)
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Map FunctionName (Set CallSite)
forall k a. Map k a
Map.empty FunctionName
caller CallGraph
graph
        callSiteSet :: Set CallSite
callSiteSet = Set CallSite
-> FunctionName -> Map FunctionName (Set CallSite) -> Set CallSite
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Set CallSite
forall a. Set a
Set.empty FunctionName
callee Map FunctionName (Set CallSite)
calleeMap
        newCallSiteSet :: Set CallSite
newCallSiteSet = CallSite -> Set CallSite -> Set CallSite
forall a. Ord a => a -> Set a -> Set a
Set.insert CallSite
callSite Set CallSite
callSiteSet
        newCalleeMap :: Map FunctionName (Set CallSite)
newCalleeMap = FunctionName
-> Set CallSite
-> Map FunctionName (Set CallSite)
-> Map FunctionName (Set CallSite)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert FunctionName
callee Set CallSite
newCallSiteSet Map FunctionName (Set CallSite)
calleeMap
        newGraph :: CallGraph
newGraph = FunctionName
-> Map FunctionName (Set CallSite) -> CallGraph -> CallGraph
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert FunctionName
caller Map FunctionName (Set CallSite)
newCalleeMap CallGraph
graph
    in
        BuilderState
st { bsGraph :: CallGraph
bsGraph = CallGraph
newGraph }

-- | Helper function to get all functions called by a given function.
getCallees :: CallGraph -> FunctionName -> CalleeMap
getCallees :: CallGraph -> FunctionName -> Map FunctionName (Set CallSite)
getCallees CallGraph
graph FunctionName
callerName = Map FunctionName (Set CallSite)
-> FunctionName -> CallGraph -> Map FunctionName (Set CallSite)
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Map FunctionName (Set CallSite)
forall k a. Map k a
Map.empty FunctionName
callerName CallGraph
graph