{-# LANGUAGE CPP #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DerivingVia #-}
{-# OPTIONS_GHC -Wno-x-partial -Wno-incomplete-uni-patterns -Wno-unused-imports #-}
-- | @Eq@ synthesizer: two values are equal iff same constructor and all fields equal.
module Stock.Eq where
-- Most names below (data-con/type builders, coercion builders, occ-name
-- helpers, …) are re-exported by 'GHC.Plugins', so we only import explicitly
-- the ones it does not provide.
import GHC.Plugins hiding (TcPlugin)
import GHC.Tc.Plugin
import GHC.Tc.Types
import GHC.Tc.Types.Constraint
#if MIN_VERSION_ghc(9,12,0)
import GHC.Tc.Types.CtLoc (CtLoc)
#else
import GHC.Tc.Types.Constraint (CtLoc)
#endif
import GHC.Tc.Types.Evidence
import GHC.Tc.Utils.Monad (addErrTc)
import GHC.Tc.Errors.Types (mkTcRnUnknownMessage)
import GHC.Types.Error (mkPlainError, noHints)
import GHC.Core.Class (Class, className, classMethods, classOpItems, classTyCon)
import GHC.Core.Predicate (classifyPredType, Pred(ClassPred), mkClassPred)
import GHC.Builtin.Types.Prim (intPrimTy)
import GHC.Builtin.PrimOps (PrimOp(TagToEnumOp))
import GHC.Builtin.PrimOps.Ids (primOpId)
import GHC.Builtin.Names ( eqClassName, ordClassName, appendName
                         , enumClassName, mapName, numClassName
                         , enumFromToName, enumFromThenToName
                         , eqStringName
                         , genClassName, repTyConName, u1TyConName, k1TyConName
                         , prodTyConName, sumTyConName
                         , monoidClassName, foldableClassName, functorClassName
                         , semigroupClassName )
import Stock.Compat ( gHC_INTERNAL_SHOW, gHC_INTERNAL_READ
                    , gHC_INTERNAL_LIST, gHC_INTERNAL_GENERICS )
import GHC.Core.Reduction (mkReduction)
import GHC.Core.TyCo.Rep (UnivCoProvenance(PluginProv))
import GHC.Rename.Fixity (lookupFixityRn)
import GHC.Types.Fixity (Fixity(..), defaultFixity)
import GHC.Core.TyCo.Compare (eqType)
import GHC.Core.Multiplicity (scaledThing)
import GHC.Core.SimpleOpt (defaultSimpleOpts)
import GHC.Core.Unfold.Make (mkInlineUnfoldingWithArity)
import GHC.Core.InstEnv (classInstances, is_dfun, is_tys)
import GHC.Runtime.Loader (getValueSafely)
import Stock.Derive
import Data.Maybe (catMaybes, fromJust, isJust, fromMaybe)
import qualified Data.Monoid as Mon (Alt(..))  -- 'Alt' clashes with GHC.Core's case-alt 'Alt'
import Stock.Trans (MaybeT(..))
import Control.Monad (forM, zipWithM, unless, guard)
import Data.IORef (IORef, newIORef, readIORef, modifyIORef')
import Stock.Internal

-- @(==)@ is the SOP eliminator twice over: dispatch @a@, then @b@; equal
-- constructors conjoin their per-field @(==)@s (each field's @Eq@ a wanted),
-- mismatched constructors are @False@.  @(/=)@ negates @(==)@.
eqDeriver :: Deriver
eqDeriver :: Deriver
eqDeriver = (Class -> Datatype -> Synth EvTerm) -> Deriver
Deriver \Class
cls Datatype
dt -> do
  let via :: Type
via    = Datatype -> Type
dtVia Datatype
dt
      eqSel :: Id
eqSel  = String -> Class -> Id
classMethod String
"==" Class
cls
      true_ :: Expr b
true_  = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
trueDataCon)
      false_ :: Expr b
false_ = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
falseDataCon)
      -- x0==y0 && x1==y1 && … , short-circuiting via nested case; the last
      -- field is the bare comparison (as @&&@ and stock @deriving@ produce).
      eqField :: (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type
ft, Arg Id
x, Arg Id
y) = do Arg Id
d <- Class -> Type -> Synth (Arg Id)
field Class
cls Type
ft   -- the continuation: get Eq ft
                              Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> Arg Id
forall b. Id -> Expr b
Var Id
eqSel) [Type -> Arg Id
forall b. Type -> Expr b
Type Type
ft, Arg Id
d, Arg Id
x, Arg Id
y])
      conjEq :: [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq []     = Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Arg Id
forall {b}. Expr b
true_
      conjEq [(Type, Arg Id, Arg Id)
t]    = (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type, Arg Id, Arg Id)
t
      conjEq ((Type, Arg Id, Arg Id)
t : [(Type, Arg Id, Arg Id)]
rest) = do
        Arg Id
e     <- (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type, Arg Id, Arg Id)
t
        Arg Id
restE <- [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq [(Type, Arg Id, Arg Id)]
rest
        Id
scr   <- Type -> String -> Synth Id
fresh Type
boolTy String
"c"
        Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Arg Id
e Id
scr Type
boolTy
                [ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
false_
                , AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon)  [] Arg Id
restE ])
  Id
