{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE OverloadedStrings #-}
module Tokstyle.Linter.PointsTo (descr) where

import           Control.Monad                       (foldM)
import           Control.Monad.State.Strict          (get, runState)
import           Data.Fix
import           Data.IntMap.Strict                  (IntMap)
import qualified Data.IntMap.Strict                  as IntMap
import           Data.IntSet                         (IntSet)
import qualified Data.IntSet                         as IntSet
import           Data.List                           (foldl', splitAt)
import qualified Data.Map                            as Map
import           Data.Maybe                          (fromMaybe)
import           Data.Set                            (Set)
import qualified Data.Set                            as Set
import           Data.Text                           (Text)
import qualified Data.Text                           as Text
import qualified Language.Cimple                     as C
import qualified Language.Cimple.Diagnostics         as Diagnostics
import           Language.Cimple.TraverseAst
import           Tokstyle.Analysis.DataFlow          (CFGNode (..), buildCFG,
                                                      fixpoint, transfer)
import           Tokstyle.Analysis.PointsTo          (evalExpr)
import           Tokstyle.Analysis.PointsTo.Fixpoint (findEntryPointsAndFuncMap,
                                                      runGlobalFixpoint)
import           Tokstyle.Analysis.PointsTo.Types
import           Tokstyle.Analysis.Scope             (ScopedId (..),
                                                      runScopePass)
import           Tokstyle.Analysis.VTable            (resolveVTables)
import           Tokstyle.Common.TypeSystem          (collect)

analyse :: [(FilePath, [C.Node (C.Lexeme Text)])] -> [Text]
analyse :: [(FilePath, [Node (Lexeme Text)])] -> [Text]
analyse [(FilePath, [Node (Lexeme Text)])]
sources =
    let
        -- 1. Setup
        flatAst :: [Node (Lexeme Text)]
flatAst = ((FilePath, [Node (Lexeme Text)]) -> [Node (Lexeme Text)])
-> [(FilePath, [Node (Lexeme Text)])] -> [Node (Lexeme Text)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (FilePath, [Node (Lexeme Text)]) -> [Node (Lexeme Text)]
forall a b. (a, b) -> b
snd [(FilePath, [Node (Lexeme Text)])]
sources
        ([Node (Lexeme ScopedId)]
scopedAsts, ScopeState
_) = [Node (Lexeme Text)] -> ([Node (Lexeme ScopedId)], ScopeState)
runScopePass [Node (Lexeme Text)]
flatAst
        typeSystem :: TypeSystem
typeSystem = [(FilePath, [Node (Lexeme Text)])] -> TypeSystem
collect ((FilePath
"test.c", [Node (Lexeme Text)]
flatAst) (FilePath, [Node (Lexeme Text)])
-> [(FilePath, [Node (Lexeme Text)])]
-> [(FilePath, [Node (Lexeme Text)])]
forall a. a -> [a] -> [a]
: ((FilePath, [Node (Lexeme Text)])
 -> (FilePath, [Node (Lexeme Text)]))
-> [(FilePath, [Node (Lexeme Text)])]
-> [(FilePath, [Node (Lexeme Text)])]
forall a b. (a -> b) -> [a] -> [b]
map (\(FilePath
fp, [Node (Lexeme Text)]
ast) -> (FilePath
fp, [Node (Lexeme Text)]
ast)) [(FilePath, [Node (Lexeme Text)])]
sources)
        vtableMap :: VTableMap
vtableMap = [Node (Lexeme ScopedId)] -> TypeSystem -> VTableMap
resolveVTables [Node (Lexeme ScopedId)]
scopedAsts TypeSystem
typeSystem
        ([ScopedId]
_, Map ScopedId [Node (Lexeme ScopedId)]
funcMap) = [Node (Lexeme ScopedId)]
-> ([ScopedId], Map ScopedId [Node (Lexeme ScopedId)])
findEntryPointsAndFuncMap [Node (Lexeme ScopedId)]
scopedAsts

        -- 2. Run global points-to analysis
        filePath :: FilePath
filePath = (FilePath, [Node (Lexeme Text)]) -> FilePath
forall a b. (a, b) -> a
fst ([(FilePath, [Node (Lexeme Text)])]
-> (FilePath, [Node (Lexeme Text)])
forall a. [a] -> a
head [(FilePath, [Node (Lexeme Text)])]
sources)
        dummyId :: ScopedId
dummyId = Int -> Text -> Scope -> ScopedId
ScopedId Int
0 Text
"" Scope
C.Global
        ctx :: PointsToContext l
ctx = FilePath
-> TypeSystem
-> VTableMap
-> GlobalEnv
-> Map ScopedId [Node (Lexeme ScopedId)]
-> ScopedId
-> Map ScopedId (Node (Lexeme ScopedId))
-> PointsToContext l
forall l.
FilePath
-> TypeSystem
-> VTableMap
-> GlobalEnv
-> Map ScopedId [Node (Lexeme ScopedId)]
-> ScopedId
-> Map ScopedId (Node (Lexeme ScopedId))
-> PointsToContext l
PointsToContext FilePath
filePath TypeSystem
typeSystem VTableMap
vtableMap (Map (ScopedId, RelevantInputState) (FunctionSummary, PointsToFact)
-> GlobalEnv
GlobalEnv Map (ScopedId, RelevantInputState) (FunctionSummary, PointsToFact)
forall k a. Map k a
Map.empty) Map ScopedId [Node (Lexeme ScopedId)]
funcMap ScopedId
dummyId Map ScopedId (Node (Lexeme ScopedId))
forall k a. Map k a
Map.empty
        (GlobalEnv
gEnv, CallGraph
_, CFGCache
cfgCache, MemLocPool
pool) = PointsToContext ScopedId
-> [Node (Lexeme ScopedId)]
-> (GlobalEnv, CallGraph, CFGCache, MemLocPool)
runGlobalFixpoint PointsToContext ScopedId
forall l. PointsToContext l
ctx [Node (Lexeme ScopedId)]
scopedAsts

        -- 3. Lint each function using the analysis results
        (GlobalEnv Map (ScopedId, RelevantInputState) (FunctionSummary, PointsToFact)
globalEnvMap) = GlobalEnv
gEnv
        allFuncContextPairs :: [(ScopedId, RelevantInputState)]
allFuncContextPairs = Map (ScopedId, RelevantInputState) (FunctionSummary, PointsToFact)
-> [(ScopedId, RelevantInputState)]
forall k a. Map k a -> [k]
Map.keys Map (ScopedId, RelevantInputState) (FunctionSummary, PointsToFact)
globalEnvMap

        lintFunction :: (ScopedId, RelevantInputState) -> [Text]
lintFunction (ScopedId
funcId, RelevantInputState
relevantState) =
            case (ScopedId, RelevantInputState)
-> CFGCache
-> Maybe ([CFG ScopedId PointsToFact], Map ScopedId Int)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (ScopedId
funcId, RelevantInputState
relevantState) CFGCache
cfgCache of
                Just ([CFG ScopedId PointsToFact]
cfgs, Map ScopedId Int
_) ->
                    let lintCtx :: PointsToContext l
lintCtx = PointsToContext Any
forall l. PointsToContext l
ctx { pcGlobalEnv :: GlobalEnv
pcGlobalEnv = GlobalEnv
gEnv, pcCurrentFunc :: ScopedId
pcCurrentFunc = ScopedId
funcId }
                    in (CFG ScopedId PointsToFact -> [Text])
-> [CFG ScopedId PointsToFact] -> [Text]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\CFG ScopedId PointsToFact
cfg -> (CFGNode ScopedId PointsToFact -> [Text])
-> [CFGNode ScopedId PointsToFact] -> [Text]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PointsToContext ScopedId
-> MemLocPool -> CFGNode ScopedId PointsToFact -> [Text]
checkNode PointsToContext ScopedId
forall l. PointsToContext l
lintCtx MemLocPool
pool) (CFG ScopedId PointsToFact -> [CFGNode ScopedId PointsToFact]
forall k a. Map k a -> [a]
Map.elems CFG ScopedId PointsToFact
cfg)) [CFG ScopedId PointsToFact]
cfgs
                Maybe ([CFG ScopedId PointsToFact], Map ScopedId Int)
Nothing -> []

        lintResults :: [Text]
lintResults = ((ScopedId, RelevantInputState) -> [Text])
-> [(ScopedId, RelevantInputState)] -> [Text]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ScopedId, RelevantInputState) -> [Text]
lintFunction [(ScopedId, RelevantInputState)]
allFuncContextPairs
    in
        [Text]
