{-# LANGUAGE TemplateHaskellQuotes #-}
module Test.MockCat.TH.ContextBuilder
  ( -- Constraint rewrite
    liftConstraint,
    -- MonadVar helpers
    mockTType,
    tyVarBndrToType,
    applyFamilyArg,
    -- Context builder
    MockType (..),
    buildContext,
    toVarTs,
    constructClassAppT,
    getTypeVarNames,
    getTypeVarName,
    convertTyVarBndr
  )
where
import Language.Haskell.TH
  ( Name,
    TyVarBndr (..),
    Type (..),
    Pred
  )
import Control.Monad.IO.Class (MonadIO)
import Test.MockCat.MockT (MockT)
import Test.MockCat.TH.ClassAnalysis (ClassName2VarNames (..), toClassInfos, VarAppliedType (..), updateType)

-- | Rewrite constraint types to use 'MockT' for the monad variable where needed.
liftConstraint :: Name -> Type -> Type
liftConstraint :: Name -> Type -> Type
liftConstraint Name
monadVarName = Type -> Type
go
  where
    go :: Type -> Type
go predTy :: Type
predTy@(AppT (ConT Name
ty) (VarT Name
varName))
      | Name
monadVarName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
varName Bool -> Bool -> Bool
&& Name
ty Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Monad = Type
predTy
      | Name
monadVarName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
varName =
          Type -> Type -> Type
AppT (Name -> Type
ConT Name
ty) (Type -> Type -> Type
AppT (Name -> Type
ConT ''MockT) (Name -> Type
VarT Name
varName))
    go (AppT Type
ty (VarT Name
varName))
      | Name
monadVarName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
varName =
          Type -> Type -> Type
AppT Type
ty (Type -> Type -> Type
AppT (Name -> Type
ConT ''MockT) (Name -> Type
VarT Name
varName))
    go (AppT Type
t1 Type
t2) = Type -> Type -> Type
AppT (Type -> Type
go Type
t1) (Type -> Type
go Type
t2)
    go Type
ty = Type
ty

-- MonadVar helpers
mockTType :: Name -> Type
mockTType :: Name -> Type
mockTType Name
monadVarName = Type -> Type -> Type
AppT (Name -> Type
ConT ''MockT) (Name -> Type
VarT Name
monadVarName)

liftTyVar :: Name -> Name -> Type
liftTyVar :: Name -> Name -> Type
liftTyVar Name
monadVarName Name
varName
  | Name
monadVarName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
varName = Name -> Type
mockTType Name
monadVarName
  | Bool
otherwise = Name -> Type
VarT Name
varName

tyVarBndrToType :: Name -> TyVarBndr a -> Type
tyVarBndrToType :: forall a. Name -> TyVarBndr a -> Type
tyVarBndrToType Name
monadVarName (PlainTV Name
binderName a
_) = Name -> Name -> Type
liftTyVar Name
monadVarName Name
binderName
tyVarBndrToType Name
monadVarName (KindedTV Name
binderName a
_ Type
_) = Name -> Name -> Type
liftTyVar Name
monadVarName Name
binderName

applyFamilyArg :: Name -> TyVarBndr a -> Type
applyFamilyArg :: forall a. Name -> TyVarBndr a -> Type
applyFamilyArg = Name -> TyVarBndr a -> Type
forall a. Name -> TyVarBndr a -> Type
tyVarBndrToType

-- Context builder
data MockType = Total | Partial
  deriving (MockType -> MockType -> Bool
(MockType -> MockType -> Bool)
-> (MockType -> MockType -> Bool) -> Eq MockType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MockType -> MockType -> Bool
== :: MockType -> MockType -> Bool
$c/= :: MockType -> MockType -> Bool
/= :: MockType -> MockType -> Bool
Eq)

buildContext ::
  [Pred] ->
  MockType ->
  Name ->
  Name ->
  [TyVarBndr a] ->
  [VarAppliedType] ->
  [Pred]
buildContext :: forall a.
[Type]
-> MockType
-> Name
-> Name
-> [TyVarBndr a]
-> [VarAppliedType]
-> [Type]
buildContext [Type]
cxt MockType
mockType Name
className Name
monadVarName [TyVarBndr a]
tyVars [VarAppliedType]
varAppliedTypes =
  let newCxtRaw :: [Type]
newCxtRaw = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name -> Type -> Type
liftConstraint Name
monadVarName) [Type]
cxt

      isRedundantMonad :: Type -> Bool