aId <- Type -> String -> Synth Id
fresh Type
via String
"a"
  Id
bId <- Type -> String -> Synth Id
fresh Type
via String
"b"
  Arg Id
body <- Datatype
-> Type
-> Arg Id
-> (Int -> Constructor -> [Arg Id] -> Synth (Arg Id))
-> Synth (Arg Id)
matchSOP Datatype
dt Type
boolTy (Id -> Arg Id
forall b. Id -> Expr b
Var Id
aId) \Int
i Constructor
ci [Arg Id]
xs ->
          Datatype
-> Type
-> Arg Id
-> (Int -> Constructor -> [Arg Id] -> Synth (Arg Id))
-> Synth (Arg Id)
matchSOP Datatype
dt Type
boolTy (Id -> Arg Id
forall b. Id -> Expr b
Var Id
bId) \Int
j Constructor
_  [Arg Id]
ys ->
            if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq ([Type] -> [Arg Id] -> [Arg Id] -> [(Type, Arg Id, Arg Id)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Constructor -> [Type]
conFields Constructor
ci) [Arg Id]
xs [Arg Id]
ys) else Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Arg Id
forall {b}. Expr b
false_
  let eqImpl :: Arg Id
eqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] Arg Id
body
  Id
na <- Type -> String -> Synth Id
fresh Type
via String
"a" ; Id
nb <- Type -> String -> Synth Id
fresh Type
via String
"b" ; Id
ns <- Type -> String -> Synth Id
fresh Type
boolTy String
"c"
  let neqImpl :: Arg Id
neqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
na, Id
nb] (Arg Id -> Arg Id) -> Arg Id -> Arg Id
forall a b. (a -> b) -> a -> b
$
        Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps Arg Id
eqImpl [Id -> Arg Id
forall b. Id -> Expr b
Var Id
na, Id -> Arg Id
forall b. Id -> Expr b
Var Id
nb]) Id
ns Type
boolTy
          [ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
true_
          , AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon)  [] Arg Id
forall {b}. Expr b
false_ ]
  EvTerm -> Synth EvTerm
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Class -> Type -> [Arg Id] -> EvTerm
classDict Class
cls Type
via [Arg Id
eqImpl, Arg Id
neqImpl])

-- | Pointwise @Semigroup@ for a single-constructor product: @C x.. \<\> C y.. =
-- C (x \<\> y)..@, each field combined with its own @(\<\>)@ (a wanted).  Same
-- result as @Generically@, synthesized statically (a \"faster Generically\").
-- @sconcat@\/@stimes@ come from the class defaults.
synthEq :: Class -> CtLoc -> Type -> Type -> Coercion -> [(DataCon, [Coercion])]
        -> TcPluginM (EvTerm, [Ct])
synthEq :: Class
-> CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> TcPluginM (EvTerm, [Ct])
synthEq Class
cls CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons = do
  let true_ :: Expr b
true_   = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
trueDataCon)
      false_ :: Expr b
