{-# LANGUAGE TemplateHaskellQuotes #-}
module Test.MockCat.TH.ClassAnalysis
  ( ClassName2VarNames (..),
    VarName2ClassNames (..),
    toClassInfos,
    toClassInfo,
    getTypeNames,
    filterClassInfo,
    filterMonadicVarInfos,
    hasMonadInVarInfo,
    getClassName,
    getClassNames,
    VarAppliedType (..),
    applyVarAppliedTypes,
    updateType,
    findClass,
    hasClass
  )
where

import Control.Monad (guard)
import Data.List (find)
import qualified Data.Map.Strict as Map
import Data.Maybe (isJust)
import Data.Text (pack, splitOn, unpack)
import Language.Haskell.TH
  ( Name,
    Pred,
    Type (..),
  )

data ClassName2VarNames = ClassName2VarNames Name [Name]

instance Show ClassName2VarNames where
  show :: ClassName2VarNames -> String
show (ClassName2VarNames Name
cName [Name]
varNames) = Name -> [Name] -> String
showClassDef Name
cName [Name]
varNames

data VarName2ClassNames = VarName2ClassNames Name [Name]

instance Show VarName2ClassNames where
  show :: VarName2ClassNames -> String
show (VarName2ClassNames Name
varName [Name]
classNames) = Name -> String
forall a. Show a => a -> String
show Name
varName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" class is " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [String] -> String
unwords (Name -> String
showClassName (Name -> String) -> [Name] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
classNames)

toClassInfos :: [Pred] -> [ClassName2VarNames]
toClassInfos :: [Type] -> [ClassName2VarNames]
toClassInfos = (Type -> ClassName2VarNames) -> [Type] -> [ClassName2VarNames]
forall a b. (a -> b) -> [a] -> [b]
map Type -> ClassName2VarNames
toClassInfo

toClassInfo :: Pred -> ClassName2VarNames
toClassInfo :: Type -> ClassName2VarNames
toClassInfo (AppT Type
t1 Type
t2) =
  let (ClassName2VarNames Name
name [Name]
vars) = Type -> ClassName2VarNames
toClassInfo Type
t1
   in Name -> [Name] -> ClassName2VarNames
ClassName2VarNames Name
name ([Name]
vars [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Type -> [Name]
getTypeNames Type
t2)
toClassInfo (ConT Name
name) = Name -> [Name] -> ClassName2VarNames
ClassName2VarNames Name
name []
toClassInfo Type
_ = String -> ClassName2VarNames
forall a. HasCallStack => String -> a
error String
"Unsupported Type structure"

getTypeNames :: Pred -> [Name]
getTypeNames :: Type -> [Name]
getTypeNames (VarT Name
name) = [Name
name]
getTypeNames (ConT Name
name) = [Name
name]
getTypeNames Type
_ = []

filterClassInfo :: Name -> [ClassName2VarNames] -> [ClassName2VarNames]
filterClassInfo :: Name -> [ClassName2VarNames] -> [ClassName2VarNames]
filterClassInfo Name
name = (ClassName2VarNames -> Bool)
-> [ClassName2VarNames] -> [ClassName2VarNames]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> ClassName2VarNames -> Bool
hasVarName Name
name)
  where
    hasVarName :: Name -> ClassName2VarNames -> Bool
    hasVarName :: Name -> ClassName2VarNames -> Bool
hasVarName Name
target (ClassName2VarNames Name
_ [Name]
varNames) = Name
target Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
varNames

filterMonadicVarInfos :: [VarName2ClassNames] -> [VarName2ClassNames]
filterMonadicVarInfos :: [VarName2ClassNames] -> [VarName2ClassNames]
filterMonadicVarInfos = (VarName2ClassNames -> Bool)
-> [VarName2ClassNames] -> [VarName2ClassNames]
forall a. (a -> Bool) -> [a] -> [a]
filter VarName2ClassNames -> Bool
hasMonadInVarInfo

hasMonadInVarInfo :: VarName2ClassNames -> Bool
hasMonadInVarInfo :: VarName2ClassNames -> Bool
hasMonadInVarInfo (VarName2ClassNames Name
_ [Name]
classNames) = ''Monad Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
classNames

getClassName :: Type -> Name
getClassName :: Type -> Name
getClassName (ConT Name
name) = Name
name
getClassName (AppT Type
ty Type
_) = Type -> Name
getClassName Type
ty
getClassName Type
d = String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"unsupported class definition: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Type -> String
forall a. Show a => a -> String
show Type
d

getClassNames :: Type -> [Name]
getClassNames :: Type -> [Name]
getClassNames (AppT (ConT Name
name1) (ConT Name
name2)) = [Name
name1, Name
name2]
getClassNames (AppT Type
ty (ConT Name
name)) = Type -> [Name]
getClassNames Type
ty [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name
name]
getClassNames (AppT Type
ty1 Type
ty2) = Type -> [Name]
getClassNames Type
ty1 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Type -> [Name]
getClassNames Type
ty2
getClassNames Type
_ = []

showClassName :: Name -> String
showClassName :: Name -> String
showClassName Name
n = String -> ShowS
splitLast String
"." ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Name -> String
forall a. Show a => a -> String
show Name
n

showClassDef :: Name -> [Name] -> String
showClassDef :: Name -> [Name] -> String
showClassDef Name
className [Name]
varNames = Name -> String
showClassName Name
className String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [String] -> String
unwords (Name -> String
forall a. Show a => a -> String
show (Name -> String) -> [Name] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
varNames)

splitLast :: String -> String -> String
splitLast :: String -> ShowS
splitLast String
delimiter = [String] -> String
forall a. HasCallStack => [a] -> a
last ([String] -> String) -> (String -> [String]) -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String -> [String]
split String
delimiter

split :: String -> String -> [String]
split :: String -> String -> [String]
split String
delimiter String
str = Text -> String
unpack (Text -> String) -> [Text] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
splitOn (String -> Text
pack String
delimiter) (String -> Text
pack String
str)

data VarAppliedType = VarAppliedType Name (Maybe Name)
  deriving (Int -> VarAppliedType -> ShowS
[VarAppliedType] -> ShowS
VarAppliedType -> String
(Int -> VarAppliedType -> ShowS)
-> (VarAppliedType -> String)
-> ([VarAppliedType] -> ShowS)
-> Show VarAppliedType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VarAppliedType -> ShowS
showsPrec :: Int -> VarAppliedType -> ShowS
$cshow :: VarAppliedType -> String
show :: VarAppliedType -> String
$cshowList :: [VarAppliedType] -> ShowS
showList :: [VarAppliedType] -> ShowS
Show)

applyVarAppliedTypes :: [VarAppliedType] -> Type -> Type
applyVarAppliedTypes :: [VarAppliedType] -> Type -> Type
applyVarAppliedTypes [VarAppliedType]
varAppliedTypes = Type -> Type
transform
  where
    mapping :: Map Name Name
mapping =
      [(Name, Name)] -> Map Name Name
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
        [ (Name
varName, Name
className)
        | VarAppliedType Name
varName (Just Name
className) <- [VarAppliedType]
varAppliedTypes
        ]
    transform :: Type -> Type
transform (VarT Name
n) =
      case Name -> Map Name Name -> Maybe Name
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
n Map Name Name
mapping of
        Just Name
className -> Name -> Type
ConT Name
className
        Maybe Name
Nothing -> Name -> Type
VarT Name
n
    transform (AppT Type
t1 Type
t2) = Type -> Type -> Type
AppT (Type -> Type
transform Type
t1) (Type -> Type
transform Type
t2)
    transform (SigT Type
t Type
k) = Type -> Type -> Type
SigT (Type -> Type
transform Type
t) Type
k
    transform (ParensT Type
t) = Type -> Type
ParensT (Type -> Type
transform Type
t)
    transform (InfixT Type
t1 Name
n Type
t2) = Type -> Name -> Type -> Type
InfixT (Type -> Type
transform Type
t1) Name
n (Type -> Type
transform Type
t2)
    transform (UInfixT Type
t1 Name
n Type
t2) = Type -> Name -> Type -> Type
UInfixT (Type -> Type
transform Type
t1) Name
n (Type -> Type
transform Type
t2)
    transform (ForallT [TyVarBndr Specificity]
tvs [Type]
ctx Type
t) = [TyVarBndr Specificity] -> [Type] -> Type -> Type
ForallT [TyVarBndr Specificity]
tvs ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
transform [Type]
ctx) (Type -> Type
transform Type
t)
    transform Type
