{-# LANGUAGE CPP #-}
{-# LANGUAGE BlockArguments #-}
-- | A deliberately /minimal, forward-safe/ 'Data.Type.Equality.TestEquality'
-- (and 'Data.Type.Coercion.TestCoercion') synthesizer.
--
-- It handles exactly the unambiguous "finite singleton" GADT: a one-parameter
-- type whose every constructor is nullary, has no existentials, and pins the
-- parameter to a /ground/ type:
--
-- > data T a where { TInt :: T Int; TBool :: T Bool }
--
-- For these the lawful behaviour is forced: @testEquality x y@ is @Just Refl@
-- exactly when the type /indices/ of @x@ and @y@ coincide (NOT when they are
-- the same constructor: two constructors pinning the same type are equal), and
-- @Nothing@ otherwise.  Because that is the only law-abiding implementation, it
-- can never disagree with a future, more general design, so it commits us to
-- nothing.  Anything outside the subset is refused.
module Stock.TestEquality (synthTestEquality, synthTestCoercion) where

import GHC.Plugins hiding (TcPlugin)
import GHC.Tc.Plugin (TcPluginM, unsafeTcPluginTcM)
import GHC.Tc.Types.Constraint (Ct)
#if MIN_VERSION_ghc(9,12,0)
import GHC.Tc.Types.CtLoc (CtLoc)
#else
import GHC.Tc.Types.Constraint (CtLoc)
#endif
import GHC.Tc.Types.Evidence (EvTerm(EvExpr))
import GHC.Core.Class (Class, classMethods)
import GHC.Core.TyCo.Compare (eqType)
import Stock.Internal

-- | A datacon's GADT equality refinements (no public accessor; via the sig).
dcEqSpec :: DataCon -> [EqSpec]
dcEqSpec :: DataCon -> [EqSpec]
dcEqSpec DataCon
dc = let ([TyVar]
_, [TyVar]
_, [EqSpec]
eqs, ThetaType
_, [Scaled Type]
_, Type
_) = DataCon
-> ([TyVar], [TyVar], [EqSpec], ThetaType, [Scaled Type], Type)
dataConFullSig DataCon
dc in [EqSpec]
eqs

-- | A constructor in the supported subset; returns its pinned ground index.
pinnedGround :: DataCon -> Maybe Type
pinnedGround :: DataCon -> Maybe Type
pinnedGround DataCon
dc = case DataCon -> [EqSpec]
dcEqSpec DataCon
dc of
  [EqSpec
es] | [TyVar] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (DataCon -> [TyVar]
dataConExTyCoVars DataCon
dc)               -- no existentials
       , [Scaled Type] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (DataCon -> [Scaled Type]
dataConOrigArgTys DataCon
dc)               -- nullary (no value fields)
       , let ty :: Type
ty = (TyVar, Type) -> Type
forall a b. (a, b) -> b
snd (EqSpec -> (TyVar, Type)
eqSpecPair EqSpec
es)
       , VarSet -> Bool
isEmptyVarSet (Type -> VarSet
tyCoVarsOfType Type
ty)         -- ground (closed) index
       -> Type -> Maybe Type
forall a. a -> Maybe a
Just Type
ty
  [EqSpec]
_    -> Maybe Type
forall a. Maybe a
Nothing

synthTestEquality, synthTestCoercion
  :: GenEnv -> Class -> CtLoc -> Type -> Type -> TcPluginM (Maybe (EvTerm, [Ct]))
synthTestEquality :: GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthTestEquality = Bool
-> GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike Bool
True
synthTestCoercion :: GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthTestCoercion = Bool
-> GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike Bool
False

-- | @useRefl = True@ ⇒ 'TestEquality' (@(:~:)@ \/ @Refl@); @False@ ⇒
-- 'TestCoercion' (@Coercion@).
synthEqLike :: Bool -> GenEnv -> Class -> CtLoc -> Type -> Type
            -> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike :: Bool
-> GenEnv
-> Class
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe (EvTerm, [Ct]))
synthEqLike Bool
useRefl GenEnv
gen Class
cls CtLoc
_loc Type
wrappedTy Type
f =
  case (GenEnv -> Maybe TyCon
geStock1 GenEnv
gen, Type -> Maybe TyCon
tyConAppTyCon_maybe Type
f) of
    (Just TyCon
st1Tc, Just TyCon
fTc)
      | ThetaType -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (HasDebugCallStack => Type -> ThetaType
Type -> ThetaType
tyConAppArgs Type
f)                       -- F is a bare one-param tycon
      , dcons :: [DataCon]
dcons@(DataCon
dc0 : [DataCon]
_) <- TyCon -> [DataCon]
tyConDataCons TyCon
fTc
      , Just ThetaType
pins <- (DataCon -> Maybe Type) -> [DataCon] -> Maybe ThetaType
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse DataCon -> Maybe Type
pinnedGround [DataCon]
dcons
      , (EqSpec
es0 : [EqSpec]
_) <- DataCon -> [EqSpec]
dcEqSpec DataCon
dc0
      -- the witness type (@(:~:)@ \/ @Coercion@) straight from the method's
      -- signature, so we never have to name a (re-exported) module.
      , (TyVar
meth : [TyVar]
_) <- Class -> [TyVar]
classMethods Class
cls
      , (TyCon
witTc : [TyCon]
_) <- [ TyCon
tc | TyCon
tc <- UniqSet TyCon -> [TyCon]
forall elt. UniqSet elt -> [elt]
nonDetEltsUniqSet (Type -> UniqSet TyCon
tyConsOfType (TyVar -> Type
varType TyVar
meth))
                            , Name -> OccName
nameOccName (TyCon -> Name
tyConName TyCon
tc)
                                OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== String -> OccName
mkTcOcc (if Bool
useRefl then String
":~:" else String
"Coercion") ] -> do
          let witCon :: DataCon
witCon = TyCon -> DataCon
tyConSingleDataCon TyCon
witTc
              kK :: Type
kK     = TyVar -> Type
tyVarKind ((TyVar, Type) -> TyVar
forall a b. (a, b) -> a
fst (EqSpec -> (TyVar, Type)
eqSpecPair EqSpec
es0))
              coAt :: Type -> Coercion
coAt   = GenEnv -> TyCon -> Type -> Type -> Type -> Type -> Coercion
coDown1 GenEnv
gen TyCon
st1Tc Type
wrappedTy Type
f Type
f
          TyVar
aTv <- Type -> String -> TcPluginM TyVar
freshTyVarK Type
kK String
"a"
          TyVar
bTv <- Type -> String -> TcPluginM TyVar
freshTyVarK Type
kK String
"b"
          TyVar
xId <- Type -> String -> TcPluginM TyVar
freshId (Type -> Type -> Type
mkAppTy Type
wrappedTy (TyVar -> Type
mkTyVarTy TyVar
aTv)) String
"x"
          TyVar
yId <- Type -> String -> TcPluginM TyVar
freshId (Type -> Type -> Type
mkAppTy Type
wrappedTy (TyVar -> Type
mkTyVarTy TyVar
bTv)) String
"y"
          TyVar
wbX <- Type -> String -> TcPluginM TyVar
freshId (TyCon -> ThetaType -> Type
mkTyConApp TyCon
fTc [TyVar -> Type
mkTyVarTy TyVar
aTv]) String
"wx"
          TyVar
wbY <- Type -> String -> TcPluginM TyVar
freshId (TyCon -> ThetaType -> Type
mkTyConApp TyCon
fTc [TyVar -> Type
mkTyVarTy TyVar
bTv]) String
"wy"
          let aTy :: Type
aTy = TyVar -> Type
mkTyVarTy TyVar
aTv ; bTy :: Type
bTy = TyVar -> Type
mkTyVarTy TyVar
bTv
              witOf :: Type -> Type -> Type
witOf Type
x Type
y = TyCon -> ThetaType -> Type
mkTyConApp TyCon
witTc [Type
kK, Type
x, Type
y]
              resTy :: Type
resTy     = TyCon -> ThetaType -> Type
mkTyConApp TyCon
maybeTyCon [Type -> Type -> Type
witOf Type
aTy Type
bTy]
              nothingE :: CoreExpr
nothingE  = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
nothingDataCon [Type -> CoreExpr
forall b. Type -> Expr b
Type (Type -> Type -> Type
witOf Type
aTy Type
bTy)]
              -- same index: cox : a~#t, coy : b~#t  ⇒  abCo : a~#b.
              --   Refl     :: forall k (a b). (b ~# a)     => a :~: b    (eqSpec)
              --   Coercion :: forall k (a b). Coercible a b => Coercion a b
              -- so we feed the proof directly; for Coercion we first box the
              -- representational coercion into a Coercible dictionary with the
              -- wired-in 'coercibleDataCon' (@MkCoercible :: (a ~R# b) ->
              -- Coercible a b@).  No Cast, no constraint solving.
              same :: TyVar -> TyVar -> CoreExpr
same TyVar
cox TyVar
coy =
                let abCo :: Coercion
abCo  = Coercion -> Coercion -> Coercion
mkTransCo (TyVar -> Coercion
mkCoVarCo TyVar
cox) (Coercion -> Coercion
mkSymCo (TyVar -> Coercion
mkCoVarCo TyVar
coy))
                    proof :: CoreExpr
proof | Bool
useRefl   = Coercion -> CoreExpr
forall b. Coercion -> Expr b
Coercion (Coercion -> Coercion
mkSymCo Coercion
abCo)   -- b ~# a (nominal)
                          | Bool
otherwise = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
coercibleDataCon
                                          [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
kK, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
aTy, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
bTy
                                          , Coercion -> CoreExpr
forall b. Coercion -> Expr b
Coercion (HasDebugCallStack => Coercion -> Coercion
Coercion -> Coercion
mkSubCo Coercion
abCo)]   -- a ~R# b boxed
                    wit :: CoreExpr
wit = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
witCon [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
kK, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
aTy, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
bTy, CoreExpr
proof]
                in DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps DataCon
justDataCon [Type -> CoreExpr
forall b. Type -> Expr b
Type (Type -> Type -> Type
witOf Type
aTy Type
bTy), CoreExpr
wit]
          -- testEquality compares the type /indices/, not constructor tags:
          -- two constructors pinning the same ground type ⇒ Just Refl.
          let innerAlts :: Type -> TyVar -> TcPluginM [Alt TyVar]
innerAlts Type
ti TyVar
cox = ((DataCon, Type) -> TcPluginM (Alt TyVar))
-> [(DataCon, Type)] -> TcPluginM [Alt TyVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (DataCon, Type) -> TcPluginM (Alt TyVar)
mkInner ([DataCon] -> ThetaType -> [(DataCon, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DataCon]
dcons ThetaType
pins)
                where mkInner :: (DataCon, Type) -> TcPluginM (Alt TyVar)
mkInner (DataCon
dcj, Type
tj) = do
                        TyVar
coy <- Type -> TcPluginM TyVar
freshCoVar (Type -> Type -> Type
mkPrimEqPred Type
bTy Type
tj)
                        let rhs :: CoreExpr
rhs = if Type -> Type -> Bool
eqType Type
ti Type
tj then TyVar -> TyVar -> CoreExpr
same TyVar
cox TyVar
coy else CoreExpr
nothingE
                        Alt TyVar -> TcPluginM (Alt TyVar)
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [TyVar] -> CoreExpr -> Alt TyVar
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dcj) [TyVar
coy] CoreExpr
rhs)
          [Alt TyVar]
outerAlts <- ((DataCon, Type) -> TcPluginM (Alt TyVar))
-> [(DataCon, Type)] -> TcPluginM [Alt TyVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
            (\(DataCon
dci, Type
ti) -> do
                TyVar
cox <- Type -> TcPluginM TyVar
freshCoVar (Type -> Type -> Type
mkPrimEqPred Type
aTy Type
ti)
                [Alt TyVar]
ialts <- Type -> TyVar -> TcPluginM [Alt TyVar]
innerAlts Type
ti TyVar
cox
                let inner :: CoreExpr
inner = CoreExpr -> TyVar -> Type -> [Alt TyVar] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast (TyVar -> CoreExpr
forall b. TyVar -> Expr b
Var TyVar
yId) (Type -> Coercion
coAt Type
bTy)) TyVar
wbY Type
resTy [Alt TyVar]
ialts
                Alt TyVar -> TcPluginM (Alt TyVar)
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [TyVar] -> CoreExpr -> Alt TyVar
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dci) [TyVar
cox] CoreExpr
inner))
            ([DataCon] -> ThetaType -> [(DataCon, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DataCon]
dcons ThetaType
pins)
          let impl :: CoreExpr
impl = [TyVar] -> CoreExpr -> CoreExpr
mkCoreLams [TyVar
aTv, TyVar
bTv, TyVar
xId, TyVar
yId] (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$
                       CoreExpr -> TyVar -> Type -> [Alt TyVar] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast (TyVar -> CoreExpr
forall b. TyVar -> Expr b
Var TyVar
xId) (Type -> Coercion
coAt Type
aTy)) TyVar
wbX Type
resTy [Alt TyVar]
outerAlts
              -- TestEquality/TestCoercion are poly-kinded (@class C (f :: k ->
              -- Type)@), so the dictionary takes the kind @k@ first.
              dict :: CoreExpr
dict = DataCon -> [CoreExpr] -> CoreExpr
mkCoreConApps (Class -> DataCon
classDataCon Class
cls) [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
kK, Type -> CoreExpr
forall b. Type -> Expr b
Type Type
wrappedTy, CoreExpr
impl]
          Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((EvTerm, [Ct]) -> Maybe (EvTerm, [Ct])
forall a. a -> Maybe a
Just (CoreExpr -> EvTerm
EvExpr CoreExpr
dict, []))
    (Maybe TyCon, Maybe TyCon)
_ -> Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (EvTerm, [Ct])
forall a. Maybe a
Nothing

freshCoVar :: Type -> TcPluginM CoVar
freshCoVar :: Type -> TcPluginM TyVar
freshCoVar Type
ty = do
  Unique
u <- TcM Unique -> TcPluginM Unique
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM TcM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
  TyVar -> TcPluginM TyVar
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Type -> TyVar
mkCoVar (Unique -> OccName -> Name
mkSystemName Unique
u (FastString -> OccName
mkVarOccFS (String -> FastString
fsLit String
"co"))) Type
ty)