{-# LANGUAGE AllowAmbiguousTypes    #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE KindSignatures         #-}
{-# LANGUAGE LambdaCase             #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE MultiWayIf             #-}
{-# LANGUAGE OverloadedStrings      #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TupleSections          #-}

-- | This module provides a generic framework for forward data flow analysis
-- on C code, represented by the 'Language.Cimple.Ast'. It includes tools
-- for building a control flow graph (CFG) from a function definition and
-- a fixpoint solver to compute data flow facts.
--
-- The core components are:
--
-- * 'CFG': A control flow graph representation, where nodes contain basic
--   blocks of statements.
-- * 'DataFlow': A type class that defines the specific analysis to be
--   performed (e.g., reaching definitions, liveness analysis).
-- * 'buildCFG': A function to construct a 'CFG' from a 'C.FunctionDefn'.
-- * 'fixpoint': A generic solver that iteratively computes data flow facts
--   until a stable state (fixpoint) is reached.
--
-- To use this module, you need to:
--
-- 1. Define a data type for your data flow facts.
-- 2. Create an instance of the 'DataFlow' type class for your data type,
--    implementing 'emptyFacts', 'transfer', and 'join'.
-- 3. Build the CFG for a function using 'buildCFG'.
-- 4. Run the 'fixpoint' solver on the generated CFG.
-- 5. Extract and use the computed 'cfgInFacts' and 'cfgOutFacts' from the
--    resulting CFG.
module Tokstyle.Analysis.DataFlow
    ( CFGNode (..)
    , CFG
    , DataFlow (..)
    , fixpoint
    , buildCFG
    ) where

import           Control.Monad           (foldM)
import           Data.Fix                (Fix (Fix, unFix))
import           Data.Foldable           (foldl')
import           Data.Kind               (Type)
import           Data.Map.Strict         (Map)
import qualified Data.Map.Strict         as Map
import           Data.Maybe              (mapMaybe)
import           Data.Set                (Set)
import qualified Data.Set                as Set
import           Data.String             (IsString)
import           Debug.Trace             (trace)
import           Language.Cimple         (NodeF (..))
import qualified Language.Cimple         as C
import           Language.Cimple.Pretty  (showNodePlain)
import           Prettyprinter           (Pretty (..))
import           Text.Groom              (groom)
import qualified Tokstyle.Analysis.CFG   as CFGBuilder
import           Tokstyle.Analysis.Types (lookupOrError)
import           Tokstyle.Worklist

debugging :: Bool
debugging :: Bool
debugging = Bool
False

dtrace :: String -> a -> a
dtrace :: String -> a -> a
dtrace String
msg a
x = if Bool
debugging then String -> a -> a
forall a. String -> a -> a
trace String
msg a
x else a
x

-- | A node in the control flow graph. Each node represents a basic block
-- of statements.
data CFGNode l a = CFGNode
    { CFGNode l a -> Int
cfgNodeId   :: Int -- ^ A unique identifier for the node.
    , CFGNode l a -> [Int]
cfgPreds    :: [Int] -- ^ A list of predecessor node IDs.
    , CFGNode l a -> [Int]
cfgSuccs    :: [Int] -- ^ A list of successor node IDs.
    , CFGNode l a -> [Node (Lexeme l)]
cfgStmts    :: [C.Node (C.Lexeme l)] -- ^ The statements in this basic block.
    , CFGNode l a -> a
cfgInFacts  :: a -- ^ The data flow facts at the entry of this node.
    , CFGNode l a -> a
cfgOutFacts :: a -- ^ The data flow facts at the exit of this node.
    }
    deriving (Int -> CFGNode l a -> ShowS
[CFGNode l a] -> ShowS
CFGNode l a -> String
(Int -> CFGNode l a -> ShowS)
-> (CFGNode l a -> String)
-> ([CFGNode l a] -> ShowS)
-> Show (CFGNode l a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall l a. (Show l, Show a) => Int -> CFGNode l a -> ShowS
forall l a. (Show l, Show a) => [CFGNode l a] -> ShowS
forall l a. (Show l, Show a) => CFGNode l a -> String
showList :: [CFGNode l a] -> ShowS
$cshowList :: forall l a. (Show l, Show a) => [CFGNode l a] -> ShowS
show :: CFGNode l a -> String
$cshow :: forall l a. (Show l, Show a) => CFGNode l a -> String
showsPrec :: Int -> CFGNode l a -> ShowS
$cshowsPrec :: forall l a. (Show l, Show a) => Int -> CFGNode l a -> ShowS
Show, CFGNode l a -> CFGNode l a -> Bool
(CFGNode l a -> CFGNode l a -> Bool)
-> (CFGNode l a -> CFGNode l a -> Bool) -> Eq (CFGNode l a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall l a. (Eq l, Eq a) => CFGNode l a -> CFGNode l a -> Bool
/= :: CFGNode l a -> CFGNode l a -> Bool
$c/= :: forall l a. (Eq l, Eq a) => CFGNode l a -> CFGNode l a -> Bool
== :: CFGNode l a -> CFGNode l a -> Bool
$c== :: forall l a. (Eq l, Eq a) => CFGNode l a -> CFGNode l a -> Bool
Eq)

-- | The Control Flow Graph is a map from node IDs to 'CFGNode's.
type CFG l a = Map Int (CFGNode l a)

-- | A type class for data flow analysis. Users of this framework must
-- provide an instance of this class for their specific analysis.
class (Eq a, Show a, Monad m, Ord callCtx) => DataFlow m (c :: Type -> Type) l a callCtx | a -> l, a -> callCtx where
    -- | The facts for an empty basic block.
    emptyFacts :: c l -> m a
    -- | The transfer function defines how a single statement affects the
    -- data flow facts. It takes the facts before the statement and
    -- returns the facts after the statement, plus any new work discovered.
    transfer :: c l -> l -> Int -> a -> C.Node (C.Lexeme l) -> m (a, Set (l, callCtx))
    -- | The join operator combines facts from multiple predecessor nodes.
    -- This is used at control flow merge points (e.g., after an if-statement
    -- or at the start of a loop).
    join :: c l -> a -> a -> m a

-- | A generic fixpoint solver for forward data flow analysis. This function
-- iteratively applies the transfer function to each node in the CFG until
-- the data flow facts no longer change. It uses a worklist algorithm for
-- efficiency, and returns the final CFG along with any new work discovered.
fixpoint :: forall m c l a callCtx. (DataFlow m c l a callCtx, Show l, Ord l) => c l -> l -> CFG l a -> m (CFG l a, Set (l, callCtx))
fixpoint :: c l -> l -> CFG l a -> m (CFG l a, Set (l, callCtx))
fixpoint c l
ctx l
funcName (CFG l a
cfg :: CFG l a) =
    let
        worklist :: Worklist Int
worklist = [Int] -> Worklist Int
forall a. [a] -> Worklist a
fromList (CFG l a -> [Int]
forall k a. Map k a -> [k]
Map.keys CFG l a
cfg)
    in
        Worklist Int
-> CFG l a -> Set (l, callCtx) -> m (CFG l a, Set (l, callCtx))
go Worklist Int
worklist CFG l a
cfg Set (l, callCtx)
forall a. Set a
Set.empty
    where
        go :: Worklist Int -> CFG l a -> Set (l, callCtx) -> m (CFG l a, Set (l, callCtx))
        go :: Worklist Int
-> CFG l a -> Set (l, callCtx) -> m (CFG l a, Set (l, callCtx))
go Worklist Int
worklist CFG l a
cfg' Set (l, callCtx)
accumulatedWork
            | Just (Int
currentId, Worklist Int
worklist') <- Worklist Int -> Maybe (Int, Worklist Int)
forall a. Worklist a -> Maybe (a, Worklist a)
pop Worklist Int
worklist = do
                let node :: CFGNode l a
node = String -> CFG l a -> Int -> CFGNode l a
forall k a. (Ord k, Show k) => String -> Map k a -> k -> a
lookupOrError String
"fixpoint" CFG l a
cfg' Int
currentId
                let predNodes :: [CFGNode l a]
predNodes = (Int -> Maybe (CFGNode l a)) -> [Int] -> [CFGNode l a]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Int -> CFG l a -> Maybe (CFGNode l a)
forall k a. Ord k => k -> Map k a -> Maybe a
`Map.lookup` CFG l a
cfg') (CFGNode l a -> [Int]
forall l a. CFGNode l a -> [Int]
cfgPreds CFGNode l a
node)

                a
inFacts' <- if [CFGNode l a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CFGNode l a]
predNodes
                                then a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ CFGNode l a -> a
forall l a. CFGNode l a -> a
cfgInFacts CFGNode l a
node
                                else (a -> a -> m a) -> a -> [a] -> m a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (c l -> a -> a -> m a
forall (m :: * -> *) (c :: * -> *) l a callCtx.
DataFlow m c l a callCtx =>
c l -> a -> a -> m a
join c l
ctx) (CFGNode l a -> a
forall l a. CFGNode l a -> a
cfgOutFacts ([CFGNode l a] -> CFGNode l a
forall a. [a] -> a
head [CFGNode l a]
predNodes)) ((CFGNode l a -> a) -> [CFGNode l a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map CFGNode l a -> a
forall l a. CFGNode l a -> a
cfgOutFacts ([CFGNode l a] -> [CFGNode l a]
forall a. [a] -> [a]
tail [CFGNode l a]
predNodes))

                (a
outFacts', Set (l, callCtx)
blockWork) <-
                    ((a, Set (l, callCtx))
 -> Node (Lexeme l) -> m (a, Set (l, callCtx)))
-> (a, Set (l, callCtx))
-> [Node (Lexeme l)]
-> m (a, Set (l, callCtx))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
                        (\(a
accFacts, Set (l, callCtx)
accWork) Node (Lexeme l)
stmt -> do
                            (a
newFacts, Set (l, callCtx)
newWork) <- c l -> l -> Int -> a -> Node (Lexeme l) -> m (a, Set (l, callCtx))
forall (m :: * -> *) (c :: * -> *) l a callCtx.
DataFlow m c l a callCtx =>
c l -> l -> Int -> a -> Node (Lexeme l) -> m (a, Set (l, callCtx))
transfer c l
ctx l
funcName (CFGNode l a -> Int
forall l a. CFGNode l a -> Int
cfgNodeId CFGNode l a
node) (String -> a -> a
forall a. String -> a -> a
dtrace (String
"fixpoint fold: accFacts=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
accFacts) a
accFacts) Node (Lexeme l)
stmt
                            (a, Set (l, callCtx)) -> m (a, Set (l, callCtx))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
newFacts, Set (l, callCtx) -> Set (l, callCtx) -> Set (l, callCtx)
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set (l, callCtx)
accWork Set (l, callCtx)
newWork))
                        (a
inFacts', Set (l, callCtx)
forall a. Set a
Set.empty)
                        (CFGNode l a -> [Node (Lexeme l)]
forall l a. CFGNode l a -> [Node (Lexeme l)]
cfgStmts CFGNode l a
node)

                let outFactsChanged :: Bool
outFactsChanged = a
outFacts' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= CFGNode l a -> a
forall l a. CFGNode l a -> a
cfgOutFacts CFGNode l a
node
                let cfg'' :: CFG l a
cfg'' = String -> CFG l a -> CFG l a
forall a. String -> a -> a
dtrace ([String] -> String
unlines [ String
"fixpoint (" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> l -> String
forall a. Show a => a -> String
show l
funcName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
", node " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
currentId String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"):"
                                            , String
"  inFacts': " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
groom a
inFacts'
                                            , String
"  outFacts': " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
groom a
outFacts'
                                            , String
"  old outFacts: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
groom (CFGNode l a -> a
forall l a. CFGNode l a -> a
cfgOutFacts CFGNode l a
node)
                                            , String
"  outFactsChanged: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Bool -> String
forall a. Show a => a -> String
show Bool
outFactsChanged
                                            ]) (CFG l a -> CFG l a) -> CFG l a -> CFG l a
forall a b. (a -> b) -> a -> b
$ Int -> CFGNode l a -> CFG l a -> CFG l a
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
currentId (CFGNode l a
node { cfgInFacts :: a
cfgInFacts = a
inFacts', cfgOutFacts :: a
cfgOutFacts = a
outFacts' }) CFG l a
cfg'
                let worklist'' :: Worklist Int
worklist'' = if Bool
outFactsChanged
                        then [Int] -> Worklist Int -> Worklist Int
forall a. [a] -> Worklist a -> Worklist a
pushList (CFGNode l a -> [Int]
forall l a. CFGNode l a -> [Int]
cfgSuccs CFGNode l a
node) Worklist Int
worklist'
                        else Worklist Int
worklist'
                let accumulatedWork' :: Set (l, callCtx)
accumulatedWork' = Set (l, callCtx) -> Set (l, callCtx) -> Set (l, callCtx)
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set (l, callCtx)
accumulatedWork Set (l, callCtx)
blockWork
                Worklist Int
-> CFG l a -> Set (l, callCtx) -> m (CFG l a, Set (l, callCtx))
go Worklist Int
worklist'' CFG l a
cfg'' Set (l, callCtx)
accumulatedWork'
            | Bool
otherwise = (CFG l a, Set (l, callCtx)) -> m (CFG l a, Set (l, callCtx))
forall (m :: * -> *) a. Monad m => a -> m a
return (CFG l a
cfg', Set (l, callCtx)
accumulatedWork)

-- | Build a control flow graph for a function definition. This is the main
-- entry point for constructing a CFG from a Cimple AST.
buildCFG :: forall m c l a callCtx. (DataFlow m c l a callCtx, Pretty l, Ord l, Show l, IsString l) => c l -> C.Node (C.Lexeme l) -> a -> m (CFG l a)
buildCFG :: c l -> Node (Lexeme l) -> a -> m (CFG l a)
buildCFG c l
ctx cNode :: Node (Lexeme l)
cNode@(Fix (C.FunctionDefn Scope
_ (Fix (C.FunctionPrototype Node (Lexeme l)
_ (C.L AlexPosn
_ LexemeClass
_ l
funcName) [Node (Lexeme l)]
_)) Node (Lexeme l)
_)) a
initialFacts = do
    let structuralCFG :: CFG l
structuralCFG = Node (Lexeme l) -> CFG l
forall l.
(Pretty l, Ord l, Show l, IsString l) =>
Node (Lexeme l) -> CFG l
CFGBuilder.buildCFG Node (Lexeme l)
cNode

    let addFacts :: Int -> CFGBuilder.CFGNode l -> m (CFGNode l a)
        addFacts :: Int -> CFGNode l -> m (CFGNode l a)
addFacts Int
nodeId CFGNode l
structuralNode = do
            a
facts <- if Int
nodeId Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
initialFacts else c l -> m a
forall (m :: * -> *) (c :: * -> *) l a callCtx.
DataFlow m c l a callCtx =>
c l -> m a
emptyFacts c l
ctx
            CFGNode l a -> m (CFGNode l a)
forall (m :: * -> *) a. Monad m => a -> m a
return (CFGNode l a -> m (CFGNode l a)) -> CFGNode l a -> m (CFGNode l a)
forall a b. (a -> b) -> a -> b
$ CFGNode :: forall l a.
Int -> [Int] -> [Int] -> [Node (Lexeme l)] -> a -> a -> CFGNode l a
CFGNode
                    { cfgNodeId :: Int
cfgNodeId   = CFGNode l -> Int
forall l. CFGNode l -> Int
CFGBuilder.cfgNodeId CFGNode l
structuralNode
                    , cfgPreds :: [Int]
cfgPreds    = CFGNode l -> [Int]
forall l. CFGNode l -> [Int]
CFGBuilder.cfgPreds CFGNode l
structuralNode
                    , cfgSuccs :: [Int]
cfgSuccs    = CFGNode l -> [Int]
forall l. CFGNode l -> [Int]
CFGBuilder.cfgSuccs CFGNode l
structuralNode
                    , cfgStmts :: [Node (Lexeme l)]
cfgStmts    = CFGNode l -> [Node (Lexeme l)]
forall l. CFGNode l -> [Node (Lexeme l)]
CFGBuilder.cfgStmts CFGNode l
structuralNode
                    , cfgInFacts :: a
cfgInFacts  = a
facts
                    , cfgOutFacts :: a
cfgOutFacts = a
facts
                    }

    CFG l a
dfaCFG <- (Int -> CFGNode l -> m (CFGNode l a)) -> CFG l -> m (CFG l a)
forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
Map.traverseWithKey Int -> CFGNode l -> m (CFGNode l a)
addFacts CFG l
structuralCFG
    CFG l a -> m (CFG l a)
forall (m :: * -> *) a. Monad m => a -> m a
return (CFG l a -> m (CFG l a)) -> CFG l a -> m (CFG l a)
forall a b. (a -> b) -> a -> b
$ String -> CFG l a -> CFG l a
forall a. String -> a -> a
dtrace (String
"\n--- CFG for " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> l -> String
forall a. Show a => a -> String
show l
funcName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" ---\n" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Map Int (Int, [Int], [Int], [Text]) -> String
forall a. Show a => a -> String
show ((CFGNode l a -> (Int, [Int], [Int], [Text]))
-> CFG l a -> Map Int (Int, [Int], [Int], [Text])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\CFGNode l a
n -> (CFGNode l a -> Int
forall l a. CFGNode l a -> Int
cfgNodeId CFGNode l a
n, CFGNode l a -> [Int]
forall l a. CFGNode l a -> [Int]
cfgPreds CFGNode l a
n, CFGNode l a -> [Int]
forall l a. CFGNode l a -> [Int]
cfgSuccs CFGNode l a
n, (Node (Lexeme l) -> Text) -> [Node (Lexeme l)] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Node (Lexeme l) -> Text
forall a. Pretty a => Node (Lexeme a) -> Text
showNodePlain (CFGNode l a -> [Node (Lexeme l)]
forall l a. CFGNode l a -> [Node (Lexeme l)]
cfgStmts CFGNode l a
n))) CFG l a
dfaCFG)) CFG l a
dfaCFG
buildCFG c l
_ Node (Lexeme l)
_ a
_ = CFG l a -> m (CFG l a)
forall (m :: * -> *) a. Monad m => a -> m a
return CFG l a
forall k a. Map k a
Map.empty