false_  = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
falseDataCon)
      scrut :: Id -> Expr b
scrut Id
v = Expr b -> Coercion -> Expr b
forall b. Expr b -> Coercion -> Expr b
Cast (Id -> Expr b
forall b. Id -> Expr b
Var Id
v) Coercion
co               -- (v |> co) :: innerTy
      indexed :: [(Int, (DataCon, [Coercion]))]
indexed = [Int] -> [(DataCon, [Coercion])] -> [(Int, (DataCon, [Coercion]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 :: Int ..] [(DataCon, [Coercion])]
dcons
      realFts :: DataCon -> [Type]
realFts DataCon
dc = Type -> DataCon -> [Type]
fieldTysAt Type
innerTy DataCon
dc       -- field's real (bind) type

  Id
aId <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"a"
  Id
bId <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"b"

  -- case (a|>co) of { Ci x.. -> case (b|>co) of { Cj y.. -> body i j } }
  [(Alt Id, [Ct])]
outer <- [(Int, (DataCon, [Coercion]))]
-> ((Int, (DataCon, [Coercion])) -> TcPluginM (Alt Id, [Ct]))
-> TcPluginM [(Alt Id, [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Int, (DataCon, [Coercion]))]
indexed \(Int
i, (DataCon
dci, [Coercion]
cosI)) -> do
    [Id]
xs <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dci)
    [(Alt Id, [Ct])]
inner <- [(Int, (DataCon, [Coercion]))]
-> ((Int, (DataCon, [Coercion])) -> TcPluginM (Alt Id, [Ct]))
-> TcPluginM [(Alt Id, [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Int, (DataCon, [Coercion]))]
indexed \(Int
j, (DataCon
dcj, [Coercion]
_)) -> do
      [Id]
ys <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"y" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dcj)
      if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j
        then do
          (Arg Id
body, [Ct]
ws) <- CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (Arg Id, [Ct])
conj CtLoc
loc ([Id] -> [Id] -> [Coercion] -> [(Id, Id, Coercion)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Id]
xs [Id]
ys [Coercion]
cosI)
          (Alt Id, [Ct]) -> TcPluginM (Alt Id, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dcj) [Id]
ys Arg Id
body, [Ct]
ws)
        else (Alt Id, [Ct]) -> TcPluginM (Alt Id, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dcj) [Id]
ys Arg Id
forall {b}. Expr b
false_, [])
    Id
innerBndr <- Type -> String -> TcPluginM Id
freshId Type
innerTy String
"cb"
    let ([Alt Id]
ialts, [[Ct]]
iws) = [(Alt Id, [Ct])] -> ([Alt Id], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Alt Id, [Ct])]
inner
    (Alt Id, [Ct]) -> TcPluginM (Alt Id, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
dci) [Id]
xs (Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Id -> Arg Id
forall b. Id -> Expr b
scrut Id
bId) Id
innerBndr Type
boolTy [Alt Id]
ialts), [[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
iws)

  Id
outerBndr <- Type -> String -> TcPluginM Id
freshId Type
innerTy String
"ca"
  let ([Alt Id]
oalts, [[Ct]]
ows) = [(Alt Id, [Ct])] -> ([Alt Id], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Alt Id, [Ct])]
outer
      eqImpl :: Arg Id
eqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] (Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Id -> Arg Id
forall b. Id -> Expr b
scrut Id
aId) Id
outerBndr Type
boolTy [Alt Id]
oalts)

  -- (/=) = \a b -> case (==) a b of { False -> True; True -> False }
  Id
na <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"a"
  Id
nb <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"b"
  Id
ns <- Type -> String -> TcPluginM Id
freshId Type
boolTy String
"c"
  let neqImpl :: Arg Id
neqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
na, Id
nb] (Arg Id -> Arg Id) -> Arg Id -> Arg Id
forall a b. (a -> b) -> a -> b
$
        Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps Arg Id
eqImpl [Id -> Arg Id
forall b. Id -> Expr b
Var Id
na, Id -> Arg Id
forall b. Id -> Expr b
Var Id
nb]) Id
ns Type
boolTy
          [ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
true_
          , AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon)  [] Arg Id
forall {b}. Expr b
false_ ]
      dict :: Arg Id
dict = Class -> Type -> [Arg Id] -> Arg Id
mkClassDict Class
cls Type
wrappedTy [Arg Id
eqImpl, Arg Id
neqImpl]
  (EvTerm, [Ct]) -> TcPluginM (EvTerm, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Arg Id -> EvTerm
EvExpr Arg Id
dict, [[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
ows)

-- | Conjoin per-field equalities — @and [x0 == y0, x1 == y1, …]@ — via 'andE'
-- (the short-circuiting @&&@ chain).  Each field's @Eq@ dictionary is a wanted.
-- Each triple is @(x, y, fieldCo)@; the field is compared at its modifier type
-- (@coercionRKind fieldCo@, the real type when 'Refl'), the bound values coerced.
conj :: CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (CoreExpr, [Ct])
conj :: CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (Arg Id, [Ct])
conj CtLoc
loc [(Id, Id, Coercion)]
triples = do
  Class
eqCls <- Name -> TcPluginM Class
tcLookupClass Name
eqClassName
  let eqSel :: Id
eqSel = String -> Class -> Id
classMethod String
"==" Class
eqCls              -- (==)
  [CtEvidence]
evs <- ((Id, Id, Coercion) -> TcPluginM CtEvidence)
-> [(Id, Id, Coercion)] -> TcPluginM [CtEvidence]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Id
_, Id
_, Coercion
fco) -> CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc (Class -> [Type] -> Type
mkClassPred Class
eqCls [Coercion -> Type
coercionRKind Coercion
fco])) [(Id, Id, Coercion)]
triples
  let cmp :: ((Id, Id, Coercion), CtEvidence) -> Arg Id
cmp ((Id
x, Id
y, Coercion
fco), CtEvidence
ev) = Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> Arg Id
forall b. Id -> Expr b
Var Id
eqSel)
        [Type -> Arg Id
forall b. Type -> Expr b
Type (Coercion -> Type
coercionRKind Coercion
fco), HasDebugCallStack => CtEvidence -> Arg Id
CtEvidence -> Arg Id
ctEvExpr CtEvidence
ev, Arg Id -> Coercion -> Arg Id
castInto (Id -> Arg Id
forall b. Id -> Expr b
Var Id
x) Coercion
fco, Arg Id -> Coercion -> Arg Id
castInto (Id -> Arg Id
forall b. Id -> Expr b
Var Id
y) Coercion
fco]
  Arg Id
body <- [Arg Id] -> TcPluginM (Arg Id)
andE ((((Id, Id, Coercion), CtEvidence) -> Arg Id)
-> [((Id, Id, Coercion), CtEvidence)] -> [Arg Id]
forall a b. (a -> b) -> [a] -> [b]
map ((Id, Id, Coercion), CtEvidence) -> Arg Id
cmp ([(Id, Id, Coercion)]
-> [CtEvidence] -> [((Id, Id, Coercion), CtEvidence)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Id, Id, Coercion)]
triples [CtEvidence]
evs))
  (Arg Id, [Ct]) -> TcPluginM (Arg Id, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Arg Id
body, (CtEvidence -> Ct) -> [CtEvidence] -> [Ct]
forall a b. (a -> b) -> [a] -> [b]
map CtEvidence -> Ct
mkNonCanonical [CtEvidence]
evs)

-- | Synthesize a full @Ord (Stock Inner)@ dictionary for any single-level
-- algebraic type: tag order between constructors, lexicographic within.  Every
-- comparison is derived from a single @compare@.  Returns the field @Ord@ and
-- @Eq@-superclass wanteds.