{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Tokstyle.Analysis.VTable
    ( VTableMap
    , resolveVTables
    ) where

import           Control.Monad.State.Strict  (State, execState, gets, modify)
import           Data.Fix                    (Fix (..))
import           Data.Map.Strict             (Map)
import qualified Data.Map.Strict             as Map
import           Data.Text                   (Text)
import qualified Language.Cimple             as C
import           Language.Cimple.Pretty      (showNodePlain)
import           Language.Cimple.TraverseAst (AstActions (..), astActions,
                                              traverseAst)
import           Tokstyle.Analysis.Scope     (ScopedId (..))
import           Tokstyle.Common.TypeSystem  (TypeDescr (..), TypeInfo (..),
                                              TypeRef (..), TypeSystem,
                                              lookupType)

-- | A map from a v-table's ScopedId to its field-to-function mapping.
type VTableMap = Map ScopedId (Map Text ScopedId)

data VTableState = VTableState
    { VTableState -> VTableMap
vtsMap        :: VTableMap
    , VTableState -> TypeSystem
vtsTypeSystem :: TypeSystem
    }

-- | Traverses the AST to find global v-tables and resolve their function pointers.
resolveVTables :: [C.Node (C.Lexeme ScopedId)] -> TypeSystem -> VTableMap
resolveVTables :: [Node (Lexeme ScopedId)] -> TypeSystem -> VTableMap
resolveVTables [Node (Lexeme ScopedId)]
ast TypeSystem
typeSystem =
    let initialState :: VTableState
initialState = VTableMap -> TypeSystem -> VTableState
VTableState VTableMap
forall k a. Map k a
Map.empty TypeSystem
typeSystem
    in VTableState -> VTableMap
vtsMap (VTableState -> VTableMap) -> VTableState -> VTableMap
forall a b. (a -> b) -> a -> b
$ State VTableState () -> VTableState -> VTableState
forall s a. State s a -> s -> s
execState (AstActions (State VTableState) ScopedId
-> [Node (Lexeme ScopedId)] -> State VTableState ()
forall text a (f :: * -> *).
(TraverseAst text a, Applicative f) =>
AstActions f text -> a -> f ()
traverseAst AstActions (State VTableState) ScopedId
vtableActions [Node (Lexeme ScopedId)]
ast) VTableState
initialState

vtableActions :: AstActions (State VTableState) ScopedId
vtableActions :: AstActions (State VTableState) ScopedId
vtableActions = AstActions (State VTableState) ScopedId
forall (f :: * -> *) text. Applicative f => AstActions f text
astActions
    { doNode :: FilePath
-> Node (Lexeme ScopedId)
-> State VTableState ()
-> State VTableState ()
doNode = \FilePath
_ Node (Lexeme ScopedId)
node State VTableState ()
act -> do
        case Node (Lexeme ScopedId)
-> NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
forall (f :: * -> *). Fix f -> f (Fix f)
unFix Node (Lexeme ScopedId)
node of
            C.ConstDefn Scope
C.Global Node (Lexeme ScopedId)
ty Lexeme ScopedId
vtableName (Fix (C.InitialiserList [Node (Lexeme ScopedId)]
initializers)) ->
                Node (Lexeme ScopedId)
-> Lexeme ScopedId
-> [Node (Lexeme ScopedId)]
-> State VTableState ()
handleConstDefn Node (Lexeme ScopedId)
ty Lexeme ScopedId
vtableName [Node (Lexeme ScopedId)]
initializers
            C.ConstDefn Scope
C.Static Node (Lexeme ScopedId)
ty Lexeme ScopedId
vtableName (Fix (C.InitialiserList [Node (Lexeme ScopedId)]
initializers)) ->
                Node (Lexeme ScopedId)
-> Lexeme ScopedId
-> [Node (Lexeme ScopedId)]
-> State VTableState ()
handleConstDefn Node (Lexeme ScopedId)
ty Lexeme ScopedId
vtableName [Node (Lexeme ScopedId)]
initializers
            NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
_ -> () -> State VTableState ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        State VTableState ()
act
    }

handleConstDefn :: C.Node (C.Lexeme ScopedId) -> C.Lexeme ScopedId -> [C.Node (C.Lexeme ScopedId)] -> State VTableState ()
handleConstDefn :: Node (Lexeme ScopedId)
-> Lexeme ScopedId
-> [Node (Lexeme ScopedId)]
-> State VTableState ()
handleConstDefn Node (Lexeme ScopedId)
tyNode (C.L AlexPosn
_ LexemeClass
_ ScopedId
vtableSid) [Node (Lexeme ScopedId)]
initializers = do
    TypeSystem
typeSystem <- (VTableState -> TypeSystem)
-> StateT VTableState Identity TypeSystem
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets VTableState -> TypeSystem
vtsTypeSystem
    case Node (Lexeme ScopedId) -> Maybe Text
getTypeName Node (Lexeme ScopedId)
tyNode of
        Just Text
typeName ->
            case Text -> TypeSystem -> Maybe TypeDescr
lookupType Text
typeName TypeSystem
typeSystem of
                Just (StructDescr Lexeme Text
_ [(Lexeme Text, TypeInfo)]
fields) | Bool -> Bool
not ([(Lexeme Text, TypeInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Lexeme Text, TypeInfo)]
fields) Bool -> Bool -> Bool
&& ((Lexeme Text, TypeInfo) -> Bool)
-> [(Lexeme Text, TypeInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (TypeInfo -> Bool
isFuncPtrType (TypeInfo -> Bool)
-> ((Lexeme Text, TypeInfo) -> TypeInfo)
-> (Lexeme Text, TypeInfo)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lexeme Text, TypeInfo) -> TypeInfo
forall a b. (a, b) -> b
snd) [(Lexeme Text, TypeInfo)]
fields -> do
                    let funcInitializers :: [ScopedId]
funcInitializers = (Node (Lexeme ScopedId) -> ScopedId)
-> [Node (Lexeme ScopedId)] -> [ScopedId]
forall a b. (a -> b) -> [a] -> [b]
map Node (Lexeme ScopedId) -> ScopedId
getInitializerFuncSid [Node (Lexeme ScopedId)]
initializers
                    let fieldNames :: [Text]
fieldNames = ((Lexeme Text, TypeInfo) -> Text)
-> [(Lexeme Text, TypeInfo)] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (Lexeme Text -> Text
forall text. Lexeme text -> text
C.lexemeText (Lexeme Text -> Text)
-> ((Lexeme Text, TypeInfo) -> Lexeme Text)
-> (Lexeme Text, TypeInfo)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lexeme Text, TypeInfo) -> Lexeme Text
forall a b. (a, b) -> a
fst) [(Lexeme Text, TypeInfo)]
fields
                    let vtable :: Map Text ScopedId
vtable = [(Text, ScopedId)] -> Map Text ScopedId
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Text, ScopedId)] -> Map Text ScopedId)
-> [(Text, ScopedId)] -> Map Text ScopedId
forall a b. (a -> b) -> a -> b
$ [Text] -> [ScopedId] -> [(Text, ScopedId)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Text]
fieldNames [ScopedId]
funcInitializers
                    (VTableState -> VTableState) -> State VTableState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((VTableState -> VTableState) -> State VTableState ())
