{-# LANGUAGE CPP #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DerivingVia #-}
{-# OPTIONS_GHC -Wno-x-partial -Wno-incomplete-uni-patterns -Wno-unused-imports #-}
-- | @Eq@ synthesizer: two values are equal iff same constructor and all fields equal.
module Stock.Eq where
-- Most names below (data-con/type builders, coercion builders, occ-name
-- helpers, …) are re-exported by 'GHC.Plugins', so we only import explicitly
-- the ones it does not provide.
import GHC.Plugins hiding (TcPlugin)
import GHC.Tc.Plugin
import GHC.Tc.Types
import GHC.Tc.Types.Constraint
#if MIN_VERSION_ghc(9,12,0)
import GHC.Tc.Types.CtLoc (CtLoc)
#else
import GHC.Tc.Types.Constraint (CtLoc)
#endif
import GHC.Tc.Types.Evidence
import GHC.Tc.Utils.Monad (addErrTc)
import GHC.Tc.Errors.Types (mkTcRnUnknownMessage)
import GHC.Types.Error (mkPlainError, noHints)
import GHC.Core.Class (Class, className, classMethods, classOpItems, classTyCon)
import GHC.Core.Predicate (classifyPredType, Pred(ClassPred), mkClassPred)
import GHC.Builtin.Types.Prim (intPrimTy)
import GHC.Builtin.PrimOps (PrimOp(TagToEnumOp))
import GHC.Builtin.PrimOps.Ids (primOpId)
import GHC.Builtin.Names ( eqClassName, ordClassName, appendName
                         , enumClassName, mapName, numClassName
                         , enumFromToName, enumFromThenToName
                         , eqStringName
                         , genClassName, repTyConName, u1TyConName, k1TyConName
                         , prodTyConName, sumTyConName
                         , monoidClassName, foldableClassName, functorClassName
                         , semigroupClassName )
import Stock.Compat ( gHC_INTERNAL_SHOW, gHC_INTERNAL_READ
                    , gHC_INTERNAL_LIST, gHC_INTERNAL_GENERICS )
import GHC.Core.Reduction (mkReduction)
import GHC.Core.TyCo.Rep (UnivCoProvenance(PluginProv))
import GHC.Rename.Fixity (lookupFixityRn)
import GHC.Types.Fixity (Fixity(..), defaultFixity)
import GHC.Core.TyCo.Compare (eqType)
import GHC.Core.Multiplicity (scaledThing)
import GHC.Core.SimpleOpt (defaultSimpleOpts)
import GHC.Core.Unfold.Make (mkInlineUnfoldingWithArity)
import GHC.Core.InstEnv (classInstances, is_dfun, is_tys)
import GHC.Runtime.Loader (getValueSafely)
import Stock.Derive
import Data.Maybe (catMaybes, fromJust, isJust, fromMaybe)
import qualified Data.Monoid as Mon (Alt(..))  -- 'Alt' clashes with GHC.Core's case-alt 'Alt'
import Stock.Trans (MaybeT(..))
import Control.Monad (forM, zipWithM, unless, guard)
import Data.IORef (IORef, newIORef, readIORef, modifyIORef')
import Stock.Internal

-- @(==)@ is the SOP eliminator twice over: dispatch @a@, then @b@; equal
-- constructors conjoin their per-field @(==)@s (each field's @Eq@ a wanted),
-- mismatched constructors are @False@.  @(/=)@ negates @(==)@.
eqDeriver :: Deriver
eqDeriver :: Deriver
eqDeriver = (Class -> Datatype -> Synth EvTerm) -> Deriver
Deriver \Class
cls Datatype
dt -> do
  let via :: Type
via    = Datatype -> Type
dtVia Datatype
dt
      eqSel :: Id
eqSel  = String -> Class -> Id
classMethod String
"==" Class
cls
      true_ :: Expr b
true_  = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
trueDataCon)
      false_ :: Expr b
false_ = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
falseDataCon)
      -- x0==y0 && x1==y1 && … , short-circuiting via nested case; the last
      -- field is the bare comparison (as @&&@ and stock @deriving@ produce).
      eqField :: (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type
ft, Arg Id
x, Arg Id
y) = do d <- Class -> Type -> Synth (Arg Id)
field Class
cls Type
ft   -- the continuation: get Eq ft
                              pure (mkApps (Var eqSel) [Type ft, d, x, y])
      conjEq :: [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq []     = Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Arg Id
forall {b}. Expr b
true_
      conjEq [(Type, Arg Id, Arg Id)
t]    = (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type, Arg Id, Arg Id)
t
      conjEq ((Type, Arg Id, Arg Id)
t : [(Type, Arg Id, Arg Id)]
rest) = do
        e     <- (Type, Arg Id, Arg Id) -> Synth (Arg Id)
eqField (Type, Arg Id, Arg Id)
t
        restE <- conjEq rest
        scr   <- fresh boolTy "c"
        pure (Case e scr boolTy
                [ Alt (DataAlt falseDataCon) [] false_
                , Alt (DataAlt trueDataCon)  [] restE ])
  aId <- Type -> String -> Synth Id
fresh Type
via String
"a"
  bId <- fresh via "b"
  body <- matchSOP dt boolTy (Var aId) \Int
i Constructor
ci [Arg Id]
xs ->
          Datatype
-> Type
-> Arg Id
-> (Int -> Constructor -> [Arg Id] -> Synth (Arg Id))
-> Synth (Arg Id)
matchSOP Datatype
dt Type
boolTy (Id -> Arg Id
forall b. Id -> Expr b
Var Id
bId) \Int
j Constructor
_  [Arg Id]
ys ->
            if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then [(Type, Arg Id, Arg Id)] -> Synth (Arg Id)
conjEq ([Type] -> [Arg Id] -> [Arg Id] -> [(Type, Arg Id, Arg Id)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Constructor -> [Type]
conFields Constructor
ci) [Arg Id]
xs [Arg Id]
ys) else Arg Id -> Synth (Arg Id)
forall a. a -> Synth a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Arg Id
forall {b}. Expr b
false_
  let eqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] Arg Id
body
  na <- fresh via "a" ; nb <- fresh via "b" ; ns <- fresh boolTy "c"
  let neqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
na, Id
nb] (Arg Id -> Arg Id) -> Arg Id -> Arg Id
forall a b. (a -> b) -> a -> b
$
        Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps Arg Id
eqImpl [Id -> Arg Id
forall b. Id -> Expr b
Var Id
na, Id -> Arg Id
forall b. Id -> Expr b
Var Id
nb]) Id
ns Type
boolTy
          [ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
true_
          , AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon)  [] Arg Id
