{-# LANGUAGE CPP #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DerivingVia #-}
{-# OPTIONS_GHC -Wno-x-partial -Wno-incomplete-uni-patterns -Wno-unused-imports #-}
module Stock.Eq where
import GHC.Plugins hiding (TcPlugin)
import GHC.Tc.Plugin
import GHC.Tc.Types
import GHC.Tc.Types.Constraint
#if MIN_VERSION_ghc(9,12,0)
import GHC.Tc.Types.CtLoc (CtLoc)
#else
import GHC.Tc.Types.Constraint (CtLoc)
#endif
import GHC.Tc.Types.Evidence
import GHC.Tc.Utils.Monad (addErrTc)
import GHC.Tc.Errors.Types (mkTcRnUnknownMessage)
import GHC.Types.Error (mkPlainError, noHints)
import GHC.Core.Class (Class, className, classMethods, classOpItems, classTyCon)
import GHC.Core.Predicate (classifyPredType, Pred(ClassPred), mkClassPred)
import GHC.Builtin.Types.Prim (intPrimTy)
import GHC.Builtin.PrimOps (PrimOp(TagToEnumOp))
import GHC.Builtin.PrimOps.Ids (primOpId)
import GHC.Builtin.Names ( eqClassName, ordClassName, appendName
, enumClassName, mapName, numClassName
, enumFromToName, enumFromThenToName
, eqStringName
, genClassName, repTyConName, u1TyConName, k1TyConName
, prodTyConName, sumTyConName
, monoidClassName, foldableClassName, functorClassName
, semigroupClassName )
import Stock.Compat ( gHC_INTERNAL_SHOW, gHC_INTERNAL_READ
, gHC_INTERNAL_LIST, gHC_INTERNAL_GENERICS )
import GHC.Core.Reduction (mkReduction)
import GHC.Core.TyCo.Rep (UnivCoProvenance(PluginProv))
import GHC.Rename.Fixity (lookupFixityRn)
import GHC.Types.Fixity (Fixity(..), defaultFixity)
import GHC.Core.TyCo.Compare (eqType)
import GHC.Core.Multiplicity (scaledThing)
import GHC.Core.SimpleOpt (defaultSimpleOpts)
import GHC.Core.Unfold.Make (mkInlineUnfoldingWithArity)
import GHC.Core.InstEnv (classInstances, is_dfun, is_tys)
import GHC.Runtime.Loader (getValueSafely)
import Stock.Derive
import Data.Maybe (catMaybes, fromJust, isJust, fromMaybe)
import qualified Data.Monoid as Mon (Alt(..))
import Stock.Trans (MaybeT(..))
import Control.Monad (forM, zipWithM, unless, guard)
import Data.IORef (IORef, newIORef, readIORef, modifyIORef')
import Stock.Internal
eqDeriver :: Deriver
eqDeriver :: Deriver
eqDeriver = (Class -> Datatype -> Synth EvTerm) -> Deriver
Deriver \Class
cls Datatype
dt -> do
let via :: Type
via = Datatype -> Type
dtVia Datatype
dt
eqSel :: Id
eqSel = String -> Class -> Id
classMethod String
"==" Class
cls
true_ :: Expr b
true_ = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
trueDataCon)
false_ :: Expr b
false_ = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
falseDataCon)
eqField :: (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type
ft, Arg Id
x, Arg Id
y) = do d <- Class -> Type -> Synth (Arg Id)
field Class
cls Type
ft
pure (mkApps (Var eqSel) [Type ft, d, x, y])
conjEq :: [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq [] = Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Arg Id
forall {b}. Expr b
true_
conjEq [(Type, Arg Id, Arg Id)
t] = (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type, Arg Id, Arg Id)
t
conjEq ((Type, Arg Id, Arg Id)
t : [(Type, Arg Id, Arg Id)]
rest) = do
e <- (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type, Arg Id, Arg Id)
t
restE <- conjEq rest
scr <- fresh boolTy "c"
pure (Case e scr boolTy
[ Alt (DataAlt falseDataCon) [] false_
, Alt (DataAlt trueDataCon) [] restE ])
aId <- Type -> String -> Synth Id
fresh Type
via String
"a"
bId <- fresh via "b"
body <- matchSOP dt boolTy (Var aId) \Int
i Constructor
ci [Arg Id]
xs ->
Datatype
-> Type
-> Arg Id
-> (Int -> Constructor -> [Arg Id] -> Synth (Arg Id))
-> Synth (Arg Id)
matchSOP Datatype
dt Type
boolTy (Id -> Arg Id
forall b. Id -> Expr b
Var Id
bId) \Int
j Constructor
_ [Arg Id]
ys ->
if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq ([Type] -> [Arg Id] -> [Arg Id] -> [(Type, Arg Id, Arg Id)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Constructor -> [Type]
conFields Constructor
ci) [Arg Id]
xs [Arg Id]
ys) else Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Arg Id
forall {b}. Expr b
false_
let eqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] Arg Id
body
na <- fresh via "a" ; nb <- fresh via "b" ; ns <- fresh boolTy "c"
let neqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
na, Id
nb] (Arg Id -> Arg Id) -> Arg Id -> Arg Id
forall a b. (a -> b) -> a -> b
$
Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps Arg Id
eqImpl [Id -> Arg Id
forall b. Id -> Expr b
Var Id
na, Id -> Arg Id
forall b. Id -> Expr b
Var Id
nb]) Id
ns Type
boolTy
[ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
true_
, AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon) [] Arg Id
forall {b}. Expr b
false_ ]
pure (classDict cls via [eqImpl, neqImpl])
synthEq :: Class -> CtLoc -> Type -> Type -> Coercion -> [(DataCon, [Coercion])]
-> TcPluginM (EvTerm, [Ct])
synthEq :: Class
-> CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> TcPluginM (EvTerm, [Ct])
synthEq Class
cls CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons = do
let true_ :: Expr b
true_ = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
trueDataCon)
false_ :: Expr b
false_ = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
falseDataCon)
scrut :: Id -> Expr b
scrut Id
v = Expr b -> Coercion -> Expr b
forall b. Expr b -> Coercion -> Expr b
Cast (Id -> Expr b
forall b. Id -> Expr b
Var Id
v) Coercion
co
indexed :: [(Int, (DataCon, [Coercion]))]
indexed = [Int] -> [(DataCon, [Coercion])] -> [(Int, (DataCon, [Coercion]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 :: Int ..] [(DataCon, [Coercion])]
dcons
realFts :: DataCon -> [Type]
realFts DataCon
dc = Type -> DataCon -> [Type]
fieldTysAt Type
innerTy DataCon
dc
aId <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"a"
bId <- freshId wrappedTy "b"
outer <- forM indexed \(Int
i, (DataCon
dci, [Coercion]
cosI)) -> do
xs <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dci)
inner <- forM indexed \(Int
j, (DataCon
dcj, [Coercion]
_)) -> do
ys <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"y" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dcj)
if i == j
then do
(body, ws) <- conj loc (zip3 xs ys cosI)
pure (Alt (DataAlt dcj) ys body, ws)
else pure (Alt (DataAlt dcj) ys false_, [])
innerBndr <- freshId innerTy "cb"
let (ialts, iws) = unzip inner
pure (Alt (DataAlt dci) xs (Case (scrut bId) innerBndr boolTy ialts), concat iws)
outerBndr <- freshId innerTy "ca"
let (oalts, ows) = unzip outer
eqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] (Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Id -> Arg Id
forall b. Id -> Expr b
scrut Id
aId) Id
outerBndr Type
boolTy [Alt Id]
oalts)
na <- freshId wrappedTy "a"
nb <- freshId wrappedTy "b"
ns <- freshId boolTy "c"
let neqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
na, Id
nb] (Arg Id -> Arg Id) -> Arg Id -> Arg Id
forall a b. (a -> b) -> a -> b
$
Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps Arg Id
eqImpl [Id -> Arg Id
forall b. Id -> Expr b
Var Id
na, Id -> Arg Id
forall b. Id -> Expr b
Var Id
nb]) Id
ns Type
boolTy
[ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
true_
, AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon) [] Arg Id
forall {b}. Expr b
false_ ]
dict = Class -> Type -> [Arg Id] -> Arg Id
mkClassDict Class
cls Type
wrappedTy [Arg Id
eqImpl, Arg Id
neqImpl]
pure (EvExpr dict, concat ows)
conj :: CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (CoreExpr, [Ct])
conj :: CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (Arg Id, [Ct])
conj CtLoc
loc [(Id, Id, Coercion)]
triples = do
eqCls <- Name -> TcPluginM Class
tcLookupClass Name
eqClassName
let eqSel = String -> Class -> Id
classMethod String
"==" Class
eqCls
evs <- mapM (\(Id
_, Id
_, Coercion
fco) -> CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc (Class -> [Type] -> Type
mkClassPred Class
eqCls [Coercion -> Type
coercionRKind Coercion
fco])) triples
let cmp ((Id
x, Id
y, Coercion
fco), CtEvidence
ev) = Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> Arg Id
forall b. Id -> Expr b
Var Id
eqSel)
[Type -> Arg Id
forall b. Type -> Expr b
Type (Coercion -> Type
coercionRKind Coercion
fco), HasDebugCallStack => CtEvidence -> Arg Id
CtEvidence -> Arg Id
ctEvExpr CtEvidence
ev, Arg Id -> Coercion -> Arg Id
castInto (Id -> Arg Id
forall b. Id -> Expr b
Var Id
x) Coercion
fco, Arg Id -> Coercion -> Arg Id
castInto (Id -> Arg Id
forall b. Id -> Expr b
Var Id
y) Coercion
fco]
body <- andE (map cmp (zip triples evs))
pure (body, map mkNonCanonical evs)