{-# LANGUAGE TemplateHaskellQuotes #-}
module Test.MockCat.TH.TypeUtils
  ( splitApps,
    substituteType,
    isNotConstantFunctionType,
    collectTypeVars,
    needsTypeable,
    collectTypeableTargets,
    isStandardTypeCon
  )
where

import qualified Data.Map.Strict as Map
import Language.Haskell.TH (Name, Type (..))
import Test.MockCat.Param (Param)
import Test.MockCat.Cons ((:>))

splitApps :: Type -> (Type, [Type])
splitApps :: Type -> (Type, [Type])
splitApps Type
ty = Type -> [Type] -> (Type, [Type])
go Type
ty []
  where
    go :: Type -> [Type] -> (Type, [Type])
go (Language.Haskell.TH.AppT Type
t1 Type
t2) [Type]
acc = Type -> [Type] -> (Type, [Type])
go Type
t1 (Type
t2 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
acc)
    go Type
t [Type]
acc = (Type
t, [Type]
acc)

substituteType :: Map.Map Language.Haskell.TH.Name Language.Haskell.TH.Type -> Language.Haskell.TH.Type -> Language.Haskell.TH.Type
substituteType :: Map Name Type -> Type -> Type
substituteType Map Name Type
subMap = Type -> Type
go
  where
    go :: Type -> Type
go (Language.Haskell.TH.VarT Name
name) = Type -> Name -> Map Name Type -> Type
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault (Name -> Type
Language.Haskell.TH.VarT Name
name) Name
name Map Name Type
subMap
    go (Language.Haskell.TH.AppT Type
t1 Type
t2) = Type -> Type -> Type
Language.Haskell.TH.AppT (Type -> Type
go Type
t1) (Type -> Type
go Type
t2)
    go (Language.Haskell.TH.SigT Type
t Type
k) = Type -> Type -> Type
Language.Haskell.TH.SigT (Type -> Type
go Type
t) Type
k
    go (Language.Haskell.TH.ParensT Type
t) = Type -> Type
Language.Haskell.TH.ParensT (Type -> Type
go Type
t)
    go (Language.Haskell.TH.InfixT Type
t1 Name
n Type
t2) = Type -> Name -> Type -> Type
Language.Haskell.TH.InfixT (Type -> Type
go Type
t1) Name
n (Type -> Type
go Type
t2)
    go (Language.Haskell.TH.UInfixT Type
t1 Name
n Type
t2) = Type -> Name -> Type -> Type
Language.Haskell.TH.UInfixT (Type -> Type
go Type
t1) Name
n (Type -> Type
go Type
t2)
    go (Language.Haskell.TH.ForallT [TyVarBndr Specificity]
tvs [Type]
ctx Type
t) = [TyVarBndr Specificity] -> [Type] -> Type -> Type
Language.Haskell.TH.ForallT [TyVarBndr Specificity]
tvs ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
go [Type]
ctx) (Type -> Type
go Type
t)
    go Type
t = Type
t

isNotConstantFunctionType :: Type -> Bool
isNotConstantFunctionType :: Type -> Bool
isNotConstantFunctionType (AppT (AppT Type
ArrowT Type
_) Type
_) = Bool
True
isNotConstantFunctionType (AppT Type
t1 Type
t2) = Type -> Bool
isNotConstantFunctionType Type
t1 Bool -> Bool -> Bool
|| Type -> Bool
isNotConstantFunctionType Type
t2
isNotConstantFunctionType (TupleT Int
_) = Bool
False
isNotConstantFunctionType (ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t) = Type -> Bool
isNotConstantFunctionType Type
t
isNotConstantFunctionType Type
_ = Bool
False

needsTypeable :: Type -> Bool
needsTypeable :: Type -> Bool
needsTypeable = Type -> Bool
go
  where
    go :: Type -> Bool
go (ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t) = Type -> Bool
go Type
t
    go (AppT Type
t1 Type
t2) = Type -> Bool
go Type
t1 Bool -> Bool -> Bool
|| Type -> Bool
go Type
t2
    go (SigT Type
t Type
_) = Type -> Bool
go Type
t
    go (VarT Name
_) = Bool
True
    go (ParensT Type
t) = Type -> Bool
go Type
t
    go (InfixT Type
t1 Name
_ Type
t2) = Type -> Bool
go Type
t1 Bool -> Bool -> Bool
|| Type -> Bool
go Type
t2
    go (UInfixT Type
t1 Name
_ Type
t2) = Type -> Bool
go Type
t1 Bool -> Bool -> Bool
|| Type -> Bool
go Type
t2
    go (ImplicitParamT String
_ Type
t) = Type -> Bool
go Type
t
    go Type
_ = Bool
False

collectTypeVars :: Type -> [Name]
collectTypeVars :: Type -> [Name]
collectTypeVars (VarT Name
name) = [Name
name]
collectTypeVars (AppT Type
t1 Type
t2) = Type -> [Name]
collectTypeVars Type
t1 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Type -> [Name]
collectTypeVars Type
t2
collectTypeVars (SigT Type
t Type
_) = Type -> [Name]
collectTypeVars Type
t
collectTypeVars (ParensT Type
t) = Type -> [Name]
collectTypeVars Type
t
collectTypeVars (InfixT Type
t1 Name
_ Type
t2) = Type -> [Name]
collectTypeVars Type
t1 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Type -> [Name]
collectTypeVars Type
t2
collectTypeVars (UInfixT Type
t1 Name
_ Type
t2) = Type -> [Name]
collectTypeVars Type
t1 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Type -> [Name]
collectTypeVars Type
t2
collectTypeVars (ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t) = Type -> [Name]
collectTypeVars Type
t
collectTypeVars (ImplicitParamT String
_ Type
t) = Type -> [Name]
collectTypeVars Type
t
collectTypeVars Type
_ = []

collectTypeableTargets :: Type -> [Type]
collectTypeableTargets :: Type -> [Type]
collectTypeableTargets Type
ty =
  case Type
ty of
    VarT Name
_ -> [Type
ty]
    AppT Type
_ Type
_ ->
      let (Type
f, [Type]
args) = Type -> (Type, [Type])
splitApps Type
ty
      in if Type -> Bool
isStandardTypeCon Type
f
         then (Type -> [Type]) -> [Type] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Type -> [Type]
collectTypeableTargets [Type]
args
         else [Type
ty]
    SigT Type
t Type
_ -> Type -> [Type]
collectTypeableTargets Type
t
    ParensT Type
t -> Type -> [Type]
collectTypeableTargets Type
t
    InfixT Type
t1 Name
_ Type
t2 -> Type -> [Type]
collectTypeableTargets Type
t1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Type -> [Type]
collectTypeableTargets Type
t2
    UInfixT Type
t1 Name
_ Type
t2 -> Type -> [Type]
collectTypeableTargets Type
t1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Type -> [Type]
collectTypeableTargets Type
t2
    ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t -> Type -> [Type]
collectTypeableTargets Type
t
    Type
_ -> []

isStandardTypeCon :: Type -> Bool
isStandardTypeCon :: Type -> Bool
isStandardTypeCon Type
ArrowT = Bool
True
isStandardTypeCon Type
ListT = Bool
True
isStandardTypeCon (TupleT Int
_) = Bool
True
isStandardTypeCon (ConT Name
n) =
  Name
n Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`
    [ ''Maybe
    , ''IO
    , ''Either
    , ''[]
    , ''(,)
    , ''(,,)
    , ''(,,,)
    , ''(,,,,)
    , ''Param
    , ''(:>)
    ]
isStandardTypeCon Type
_ = Bool
False