-> (VTableState -> VTableState) -> State VTableState ()
forall a b. (a -> b) -> a -> b
$ \VTableState
st -> VTableState
st { vtsMap :: VTableMap
vtsMap = ScopedId -> Map Text ScopedId -> VTableMap -> VTableMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ScopedId
vtableSid Map Text ScopedId
vtable (VTableState -> VTableMap
vtsMap VTableState
st) }
                Maybe TypeDescr
_ -> () -> State VTableState ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Not a struct or not all fields are function pointers
        Maybe Text
_ -> () -> State VTableState ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Could not determine type name

isFuncPtrType :: TypeInfo -> Bool
isFuncPtrType :: TypeInfo -> Bool
isFuncPtrType (Pointer (TypeRef TypeRef
FuncRef Lexeme Text
_)) = Bool
True
isFuncPtrType (Pointer (ExternalType Lexeme Text
_))    = Bool
True -- For typedef'd function pointers
isFuncPtrType (Const TypeInfo
t)                     = TypeInfo -> Bool
isFuncPtrType TypeInfo
t
isFuncPtrType (TypeRef TypeRef
FuncRef Lexeme Text
_)           = Bool
True
isFuncPtrType TypeInfo
_                             = Bool
False

getTypeName :: C.Node (C.Lexeme ScopedId) -> Maybe Text
getTypeName :: Node (Lexeme ScopedId) -> Maybe Text
getTypeName (Fix NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
node) = case NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
node of
    C.TyUserDefined (C.L AlexPosn
_ LexemeClass
_ ScopedId
sid) -> Text -> Maybe Text
forall a. a -> Maybe a
Just (ScopedId -> Text
sidName ScopedId
sid)
    C.TyStruct (C.L AlexPosn
_ LexemeClass
_ ScopedId
sid)      -> Text -> Maybe Text
forall a. a -> Maybe a
Just (ScopedId -> Text
sidName ScopedId
sid)
    C.TyConst Node (Lexeme ScopedId)
t                   -> Node (Lexeme ScopedId) -> Maybe Text
getTypeName Node (Lexeme ScopedId)
t
    NodeF (Lexeme ScopedId) (Node (Lexeme ScopedId))
_                             -> Maybe Text
forall a. Maybe a
Nothing

getInitializerFuncSid :: C.Node (C.Lexeme ScopedId) -> ScopedId
getInitializerFuncSid :: Node (Lexeme ScopedId) -> ScopedId
getInitializerFuncSid (Fix (C.VarExpr (C.L AlexPosn
_ LexemeClass
_ ScopedId
sid))) = ScopedId
sid
getInitializerFuncSid (Fix (C.CastExpr Node (Lexeme ScopedId)
_ Node (Lexeme ScopedId)
expr))       = Node (Lexeme ScopedId) -> ScopedId
getInitializerFuncSid Node (Lexeme ScopedId)
expr
getInitializerFuncSid Node (Lexeme ScopedId)
n                               = FilePath -> ScopedId
forall a. HasCallStack => FilePath -> a
error (FilePath -> ScopedId) -> FilePath -> ScopedId
forall a b. (a -> b) -> a -> b
$ FilePath
"VTable initializer is not a function designator: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Text -> FilePath
forall a. Show a => a -> FilePath
show (Node (Lexeme ScopedId) -> Text
forall a. Pretty a => Node (Lexeme a) -> Text
showNodePlain Node (Lexeme ScopedId)
n)