t = Type
t

updateType :: Type -> [VarAppliedType] -> Type
updateType :: Type -> [VarAppliedType] -> Type
updateType (AppT (VarT Name
v1) (VarT Name
v2)) [VarAppliedType]
varAppliedTypes =
  let x :: Type
x = Type -> (Name -> Type) -> Maybe Name -> Type
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Name -> Type
VarT Name
v1) Name -> Type
ConT (Name -> [VarAppliedType] -> Maybe Name
findClass Name
v1 [VarAppliedType]
varAppliedTypes)
      y :: Type
y = Type -> (Name -> Type) -> Maybe Name -> Type
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Name -> Type
VarT Name
v2) Name -> Type
ConT (Name -> [VarAppliedType] -> Maybe Name
findClass Name
v2 [VarAppliedType]
varAppliedTypes)
   in Type -> Type -> Type
AppT Type
x Type
y
updateType Type
ty [VarAppliedType]
_ = Type
ty

hasClass :: Name -> [VarAppliedType] -> Bool
hasClass :: Name -> [VarAppliedType] -> Bool
hasClass Name
varName = (VarAppliedType -> Bool) -> [VarAppliedType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(VarAppliedType Name
v Maybe Name
c) -> Name
v Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
varName Bool -> Bool -> Bool
&& Maybe Name -> Bool
forall a. Maybe a -> Bool
isJust Maybe Name
c)

findClass :: Name -> [VarAppliedType] -> Maybe Name
findClass :: Name -> [VarAppliedType] -> Maybe Name
findClass Name
varName [VarAppliedType]
types = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Name -> [VarAppliedType] -> Bool
hasClass Name
varName [VarAppliedType]
types
  (VarAppliedType Name
_ Maybe Name
c) <- (VarAppliedType -> Bool)
-> [VarAppliedType] -> Maybe VarAppliedType
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(VarAppliedType Name
v Maybe Name
_) -> Name
v Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
varName) [VarAppliedType]
types
  Maybe Name
c