lintResults

checkNode :: PointsToContext ScopedId -> MemLocPool -> CFGNode ScopedId PointsToFact -> [Text]
checkNode :: PointsToContext ScopedId
-> MemLocPool -> CFGNode ScopedId PointsToFact -> [Text]
checkNode PointsToContext ScopedId
ctx MemLocPool
pool CFGNode ScopedId PointsToFact
node =
    let
        folder :: (PointsToFact, [Text])
-> Node (Lexeme ScopedId)
-> StateT MemLocPool Identity (PointsToFact, [Text])
folder (PointsToFact
facts, [Text]
diags) Node (Lexeme ScopedId)
stmt = do
            (PointsToFact
nextFacts, Set (ScopedId, RelevantInputState)
_) <- PointsToContext ScopedId
-> ScopedId
-> Int
-> PointsToFact
-> Node (Lexeme ScopedId)
-> StateT
     MemLocPool
     Identity
     (PointsToFact, Set (ScopedId, RelevantInputState))
forall (m :: * -> *) (c :: * -> *) l a callCtx.
DataFlow m c l a callCtx =>
c l -> l -> Int -> a -> Node (Lexeme l) -> m (a, Set (l, callCtx))
transfer PointsToContext ScopedId
ctx (PointsToContext ScopedId -> ScopedId
forall l. PointsToContext l -> ScopedId
pcCurrentFunc PointsToContext ScopedId
ctx) (CFGNode ScopedId PointsToFact -> Int
forall l a. CFGNode l a -> Int
cfgNodeId CFGNode ScopedId PointsToFact
node) PointsToFact
facts Node (Lexeme ScopedId)
stmt
            [Text]