forall {b}. Expr b
false_ ]
  pure (classDict cls via [eqImpl, neqImpl])

-- | Pointwise @Semigroup@ for a single-constructor product: @C x.. \<\> C y.. =
-- C (x \<\> y)..@, each field combined with its own @(\<\>)@ (a wanted).  Same
-- result as @Generically@, synthesized statically (a \"faster Generically\").
-- @sconcat@\/@stimes@ come from the class defaults.
synthEq :: Class -> CtLoc -> Type -> Type -> Coercion -> [(DataCon, [Coercion])]
        -> TcPluginM (EvTerm, [Ct])
synthEq :: Class
-> CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> TcPluginM (EvTerm, [Ct])
synthEq Class
cls CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons = do
  let true_ :: Expr b
true_   = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
trueDataCon)
      false_ :: Expr b
false_  = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
falseDataCon)
      scrut :: Id -> Expr b
scrut Id
v = Expr b -> Coercion -> Expr b
forall b. Expr b -> Coercion -> Expr b
Cast (Id -> Expr b
forall b. Id -> Expr b
Var Id
v) Coercion
co               -- (v |> co) :: innerTy
      indexed :: [(Int, (DataCon, [Coercion]))]
indexed = [Int] -> [(DataCon, [Coercion])] -> [(Int, (DataCon, [Coercion]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 :: Int ..] [(DataCon, [Coercion])]
dcons
      realFts :: DataCon -> [Type]
realFts DataCon
dc = Type -> DataCon -> [Type]
fieldTysAt Type
innerTy DataCon
dc       -- field's real (bind) type

  aId <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"a"
  bId <- freshId wrappedTy "b"

  -- case (a|>co) of { Ci x.. -> case (b|>co) of { Cj y.. -> body i j } }
  outer <- forM indexed \(Int
i, (DataCon
dci, [Coercion]
cosI)) -> do
    xs <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dci)
    inner <- forM indexed \(Int
j, (DataCon
dcj, [Coercion]
_)) -> do
      ys <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"y" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dcj)
      if i == j
        then do
          (body, ws) <- conj loc (zip3 xs ys cosI)
          pure (Alt (DataAlt dcj) ys body, ws)
        else pure (Alt (DataAlt dcj) ys false_, [])
    innerBndr <- freshId innerTy "cb"
    let (ialts, iws) = unzip inner
    pure (Alt (DataAlt dci) xs (Case (scrut bId) innerBndr boolTy ialts), concat iws)

  outerBndr <- freshId innerTy "ca"
  let (oalts, ows) = unzip outer
      eqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] (Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Id -> Arg Id
forall b. Id -> Expr b
scrut Id
aId) Id
outerBndr Type
boolTy [Alt Id]
oalts)

  -- (/=) = \a b -> case (==) a b of { False -> True; True -> False }
  na <- freshId wrappedTy "a"
  nb <- freshId wrappedTy "b"
  ns <- freshId boolTy "c"
  let neqImpl = [Id] -> Arg Id -> Arg Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
