{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict            #-}
module Tokstyle.C.Linter.Memset (descr) where

import           Control.Monad                   (unless)
import           Data.Functor.Identity           (Identity)
import           Data.Text                       (Text)
import           Language.C.Analysis.AstAnalysis (ExprSide (..), tExpr)
import           Language.C.Analysis.DefTable    (lookupTag)
import           Language.C.Analysis.SemRep      (CompType (..),
                                                  CompTypeRef (..), GlobalDecls,
                                                  MemberDecl (..), TagDef (..),
                                                  Type (..), TypeName (..),
                                                  VarDecl (..))
import           Language.C.Analysis.TravMonad   (MonadTrav, Trav, TravT,
                                                  getDefTable, throwTravError)
import           Language.C.Analysis.TypeUtils   (canonicalType)
import           Language.C.Data.Error           (userErr)
import           Language.C.Data.Ident           (Ident (..))
import qualified Language.C.Pretty               as C
import           Language.C.Syntax.AST           (CExpression (..), annotation)
import           Prettyprinter                   (pretty, (<+>))
import           Tokstyle.C.Env                  (Env, recordLinterError)
import           Tokstyle.C.TraverseAst          (AstActions (..), astActions,
                                                  traverseAst)
import           Tokstyle.C.TravUtils            (backticks)


hasPtrs :: MonadTrav m => Type -> m Bool
hasPtrs :: Type -> m Bool
hasPtrs Type
ty = case Type -> Type
canonicalType Type
ty of
    DirectType (TyComp (CompTypeRef SUERef
name CompTyKind
_ NodeInfo
_)) TypeQuals
_ Attributes
_ -> do
        DefTable
defs <- m DefTable
forall (m :: * -> *). MonadSymtab m => m DefTable
getDefTable
        case SUERef -> DefTable -> Maybe TagEntry
lookupTag SUERef
name DefTable
defs of
            Just (Right (CompDef (CompType SUERef
_ CompTyKind
_ [MemberDecl]
members Attributes
_ NodeInfo
_))) ->
                [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> m [Bool] -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MemberDecl -> m Bool) -> [MemberDecl] -> m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM MemberDecl -> m Bool
forall (m :: * -> *). MonadTrav m => MemberDecl -> m Bool
memberHasPtrs [MemberDecl]
members
            Maybe TagEntry
_ ->
                UserError -> m Bool
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m Bool) -> UserError -> m Bool
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
                    String
"couldn't find struct/union type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (SUERef -> Doc
forall p. Pretty p => p -> Doc
C.pretty SUERef
name) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"
    PtrType{} -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    Type
_ -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

memberHasPtrs :: MonadTrav m => MemberDecl -> m Bool
memberHasPtrs :: MemberDecl -> m Bool
memberHasPtrs (MemberDecl (VarDecl VarName
_ DeclAttrs
_ Type
ty) Maybe Expr
_ NodeInfo
_) = Type -> m Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
hasPtrs Type
ty
memberHasPtrs MemberDecl
_                                 = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False


memsetAllowed :: MonadTrav m => Type -> m Bool
memsetAllowed :: Type -> m Bool
memsetAllowed Type
ty = case Type -> Type
canonicalType Type
ty of
    PtrType Type
pointee TypeQuals
_ Attributes
_ -> Bool -> Bool
not (Bool -> Bool) -> m Bool -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> m Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
hasPtrs Type
pointee
    ArrayType Type
memTy ArraySize
_ TypeQuals
_ Attributes
_ -> Bool -> Bool
not (Bool -> Bool) -> m Bool -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> m Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
hasPtrs Type
memTy
    Type
_ ->
        UserError -> m Bool
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m Bool) -> UserError -> m Bool
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
            String
"value of type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
C.pretty Type
ty) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` cannot be passed to memset"


linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
    { doExpr :: Expr -> TravT Env Identity () -> TravT Env Identity ()
doExpr = \Expr
node TravT Env Identity ()
act -> case Expr
node of
        CCall (CVar (Ident String
"memset" Int
_ NodeInfo
_) NodeInfo
_) [Expr
s, Expr
_, Expr
_] NodeInfo
_ -> do
            Type
ty <- [StmtCtx] -> ExprSide -> Expr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> Expr -> m Type
tExpr [] ExprSide
RValue Expr
s
            Bool
allowed <- Type -> TravT Env Identity Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
memsetAllowed Type
ty
            Bool -> TravT Env Identity () -> TravT Env Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
allowed (TravT Env Identity () -> TravT Env Identity ())
-> TravT Env Identity () -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$
                NodeInfo -> Doc AnsiStyle -> TravT Env Identity ()
recordLinterError (Expr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation Expr
s) (Doc AnsiStyle -> TravT Env Identity ())
-> Doc AnsiStyle -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$
                    Doc AnsiStyle
"disallowed memset argument" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
backticks (String -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty (Doc -> String
forall a. Show a => a -> String
show (Expr -> Doc
forall p. Pretty p => p -> Doc
C.pretty Expr
s))) Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle
"of type"
                     Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
backticks (String -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty (Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
C.pretty Type
ty))) Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall a. Semigroup a => a -> a -> a
<> Doc AnsiStyle
", which contains pointers"
            TravT Env Identity ()
act

        Expr
_ -> TravT Env Identity ()
act
    }


analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> TravT Env Identity ()
analyse = AstActions (TravT Env Identity)
-> GlobalDecls -> TravT Env Identity ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter


descr :: (GlobalDecls -> Trav Env (), (Text, Text))
descr :: (GlobalDecls -> TravT Env Identity (), (Text, Text))
descr = (GlobalDecls -> TravT Env Identity ()
analyse, (Text
"memset", Text
"Checks for memset calls on types that contain pointers."))