{-# LANGUAGE CPP #-}
{-# LANGUAGE BlockArguments #-}
module Stock.TestEquality (synthTestEquality, synthTestCoercion) where
import GHC.Plugins hiding (TcPlugin)
import GHC.Tc.Plugin (TcPluginM, unsafeTcPluginTcM)
import GHC.Tc.Types.Constraint (Ct)
#if MIN_VERSION_ghc(9,12,0)
import GHC.Tc.Types.CtLoc (CtLoc)
#else
import GHC.Tc.Types.Constraint (CtLoc)
#endif
import GHC.Tc.Types.Evidence (EvTerm(EvExpr))
import GHC.Core.Class (Class, classMethods)
import GHC.Core.TyCo.Compare (eqType)
import Stock.Internal
dcEqSpec :: DataCon -> [EqSpec]
dcEqSpec :: DataCon -> [EqSpec]
dcEqSpec DataCon
dc = let ([TyVar]
_, [TyVar]
_, [EqSpec]
eqs, ThetaType
_, [Scaled Type]
_, Type
_) = DataCon
-> ([TyVar], [TyVar], [EqSpec], ThetaType, [Scaled Type], Type)
dataConFullSig DataCon
dc in [EqSpec]
eqs
pinnedGround :: DataCon -> Maybe Type
pinnedGround :: DataCon -> Maybe Type
pinnedGround DataCon
dc = case DataCon -> [EqSpec]
dcEqSpec DataCon
dc of
[EqSpec
es] | [TyVar] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (DataCon -> [TyVar]
dataConExTyCoVars DataCon
dc)
, [Scaled Type] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (DataCon -> [Scaled Type]
dataConOrigArgTys DataCon
dc)
, let ty :: Type
ty = (TyVar, Type) -> Type
forall a b. (a, b) -> b
snd (EqSpec -> (TyVar, Type)
eqSpecPair EqSpec
es)
, VarSet -> Bool
isEmptyVarSet (Type -> VarSet
tyCoVarsOfType Type
ty)
-> Type -> Maybe Type
forall a. a -> Maybe a
Just Type
ty
[EqSpec]
_ -> Maybe Type
forall a. Maybe a
Nothing
synthTestEquality, synthTestCoercion
:: GenEnv -> Class -> CtLoc -> Type -> Type -> TcPluginM (Maybe (EvTerm, [Ct]))
synthTestEquality :: GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthTestEquality = Bool
-> GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike Bool
True
synthTestCoercion :: GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthTestCoercion = Bool
-> GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike Bool
False
synthEqLike :: Bool -> GenEnv -> Class -> CtLoc -> Type -> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike :: Bool
-> GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike Bool
useRefl GenEnv
gen Class
cls CtLoc
_loc Type
wrappedTy Type
f =
case (GenEnv -> Maybe TyCon
geStock1 GenEnv
gen, Type -> Maybe TyCon
tyConAppTyCon_maybe Type
f) of
(Just TyCon
st1Tc, Just TyCon
fTc)
| ThetaType -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (HasDebugCallStack => Type -> ThetaType
Type -> ThetaType
tyConAppArgs Type
f)
, dcons :: [DataCon]
dcons@(DataCon
dc0 : [DataCon]
_) <- TyCon -> [DataCon]
tyConDataCons TyCon
fTc
, Just ThetaType
pins <- (DataCon -> Maybe Type) -> [DataCon] -> Maybe ThetaType
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse DataCon -> Maybe Type
pinnedGround [DataCon]
dcons
, (EqSpec
es0 : [EqSpec]
_) <- DataCon -> [EqSpec]
dcEqSpec DataCon
dc0
, (TyVar
meth : [TyVar]
_) <- Class -> [TyVar]
classMethods Class
cls
, (TyCon
witTc : [TyCon]
_) <- [ TyCon
tc | TyCon
tc <- UniqSet TyCon -> [TyCon]
forall elt. UniqSet elt -> [elt]
nonDetEltsUniqSet (Type -> UniqSet TyCon
tyConsOfType (TyVar -> Type
varType TyVar
meth))
, Name -> OccName
nameOccName (TyCon -> Name
tyConName TyCon
tc)
OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== String -> OccName
mkTcOcc (if Bool
useRefl then String
":~:" else String
"Coercion") ] -> do
let witCon :: DataCon
witCon = TyCon -> DataCon
tyConSingleDataCon TyCon
witTc
kK :: Type
kK = TyVar -> Type
tyVarKind ((TyVar, Type) -> TyVar
forall a b. (a, b) -> a
fst (EqSpec -> (TyVar, Type)
eqSpecPair EqSpec
es0))
coAt :: Type -> Coercion
coAt = GenEnv -> TyCon -> Type -> Type -> Type -> Type -> Coercion
coDown1 GenEnv
gen TyCon
st1Tc Type
wrappedTy Type
f Type
f
TyVar
aTv <- Type -> String -> TcPluginM TyVar
freshTyVarK Type
kK String
"a"
TyVar
bTv <- Type -> String -> TcPluginM TyVar
freshTyVarK Type
kK String
"b"
TyVar
xId <- Type -> String -> TcPluginM TyVar
freshId (Type -> Type -> Type
mkAppTy Type
wrappedTy (TyVar -> Type
mkTyVarTy TyVar
aTv)) String
"x"
TyVar
yId <- Type -> String -> TcPluginM TyVar
freshId (Type -> Type -> Type
mkAppTy Type
wrappedTy (TyVar -> Type
mkTyVarTy TyVar
bTv)) String
"y"
TyVar
wbX <- Type -> String -> TcPluginM TyVar
freshId (TyCon -> ThetaType -> Type
mkTyConApp TyCon
fTc [TyVar -> Type
mkTyVarTy TyVar
aTv]) String
"wx"
TyVar
wbY <- Type -> String -> TcPluginM TyVar
freshId (TyCon -> ThetaType -> Type
mkTyConApp TyCon
fTc [TyVar -> Type
mkTyVarTy TyVar
bTv]) String
"wy"
let aTy :: Type
aTy = TyVar -> Type
mkTyVarTy TyVar
aTv ; bTy :: Type
bTy = TyVar -> Type
mkTyVarTy TyVar
bTv
witOf :: Type -> Type -> Type
witOf Type
x Type
y = TyCon -> ThetaType -> Type
mkTyConApp TyCon
witTc [Type
kK, Type
x, Type
y]
resTy :: Type
resTy = TyCon -> ThetaType -> Type
mkTyConApp TyCon
maybeTyCon [Type -> Type -> Type
witOf Type
aTy Type
bTy]
nothingE :: CoreExpr
nothingE = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
nothingDataCon [Type -> CoreExpr
forall b. Type -> Expr b
Type (Type -> Type -> Type
witOf Type
aTy Type
bTy)]
same :: TyVar -> TyVar -> CoreExpr
same TyVar
cox TyVar
coy =
let abCo :: Coercion
abCo = Coercion -> Coercion -> Coercion
mkTransCo (TyVar -> Coercion
mkCoVarCo TyVar
cox) (Coercion -> Coercion
mkSymCo (TyVar -> Coercion
mkCoVarCo TyVar
coy))
proof :: CoreExpr
proof | Bool
useRefl = Coercion -> CoreExpr
forall b. Coercion -> Expr b
Coercion (Coercion -> Coercion
mkSymCo Coercion
abCo)
| Bool
otherwise = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
coercibleDataCon
[Type -> CoreExpr
forall b. Type -> Expr b
Type Type
kK, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
aTy, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
bTy
, Coercion -> CoreExpr
forall b. Coercion -> Expr b
Coercion (HasDebugCallStack => Coercion -> Coercion
Coercion -> Coercion
mkSubCo Coercion
abCo)]
wit :: CoreExpr
wit = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
witCon [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
kK, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
aTy, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
bTy, CoreExpr
proof]
in DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
justDataCon [Type -> CoreExpr
forall b. Type -> Expr b
Type (Type -> Type -> Type
witOf Type
aTy Type
bTy), CoreExpr
wit]
let innerAlts :: Type -> TyVar -> TcPluginM [Alt TyVar]
innerAlts Type
ti TyVar
cox = ((DataCon, Type) -> TcPluginM (Alt TyVar))
-> [(DataCon, Type)] -> TcPluginM [Alt TyVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (DataCon, Type) -> TcPluginM (Alt TyVar)
mkInner ([DataCon] -> ThetaType -> [(DataCon, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DataCon]
dcons ThetaType
pins)
where mkInner :: (DataCon, Type) -> TcPluginM (Alt TyVar)
mkInner (DataCon
dcj, Type
tj) = do
TyVar
coy <- Type -> TcPluginM TyVar
freshCoVar (Type -> Type -> Type
mkPrimEqPred Type
bTy Type
tj)
let rhs :: CoreExpr
rhs = if Type -> Type -> Bool
eqType Type
ti Type
tj then TyVar -> TyVar -> CoreExpr
same TyVar
cox TyVar
coy else CoreExpr
nothingE
Alt TyVar -> TcPluginM (Alt TyVar)
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [TyVar] -> CoreExpr -> Alt TyVar
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dcj) [TyVar
coy] CoreExpr
rhs)
[Alt TyVar]
outerAlts <- ((DataCon, Type) -> TcPluginM (Alt TyVar))
-> [(DataCon, Type)] -> TcPluginM [Alt TyVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
(\(DataCon
dci, Type
ti) -> do
TyVar
cox <- Type -> TcPluginM TyVar
freshCoVar (Type -> Type -> Type
mkPrimEqPred Type
aTy Type
ti)
[Alt TyVar]
ialts <- Type -> TyVar -> TcPluginM [Alt TyVar]
innerAlts Type
ti TyVar
cox
let inner :: CoreExpr
inner = CoreExpr -> TyVar -> Type -> [Alt TyVar] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast (TyVar -> CoreExpr
forall b. TyVar -> Expr b
Var TyVar
yId) (Type -> Coercion
coAt Type
bTy)) TyVar
wbY Type
resTy [Alt TyVar]
ialts
Alt TyVar -> TcPluginM (Alt TyVar)
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [TyVar] -> CoreExpr -> Alt TyVar
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dci) [TyVar
cox] CoreExpr
inner))
([DataCon] -> ThetaType -> [(DataCon, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DataCon]
dcons ThetaType
pins)
let impl :: CoreExpr
impl = [TyVar] -> CoreExpr -> CoreExpr
mkCoreLams [TyVar
aTv, TyVar
bTv, TyVar
xId, TyVar
yId] (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$
CoreExpr -> TyVar -> Type -> [Alt TyVar] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast (TyVar -> CoreExpr
forall b. TyVar -> Expr b
Var TyVar
xId) (Type -> Coercion
coAt Type
aTy)) TyVar
wbX Type
resTy [Alt TyVar]
outerAlts
dict :: CoreExpr
dict = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps (Class -> DataCon
classDataCon Class
cls) [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
kK, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
wrappedTy, CoreExpr
impl]
Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((EvTerm, [Ct]) -> Maybe (EvTerm, [Ct])
forall a. a -> Maybe a
Just (CoreExpr -> EvTerm
EvExpr CoreExpr
dict, []))
(Maybe TyCon, Maybe TyCon)
_ -> Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (EvTerm, [Ct])
forall a. Maybe a
Nothing
freshCoVar :: Type -> TcPluginM CoVar
freshCoVar :: Type -> TcPluginM TyVar
freshCoVar Type
ty = do
Unique
u <- TcM Unique -> TcPluginM Unique
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM TcM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
TyVar -> TcPluginM TyVar
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Type -> TyVar
mkCoVar (Unique -> OccName -> Name
mkSystemName Unique
u (FastString -> OccName
mkVarOccFS (String -> FastString
fsLit String
"co"))) Type
ty)