isRedundantMonad (AppT (ConT Name
m) (VarT Name
v)) = Name
m Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Monad Bool -> Bool -> Bool
&& Name
v Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
monadVarName
      isRedundantMonad Type
_ = Bool
False
      newCxt :: [Type]
newCxt = (Type -> Bool) -> [Type] -> [Type]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Type -> Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
isRedundantMonad) [Type]
newCxtRaw

      monadIOAppT :: Type
monadIOAppT = Type -> Type -> Type
AppT (Name -> Type
ConT ''MonadIO) (Name -> Type
VarT Name
monadVarName)

      classInfos :: [ClassName2VarNames]
classInfos = [Type] -> [ClassName2VarNames]
toClassInfos [Type]
newCxt
      hasMonadIO :: Bool
hasMonadIO = (ClassName2VarNames -> Bool) -> [ClassName2VarNames] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(ClassName2VarNames Name
c [Name]
_) -> Name
c Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''MonadIO) [ClassName2VarNames]
classInfos
      addedMonads :: [Type]
addedMonads = [Type
monadIOAppT | Bool -> Bool
not Bool
hasMonadIO]
   in case MockType
mockType of
        MockType
Total -> [Type]
newCxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
addedMonads
        MockType
Partial ->
          let classAppT :: Type
classAppT = Name -> [Type] -> Type
constructClassAppT Name
className ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr a] -> [Type]
forall a. [TyVarBndr a] -> [Type]
toVarTs [TyVarBndr a]
tyVars
              varAppliedClassAppT :: Type
varAppliedClassAppT = Type -> [VarAppliedType] -> Type
updateType Type
classAppT [VarAppliedType]
varAppliedTypes
           in [Type]
newCxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
addedMonads [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type
varAppliedClassAppT]

toVarTs :: [TyVarBndr a] -> [Type]
toVarTs :: forall a. [TyVarBndr a] -> [Type]
toVarTs [TyVarBndr a]
tyVars = Name -> Type
VarT (Name -> Type) -> [Name] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr a] -> [Name]
forall a. [TyVarBndr a] -> [Name]
getTypeVarNames [TyVarBndr a]
tyVars

constructClassAppT :: Name -> [Type] -> Type
constructClassAppT :: Name -> [Type] -> Type
constructClassAppT Name
className = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
className)

getTypeVarNames :: [TyVarBndr a] -> [Name]
getTypeVarNames :: forall a. [TyVarBndr a] -> [Name]
getTypeVarNames = (TyVarBndr a -> Name) -> [TyVarBndr a] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr a -> Name
forall a. TyVarBndr a -> Name
getTypeVarName

getTypeVarName :: TyVarBndr a -> Name
getTypeVarName :: forall a. TyVarBndr a -> Name
getTypeVarName (PlainTV Name
varName a
_) = Name
varName
getTypeVarName (KindedTV Name
varName a
_ Type
_) = Name
varName

convertTyVarBndr :: TyVarBndr a -> TyVarBndr ()
convertTyVarBndr :: forall a. TyVarBndr a -> TyVarBndr ()
convertTyVarBndr (PlainTV Name
n a
_) = Name -> () -> TyVarBndr ()
forall flag. Name -> flag -> TyVarBndr flag
PlainTV Name
n ()
convertTyVarBndr (KindedTV Name
n a
_ Type
k) = Name -> () -> Type -> TyVarBndr ()
forall flag. Name -> flag -> Type -> TyVarBndr flag
KindedTV Name
n () Type
k