na, Id
nb] (Arg Id -> Arg Id) -> Arg Id -> Arg Id
forall a b. (a -> b) -> a -> b
$
        Arg Id -> Id -> Type -> [Alt Id] -> Arg Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps Arg Id
eqImpl [Id -> Arg Id
forall b. Id -> Expr b
Var Id
na, Id -> Arg Id
forall b. Id -> Expr b
Var Id
nb]) Id
ns Type
boolTy
          [ AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
falseDataCon) [] Arg Id
forall {b}. Expr b
true_
          , AltCon -> [Id] -> Arg Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
trueDataCon)  [] Arg Id
forall {b}. Expr b
false_ ]
      dict = Class -> Type -> [Arg Id] -> Arg Id
mkClassDict Class
cls Type
wrappedTy [Arg Id
eqImpl, Arg Id
neqImpl]
  pure (EvExpr dict, concat ows)

-- | Conjoin per-field equalities — @and [x0 == y0, x1 == y1, …]@ — via 'andE'
-- (the short-circuiting @&&@ chain).  Each field's @Eq@ dictionary is a wanted.
-- Each triple is @(x, y, fieldCo)@; the field is compared at its modifier type
-- (@coercionRKind fieldCo@, the real type when 'Refl'), the bound values coerced.
conj :: CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (CoreExpr, [Ct])
conj :: CtLoc -> [(Id, Id, Coercion)] -> TcPluginM (Arg Id, [Ct])
conj CtLoc
loc [(Id, Id, Coercion)]
triples = do
  eqCls <- Name -> TcPluginM Class
tcLookupClass Name
eqClassName
  let eqSel = String -> Class -> Id
classMethod String
"==" Class
eqCls              -- (==)
  evs <- mapM (\(Id
_, Id
_, Coercion
fco) -> CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc (Class -> [Type] -> Type
mkClassPred Class
eqCls [Coercion -> Type
coercionRKind Coercion
fco])) triples
  let cmp ((Id
x, Id
y, Coercion
fco), CtEvidence
ev) = Arg Id -> [Arg Id] -> Arg Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> Arg Id
forall b. Id -> Expr b
Var Id
eqSel)
        [Type -> Arg Id
forall b. Type -> Expr b
Type (Coercion -> Type
coercionRKind Coercion
fco), HasDebugCallStack => CtEvidence -> Arg Id
CtEvidence -> Arg Id
ctEvExpr CtEvidence
ev, Arg Id -> Coercion -> Arg Id
castInto (Id -> Arg Id
forall b. Id -> Expr b
Var Id
x) Coercion
fco, Arg Id -> Coercion -> Arg Id
castInto (Id -> Arg Id
forall b. Id -> Expr b
Var Id
y) Coercion
fco]
  body <- andE (map cmp (zip triples evs))
  pure (body, map mkNonCanonical evs)

-- | Synthesize a full @Ord (Stock Inner)@ dictionary for any single-level
-- algebraic type: tag order between constructors, lexicographic within.  Every
-- comparison is derived from a single @compare@.  Returns the field @Ord@ and
-- @Eq@-superclass wanteds.