module TypeLet.Plugin (plugin) where
import Prelude hiding (cycle)
import Data.Traversable (forM)
import GHC.Plugins (Plugin(..), defaultPlugin, purePlugin)
import TypeLet.Plugin.Constraints
import TypeLet.Plugin.GhcTcPluginAPI
import TypeLet.Plugin.NameResolution
import TypeLet.Plugin.Substitution
plugin :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin {
pluginRecompile = purePlugin
, tcPlugin = \[CommandLineOption]
_cmdline -> TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin)
-> (TcPlugin -> TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcPlugin -> TcPlugin
mkTcPlugin (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin {
tcPluginInit :: TcPluginM 'Init ResolvedNames
tcPluginInit = TcPluginM 'Init ResolvedNames
resolveNames
, tcPluginSolve :: ResolvedNames -> TcPluginSolver
tcPluginSolve = ResolvedNames -> TcPluginSolver
solve
, tcPluginRewrite :: ResolvedNames -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = \ResolvedNames
_st -> UniqFM TyCon TcPluginRewriter
forall key elt. UniqFM key elt
emptyUFM
, tcPluginStop :: ResolvedNames -> TcPluginM 'Stop ()
tcPluginStop = \ResolvedNames
_st -> () -> TcPluginM 'Stop ()
forall a. a -> TcPluginM 'Stop a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
}
}
solve :: ResolvedNames -> TcPluginSolver
solve :: ResolvedNames -> TcPluginSolver
solve ResolvedNames
rn [Ct]
given [Ct]
wanted
| [Ct] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Ct]
wanted = ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
rn [Ct]
given
| Bool
otherwise = ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted
simplifyGivens ::
ResolvedNames
-> [Ct]
-> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens :: ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
_st [Ct]
_given = TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []
simplifyWanteds ::
ResolvedNames
-> [Ct]
-> [Ct]
-> TcPluginM 'Solve TcPluginSolveResult
simplifyWanteds :: ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted = do
case (Ct
-> ParseResult
(GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet))
-> [Ct]
-> Either (GenLocated CtLoc InvalidLet) [GenLocated CtLoc CLet]
forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll (ResolvedNames
-> Ct
-> ParseResult
(GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet)
parseLet ResolvedNames
rn) [Ct]
given of
Left GenLocated CtLoc InvalidLet
err ->
GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult)
-> GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ InvalidLet -> TcPluginErrorMessage
formatInvalidLet (InvalidLet -> TcPluginErrorMessage)
-> GenLocated CtLoc InvalidLet
-> GenLocated CtLoc TcPluginErrorMessage
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenLocated CtLoc InvalidLet
err
Right [GenLocated CtLoc CLet]
lets -> do
case [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) Subst
letsToSubst [GenLocated CtLoc CLet]
lets of
Left Cycle (GenLocated CtLoc CLet)
cycle ->
GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult)
-> GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ Cycle (GenLocated CtLoc CLet)
-> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle Cycle (GenLocated CtLoc CLet)
cycle
Right Subst
subst -> do
([(EvTerm, Ct)]
solved, [Ct]
new) <- ([((EvTerm, Ct), Ct)] -> ([(EvTerm, Ct)], [Ct]))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct])
forall a b. (a -> b) -> TcPluginM 'Solve a -> TcPluginM 'Solve b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((EvTerm, Ct), Ct)] -> ([(EvTerm, Ct)], [Ct])
forall a b. [(a, b)] -> ([a], [b])
unzip (TcPluginM 'Solve [((EvTerm, Ct), Ct)]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct]))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct])
forall a b. (a -> b) -> a -> b
$
[(Ct, GenLocated CtLoc CEqual)]
-> ((Ct, GenLocated CtLoc CEqual)
-> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((Ct -> ParseResult Void (Ct, GenLocated CtLoc CEqual))
-> [Ct] -> [(Ct, GenLocated CtLoc CEqual)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CEqual))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CEqual)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (ResolvedNames -> Ct -> ParseResult Void (GenLocated CtLoc CEqual)
parseEqual ResolvedNames
rn)) [Ct]
wanted) (((Ct, GenLocated CtLoc CEqual)
-> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)])
-> ((Ct, GenLocated CtLoc CEqual)
-> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
forall a b. (a -> b) -> a -> b
$
(Ct
-> GenLocated CtLoc CEqual -> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> (Ct, GenLocated CtLoc CEqual)
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Subst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual Subst
subst)
TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
where
newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
w = CtLoc -> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *) a. MonadTcPluginWork m => CtLoc -> m a -> m a
setCtLocM CtLoc
l (TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve CtEvidence)
-> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve CtEvidence
forall a b. (a -> b) -> a -> b
$ CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> PredType -> m CtEvidence
newWanted CtLoc
l PredType
w
errWith ::
GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith :: GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (L CtLoc
l TcPluginErrorMessage
err) = do
PredType
errAsTyp <- TcPluginErrorMessage -> TcPluginM 'Solve PredType
forall (m :: * -> *).
MonadTcPluginWork m =>
TcPluginErrorMessage -> m PredType
mkTcPluginErrorTy TcPluginErrorMessage
err
CtEvidence -> TcPluginSolveResult
mkErr (CtEvidence -> TcPluginSolveResult)
-> TcPluginM 'Solve CtEvidence
-> TcPluginM 'Solve TcPluginSolveResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
errAsTyp
where
mkErr :: CtEvidence -> TcPluginSolveResult
mkErr :: CtEvidence -> TcPluginSolveResult
mkErr = [Ct] -> TcPluginSolveResult
TcPluginContradiction ([Ct] -> TcPluginSolveResult)
-> (CtEvidence -> [Ct]) -> CtEvidence -> TcPluginSolveResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> [Ct] -> [Ct]
forall a. a -> [a] -> [a]
:[]) (Ct -> [Ct]) -> (CtEvidence -> Ct) -> CtEvidence -> [Ct]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CtEvidence -> Ct
mkNonCanonical
solveEqual ::
Subst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual :: Subst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual Subst
subst Ct
orig (L CtLoc
l CEqual
parsed) = do
CtEvidence
ev <- CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l (PredType -> TcPluginM 'Solve CtEvidence)
-> PredType -> TcPluginM 'Solve CtEvidence
forall a b. (a -> b) -> a -> b
$
Role -> PredType -> PredType -> PredType
mkEqPredRole
Role
Nominal
((() :: Constraint) => Subst -> PredType -> PredType
Subst -> PredType -> PredType
substTy Subst
subst (CEqual -> PredType
equalLHS CEqual
parsed))
((() :: Constraint) => Subst -> PredType -> PredType
Subst -> PredType -> PredType
substTy Subst
subst (CEqual -> PredType
equalRHS CEqual
parsed))
((EvTerm, Ct), Ct) -> TcPluginM 'Solve ((EvTerm, Ct), Ct)
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (
(ResolvedNames -> CEqual -> EvTerm
evidenceEqual ResolvedNames
rn CEqual
parsed, Ct
orig)
, CtEvidence -> Ct
mkNonCanonical CtEvidence
ev
)