{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict            #-}
module Tokstyle.C.Linter.SizeArg (descr) where

import           Data.Functor.Identity           (Identity)
import qualified Data.List                       as List
import qualified Data.Map                        as Map
import           Data.Text                       (Text)
import           Language.C.Analysis.AstAnalysis (ExprSide (..), defaultMD,
                                                  tExpr)
import           Language.C.Analysis.ConstEval   (constEval, intValue)
import           Language.C.Analysis.SemError    (invalidAST)
import           Language.C.Analysis.SemRep      (GlobalDecls, ParamDecl (..),
                                                  Type (..))
import           Language.C.Analysis.TravMonad   (Trav, TravT, catchTravError,
                                                  throwTravError)
import           Language.C.Analysis.TypeUtils   (canonicalType)
import           Language.C.Data.Ident           (Ident (..))
import qualified Language.C.Pretty               as C
import           Language.C.Syntax.AST           (CExpr, CExpression (..),
                                                  annotation)
import           Prettyprinter                   (pretty, (<+>))
import           Tokstyle.C.Env                  (Env, recordLinterError)
import           Tokstyle.C.Patterns
import           Tokstyle.C.TraverseAst          (AstActions (..), astActions,
                                                  traverseAst)
import           Tokstyle.C.TravUtils            (backticks)


checkArraySizes :: Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes :: Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId ((ParamDecl
_, CExpr
_, arrTy :: Type
arrTy@(ArrayTypeSize CExpr
arrSize)):(ParamName String
sizeParam, CExpr
sizeArg, Type
_):[(ParamDecl, CExpr, Type)]
args)
    | (String -> Bool) -> [String] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`List.isInfixOf` String
sizeParam) [String
"size", String
"len"] =
        -- Ignore any name lookup errors here. VLAs have locally defined
        -- array sizes, but we don't check VLAs.
        Trav Env () -> (CError -> Trav Env ()) -> Trav Env ()
forall (m :: * -> *) a.
MonadCError m =>
m a -> (CError -> m a) -> m a
catchTravError (do
            Maybe Integer
arrSizeVal <- CExpr -> Maybe Integer
intValue (CExpr -> Maybe Integer)
-> TravT Env Identity CExpr -> TravT Env Identity (Maybe Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MachineDesc -> Map Ident CExpr -> CExpr -> TravT Env Identity CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
arrSize
            Maybe Integer
sizeArgVal <- CExpr -> Maybe Integer
intValue (CExpr -> Maybe Integer)
-> TravT Env Identity CExpr -> TravT Env Identity (Maybe Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MachineDesc -> Map Ident CExpr -> CExpr -> TravT Env Identity CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
sizeArg
            case (Maybe Integer
arrSizeVal, Maybe Integer
sizeArgVal) of
                (Just Integer
arrSizeConst, Just Integer
sizeArgConst) | Integer
arrSizeConst Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
sizeArgConst ->
                    NodeInfo -> Doc AnsiStyle -> Trav Env ()
recordLinterError (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
sizeArg) (Doc AnsiStyle -> Trav Env ()) -> Doc AnsiStyle -> Trav Env ()
forall a b. (a -> b) -> a -> b
$
                        Doc AnsiStyle
"size parameter" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
backticks (String -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty String
sizeParam) Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle
"is passed constant value" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
backticks (String -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty (Doc -> String
forall a. Show a => a -> String
show (CExpr -> Doc
forall p. Pretty p => p -> Doc
C.pretty CExpr
sizeArg)))
                        Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle
"(= " Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall a. Semigroup a => a -> a -> a
<> Integer -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty Integer
sizeArgConst Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall a. Semigroup a => a -> a -> a
<> Doc AnsiStyle
"),\n"
                        Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall a. Semigroup a => a -> a -> a
<> Doc AnsiStyle
"  which is greater than the array size of" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
backticks (String -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty (Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
C.pretty Type
arrTy))) Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall a. Semigroup a => a -> a -> a
<> Doc AnsiStyle
",\n"
                        Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall a. Semigroup a => a -> a -> a
<> Doc AnsiStyle
"  potentially causing buffer overrun in" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
backticks (String -> Doc AnsiStyle
forall a ann. Pretty a => a -> Doc ann
pretty (Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
C.pretty Ident
funId)))
                (Maybe Integer, Maybe Integer)
_ -> () -> Trav Env ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()  -- not constant, or array size greater than size arg
            Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId [(ParamDecl, CExpr, Type)]
args
        ) ((CError -> Trav Env ()) -> Trav Env ())
-> (CError -> Trav Env ()) -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ Trav Env () -> CError -> Trav Env ()
forall a b. a -> b -> a
const (Trav Env () -> CError -> Trav Env ())
-> Trav Env () -> CError -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ () -> Trav Env ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

checkArraySizes Ident
funId ((ParamDecl, CExpr, Type)
_:[(ParamDecl, CExpr, Type)]
xs) = Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId [(ParamDecl, CExpr, Type)]
xs
checkArraySizes Ident
_ [] = () -> Trav Env ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()


linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
    { doExpr :: CExpr -> Trav Env () -> Trav Env ()
doExpr = \CExpr
node Trav Env ()
act -> case CExpr
node of
        CCall fun :: CExpr
fun@(CVar Ident
funId NodeInfo
_) [CExpr]
args NodeInfo
_ ->
            [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
fun TravT Env Identity Type -> (Type -> Trav Env ()) -> Trav Env ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                FunPtrParams [ParamDecl]
params -> do
                    [Type]
tys <- (CExpr -> TravT Env Identity Type)
-> [CExpr] -> TravT Env Identity [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Type -> Type)
-> TravT Env Identity Type -> TravT Env Identity Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
canonicalType (TravT Env Identity Type -> TravT Env Identity Type)
-> (CExpr -> TravT Env Identity Type)
-> CExpr
-> TravT Env Identity Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue) [CExpr]
args
                    Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId ([ParamDecl] -> [CExpr] -> [Type] -> [(ParamDecl, CExpr, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [ParamDecl]
params [CExpr]
args [Type]
tys)
                    Trav Env ()
act
                Type
x -> InvalidASTError -> Trav Env ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (InvalidASTError -> Trav Env ()) -> InvalidASTError -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ NodeInfo -> String -> InvalidASTError
invalidAST (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
node) (String -> InvalidASTError) -> String -> InvalidASTError
forall a b. (a -> b) -> a -> b
$ Type -> String
forall a. Show a => a -> String
show Type
x

        CExpr
_ -> Trav Env ()
act
    }


analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> Trav Env ()
analyse = AstActions (TravT Env Identity) -> GlobalDecls -> Trav Env ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter


descr :: (GlobalDecls -> Trav Env (), (Text, Text))
descr :: (GlobalDecls -> Trav Env (), (Text, Text))
descr = (GlobalDecls -> Trav Env ()
analyse, (Text
"size-arg", Text
"Checks that the size argument passed to a function matches the array size of the preceding argument."))