newDiags <- PointsToContext ScopedId
-> Int
-> PointsToFact
-> Node (Lexeme ScopedId)
-> PointsToAnalysis [Text]
checkStmt PointsToContext ScopedId
ctx (CFGNode ScopedId PointsToFact -> Int
forall l a. CFGNode l a -> Int
cfgNodeId CFGNode ScopedId PointsToFact
node) PointsToFact
facts Node (Lexeme ScopedId)
stmt
            (PointsToFact, [Text])
-> StateT MemLocPool Identity (PointsToFact, [Text])
forall (m :: * -> *) a. Monad m => a -> m a
return (PointsToFact
nextFacts, [Text]
diags [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++ [Text]
newDiags)
        ((PointsToFact
_, [Text]
allDiags), MemLocPool
_) = StateT MemLocPool Identity (PointsToFact, [Text])
-> MemLocPool -> ((PointsToFact, [Text]), MemLocPool)
forall s a. State s a -> s -> (a, s)
runState (((PointsToFact, [Text])
 -> Node (Lexeme ScopedId)
 -> StateT MemLocPool Identity (PointsToFact, [Text]))
-> (PointsToFact, [Text])
-> [Node (Lexeme ScopedId)]
-> StateT MemLocPool Identity (PointsToFact, [Text])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (PointsToFact, [Text])
-> Node (Lexeme ScopedId)
-> StateT MemLocPool Identity (PointsToFact, [Text])
folder (CFGNode ScopedId PointsToFact -> PointsToFact
forall l a. CFGNode l a -> a
cfgInFacts CFGNode ScopedId PointsToFact
node, []) (CFGNode ScopedId PointsToFact -> [Node (Lexeme ScopedId)]
forall l a. CFGNode l a -> [Node (Lexeme l)]
cfgStmts CFGNode ScopedId PointsToFact
node)) MemLocPool
pool
    in
        [Text]
allDiags

checkStmt :: PointsToContext ScopedId -> Int -> PointsToFact -> C.Node (C.Lexeme ScopedId) -> PointsToAnalysis [Text]
checkStmt :: PointsToContext ScopedId
-> Int
-> PointsToFact
-> Node (Lexeme ScopedId)
-> PointsToAnalysis [Text]
checkStmt PointsToContext ScopedId
ctx Int
nodeId PointsToFact
facts Node (Lexeme ScopedId)
stmt =
    case Node (Lexeme ScopedId)
-> NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
forall (f :: * -> *). Fix f -> f (Fix f)
unFix Node (Lexeme ScopedId)
stmt of
        C.VarDeclStmt Node (Lexeme ScopedId)
_ (Just Node (Lexeme ScopedId)
rhs)              -> Node (Lexeme ScopedId) -> PointsToAnalysis [Text]
checkRhs Node (Lexeme ScopedId)
rhs
        C.ExprStmt (Fix (C.AssignExpr Node (Lexeme ScopedId)
_ AssignOp
_ Node (Lexeme ScopedId)
rhs)) -> Node (Lexeme ScopedId) -> PointsToAnalysis [Text]
checkRhs Node (Lexeme ScopedId)
rhs
        C.Return (Just Node (Lexeme ScopedId)
expr)                    -> Node (Lexeme ScopedId) -> PointsToAnalysis [Text]
checkRhs Node (Lexeme ScopedId)
expr
        NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
_                                       -> [Text] -> PointsToAnalysis [Text]
forall (m :: * -> *) a. Monad m => a -> m a
return []
  where
    checkRhs :: Node (Lexeme ScopedId) -> PointsToAnalysis [Text]
checkRhs Node (Lexeme ScopedId)
rhs =
        let
            -- Helper to find the actual function call's callee expression
            findCallee :: C.Node (C.Lexeme ScopedId) -> Maybe (C.Node (C.Lexeme ScopedId))
            findCallee :: Node (Lexeme ScopedId) -> Maybe (Node (Lexeme ScopedId))
findCallee (Fix (C.FunctionCall Node (Lexeme ScopedId)
callee [Node (Lexeme ScopedId)]
_)) = Node (Lexeme ScopedId) -> Maybe (Node (Lexeme ScopedId))
forall a. a -> Maybe a
Just Node (Lexeme ScopedId)
callee
            findCallee (Fix (C.CastExpr Node (Lexeme ScopedId)
_ Node (Lexeme ScopedId)
inner))      = Node (Lexeme ScopedId) -> Maybe (Node (Lexeme ScopedId))
findCallee Node (Lexeme ScopedId)
inner
            findCallee Node (Lexeme ScopedId)
_                               = Maybe (Node (Lexeme ScopedId))
forall a. Maybe a
Nothing
        in
            case Node (Lexeme ScopedId) -> Maybe (Node (Lexeme ScopedId))
findCallee Node (Lexeme ScopedId)
rhs of
                Just (Fix (C.VarExpr lexeme :: Lexeme ScopedId
lexeme@(C.L AlexPosn
_ LexemeClass
_ ScopedId
sid))) -> do
                    IntSet
returnValueLocs <- PointsToFact
-> PointsToContext ScopedId
-> Int
-> Node (Lexeme ScopedId)
-> PointsToAnalysis IntSet
evalExpr PointsToFact
facts PointsToContext ScopedId
ctx Int
nodeId Node (Lexeme ScopedId)
rhs
                    MemLocPool
pool <- StateT MemLocPool Identity MemLocPool
forall s (m :: * -> *). MonadState s m => m s
get
                    IMemLoc
unknownLoc <- MemLoc -> PointsToAnalysis IMemLoc
intern MemLoc
UnknownLoc
                    let isUnresolved :: Bool
isUnresolved = Int -> IntSet -> Bool
IntSet.member (IMemLoc -> Int
unIMemLoc IMemLoc
unknownLoc) IntSet
returnValueLocs
                    if Bool
isUnresolved then
                        [Text] -> PointsToAnalysis [Text]
forall (m :: * -> *) a. Monad m => a -> m a
return [FilePath -> Lexeme ScopedId -> Text
forall a. HasLocation a => FilePath -> a -> Text
Diagnostics.sloc (PointsToContext ScopedId -> FilePath
forall l. PointsToContext l -> FilePath
pcFilePath PointsToContext ScopedId
ctx) Lexeme ScopedId
lexeme Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": The return value of function '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ScopedId -> Text
sidName ScopedId
sid Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' is used here, but its value could not be determined by the analysis."]
                    else
                        [Text] -> PointsToAnalysis [Text]
forall (m :: * -> *) a. Monad m => a -> m a
return []
                Maybe (Node (Lexeme ScopedId))
_ -> [Text] -> PointsToAnalysis [Text]
forall (m :: * -> *) a. Monad m => a -> m a
return [] -- Not a direct call or a call we want to check

descr :: ([(FilePath, [C.Node (C.Lexeme Text)])] -> [Text], (Text, Text))
descr :: ([(FilePath, [Node (Lexeme Text)])] -> [Text], (Text, Text))
descr = ([(FilePath, [Node (Lexeme Text)])] -> [Text]
analyse, (Text
"points-to", [Text] -> Text
Text.unlines
    [ Text
"Checks for use of return values from unsummarized external functions."
    , Text
""
    , Text
"**Reason:** Calling an external function that is not summarized and using its"
    , Text
"return value can lead to a loss of precision in the points-to analysis,"
    , Text
"potentially hiding bugs. It's better to provide a summary for the function."
    ]))