{-# LANGUAGE CPP #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DerivingVia #-}
{-# OPTIONS_GHC -Wno-x-partial -Wno-incomplete-uni-patterns -Wno-unused-imports #-}
module Stock.Ord where
import GHC.Plugins hiding (TcPlugin)
import GHC.Tc.Plugin
import GHC.Tc.Types
import GHC.Tc.Types.Constraint
#if MIN_VERSION_ghc(9,12,0)
import GHC.Tc.Types.CtLoc (CtLoc)
#else
import GHC.Tc.Types.Constraint (CtLoc)
#endif
import GHC.Tc.Types.Evidence
import GHC.Tc.Utils.Monad (addErrTc)
import GHC.Tc.Errors.Types (mkTcRnUnknownMessage)
import GHC.Types.Error (mkPlainError, noHints)
import GHC.Core.Class (Class, className, classMethods, classOpItems, classTyCon)
import GHC.Core.Predicate (classifyPredType, Pred(ClassPred), mkClassPred)
import GHC.Builtin.Types.Prim (intPrimTy)
import GHC.Builtin.PrimOps (PrimOp(TagToEnumOp))
import GHC.Builtin.PrimOps.Ids (primOpId)
import GHC.Builtin.Names ( eqClassName, ordClassName, appendName
, enumClassName, mapName, numClassName
, enumFromToName, enumFromThenToName
, eqStringName
, genClassName, repTyConName, u1TyConName, k1TyConName
, prodTyConName, sumTyConName
, monoidClassName, foldableClassName, functorClassName
, semigroupClassName )
import Stock.Compat ( gHC_INTERNAL_SHOW, gHC_INTERNAL_READ
, gHC_INTERNAL_LIST, gHC_INTERNAL_GENERICS )
import GHC.Core.Reduction (mkReduction)
import GHC.Core.TyCo.Rep (UnivCoProvenance(PluginProv))
import GHC.Rename.Fixity (lookupFixityRn)
import GHC.Types.Fixity (Fixity(..), defaultFixity)
import GHC.Core.TyCo.Compare (eqType)
import GHC.Core.Multiplicity (scaledThing)
import GHC.Core.SimpleOpt (defaultSimpleOpts)
import GHC.Core.Unfold.Make (mkInlineUnfoldingWithArity)
import GHC.Core.InstEnv (classInstances, is_dfun, is_tys)
import GHC.Runtime.Loader (getValueSafely)
import Stock.Derive
import Data.Maybe (catMaybes, fromJust, isJust, fromMaybe)
import qualified Data.Monoid as Mon (Alt(..))
import Stock.Trans (MaybeT(..))
import Control.Monad (forM, zipWithM, unless, guard)
import Data.IORef (IORef, newIORef, readIORef, modifyIORef')
import Stock.Internal
import Stock.Eq
buildCompare :: CtLoc -> Type -> Type -> Coercion -> [(DataCon, [Coercion])]
-> TcPluginM (CoreExpr, [Ct])
buildCompare :: CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> TcPluginM (CoreExpr, [Ct])
buildCompare CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons = do
ordCls <- Name -> TcPluginM Class
tcLookupClass Name
ordClassName
let ordTy = TyCon -> Type
mkTyConTy TyCon
orderingTyCon
[ltC, eqC, gtC] = tyConDataCons orderingTyCon
ltE = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
ltC); eqE = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
eqC); gtE = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId DataCon
gtC)
cmpSel = String -> Class -> Id
classMethod String
"compare" Class
ordCls
scrut Id
v = Expr b -> Coercion -> Expr b
forall b. Expr b -> Coercion -> Expr b
Cast (Id -> Expr b
forall b. Id -> Expr b
Var Id
v) Coercion
co
indexed = [Int] -> [(DataCon, [Coercion])] -> [(Int, (DataCon, [Coercion]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 :: Int ..] [(DataCon, [Coercion])]
dcons
realFts DataCon
dc = Type -> DataCon -> [Type]
fieldTysAt Type
innerTy DataCon
dc
lexCmp [] = (CoreExpr, [Ct]) -> TcPluginM (CoreExpr, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoreExpr
forall {b}. Expr b
eqE, [])
lexCmp ((Coercion
fco, Id
x, Id
y) : [(Coercion, Id, Id)]
more) = do
let ft :: Type
ft = Coercion -> Type
coercionRKind Coercion
fco
ev <- CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc (Class -> [Type] -> Type
mkClassPred Class
ordCls [Type
ft])
(restE, ws) <- lexCmp more
scr <- freshId ordTy "o"
let cmp = CoreExpr -> [CoreExpr] -> CoreExpr
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
cmpSel) [Type -> CoreExpr
forall b. Type -> Expr b
Type Type
ft, HasDebugCallStack => CtEvidence -> CoreExpr
CtEvidence -> CoreExpr
ctEvExpr CtEvidence
ev, CoreExpr -> Coercion -> CoreExpr
castInto (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
x) Coercion
fco, CoreExpr -> Coercion -> CoreExpr
castInto (Id -> CoreExpr
forall b. Id -> Expr b
Var Id
y) Coercion
fco]
e = CoreExpr -> Id -> Type -> [Alt Id] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
cmp Id
scr Type
ordTy
[ AltCon -> [Id] -> CoreExpr -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
ltC) [] CoreExpr
forall {b}. Expr b
ltE
, AltCon -> [Id] -> CoreExpr -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
eqC) [] CoreExpr
restE
, AltCon -> [Id] -> CoreExpr -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt (DataCon -> AltCon
DataAlt DataCon
gtC) [] CoreExpr
forall {b}. Expr b
gtE ]
pure (e, mkNonCanonical ev : ws)
aId <- freshId wrappedTy "a"
bId <- freshId wrappedTy "b"
(outerAlts, wss) <- fmap unzip $ forM indexed \(Int
i, (DataCon
dci, [Coercion]
cosI)) -> do
xs <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dci)
(innerAlts, iwss) <- fmap unzip $ forM indexed \(Int
j, (DataCon
dcj, [Coercion]
_)) -> do
ys <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"y" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dcj)
(body, ws) <- if i == j
then lexCmp (zip3 cosI xs ys)
else pure (if i < j then ltE else gtE, [])
pure (Alt (DataAlt dcj) ys body, ws)
innerBndr <- freshId innerTy "cb"
pure (Alt (DataAlt dci) xs (Case (scrut bId) innerBndr ordTy innerAlts), concat iwss)
outerBndr <- freshId innerTy "ca"
let cmpImpl = [Id] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [Id
aId, Id
bId] (CoreExpr -> Id -> Type -> [Alt Id] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (Id -> CoreExpr
forall b. Id -> Expr b
scrut Id
aId) Id
outerBndr Type
ordTy [Alt Id]
outerAlts)
pure (cmpImpl, concat wss)
buildRel :: Class -> Class -> CtLoc -> Type -> Type -> Coercion -> [(DataCon, [Coercion])]
-> Bool -> Bool -> TcPluginM (CoreExpr, [Ct])
buildRel :: Class
-> Class
-> CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> Bool
-> Bool
-> TcPluginM (CoreExpr, [Ct])
buildRel Class
ordCls Class
eqCls CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons Bool
asc Bool
refl = do
let boolE :: Bool -> Expr b
boolE Bool
b = Id -> Expr b
forall b. Id -> Expr b
Var (DataCon -> Id
dataConWorkId (if Bool
b then DataCon
trueDataCon else DataCon
falseDataCon))
ltName :: String
ltName = if Bool
asc then String
"<" else String
">"
lastName :: String
lastName | Bool
asc Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
refl = String
"<" | Bool
asc = String
"<=" | Bool -> Bool
not Bool
refl = String
">" | Bool
otherwise = String
">="
scrut :: Id -> Expr b
scrut Id
v = Expr b -> Coercion -> Expr b
forall b. Expr b -> Coercion -> Expr b
Cast (Id -> Expr b
forall b. Id -> Expr b
Var Id
v) Coercion
co
realFts :: DataCon -> [Type]
realFts DataCon
dc = Type -> DataCon -> [Type]
fieldTysAt Type
innerTy DataCon
dc
indexed :: [(Int, (DataCon, [Coercion]))]
indexed = [Int] -> [(DataCon, [Coercion])] -> [(Int, (DataCon, [Coercion]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 :: Int ..] [(DataCon, [Coercion])]
dcons
fieldRel :: String -> Coercion -> Id -> Id -> TcPluginM (CoreExpr, [Ct])
fieldRel String
nm Coercion
fco Id
x Id
y = do
let ft :: Type
ft = Coercion -> Type
coercionRKind Coercion
fco
ev <- CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc (Class -> [Type] -> Type
mkClassPred Class
ordCls [Type
ft])
pure ( mkApps (Var (classMethod nm ordCls))
[Type ft, ctEvExpr ev, castInto (Var x) fco, castInto (Var y) fco]
, [mkNonCanonical ev] )
fieldEq :: Coercion -> Id -> Id -> TcPluginM (CoreExpr, [Ct])
fieldEq Coercion
fco Id
x Id
y = do
let ft :: Type
ft = Coercion -> Type
coercionRKind Coercion
fco
ev <- CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc (Class -> [Type] -> Type
mkClassPred Class
eqCls [Type
ft])
pure ( mkApps (Var (classMethod "==" eqCls))
[Type ft, ctEvExpr ev, castInto (Var x) fco, castInto (Var y) fco]
, [mkNonCanonical ev] )
orE :: CoreExpr -> CoreExpr -> TcPluginM CoreExpr
orE CoreExpr
p CoreExpr
q = do s <- Type -> String -> TcPluginM Id
freshId Type
boolTy String
"o"
pure (Case p s boolTy [ Alt (DataAlt falseDataCon) [] q
, Alt (DataAlt trueDataCon) [] (boolE True) ])
andE2 :: CoreExpr -> CoreExpr -> TcPluginM CoreExpr
andE2 CoreExpr
p CoreExpr
q = do s <- Type -> String -> TcPluginM Id
freshId Type
boolTy String
"n"
pure (Case p s boolTy [ Alt (DataAlt falseDataCon) [] (boolE False)
, Alt (DataAlt trueDataCon) [] q ])
lexRel :: [(Coercion, Id, Id)] -> TcPluginM (CoreExpr, [Ct])
lexRel [] = (CoreExpr, [Ct]) -> TcPluginM (CoreExpr, [Ct])
forall a. a -> TcPluginM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> CoreExpr
forall {b}. Bool -> Expr b
boolE Bool
refl, [])
lexRel [(Coercion
fco, Id
x, Id
y)] = String -> Coercion -> Id -> Id -> TcPluginM (CoreExpr, [Ct])
fieldRel String
lastName Coercion
fco Id
x Id
y
lexRel ((Coercion
fco, Id
x, Id
y) : [(Coercion, Id, Id)]
more) = do
(ltE, w1) <- String -> Coercion -> Id -> Id -> TcPluginM (CoreExpr, [Ct])
fieldRel String
ltName Coercion
fco Id
x Id
y
(eqE, w2) <- fieldEq fco x y
(rest, w3) <- lexRel more
ae <- andE2 eqE rest
oe <- orE ltE ae
pure (oe, w1 ++ w2 ++ w3)
aId <- Type -> String -> TcPluginM Id
freshId Type
wrappedTy String
"a" ; bId <- freshId wrappedTy "b"
(outerAlts, wss) <- fmap unzip $ forM indexed \(Int
i, (DataCon
dci, [Coercion]
cosI)) -> do
xs <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dci)
(innerAlts, iwss) <- fmap unzip $ forM indexed \(Int
j, (DataCon
dcj, [Coercion]
_)) -> do
ys <- (Int -> Type -> TcPluginM Id) -> [Int] -> [Type] -> TcPluginM [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Int
n Type
ft -> Type -> String -> TcPluginM Id
freshId Type
ft (String
"y" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n)) [Int
0 :: Int ..] (DataCon -> [Type]
realFts DataCon
dcj)
(body, ws) <- if i == j then lexRel (zip3 cosI xs ys)
else pure (boolE (if asc then i < j else i > j), [])
pure (Alt (DataAlt dcj) ys body, ws)
cb <- freshId innerTy "cb"
pure (Alt (DataAlt dci) xs (Case (scrut bId) cb boolTy innerAlts), concat iwss)
cb2 <- freshId innerTy "ca"
pure (mkLams [aId, bId] (Case (scrut aId) cb2 boolTy outerAlts), concat wss)
synthOrd :: Class -> CtLoc -> Type -> Type -> Coercion -> [(DataCon, [Coercion])]
-> TcPluginM (EvTerm, [Ct])
synthOrd :: Class
-> CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> TcPluginM (EvTerm, [Ct])
synthOrd Class
ordCls CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons = do
(cmpImpl, cmpWs) <- CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> TcPluginM (CoreExpr, [Ct])
buildCompare CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons
eqCls <- tcLookupClass eqClassName
(eqDict0, eqWs) <- synthEq eqCls loc wrappedTy innerTy co dcons
let eqDict = EvTerm -> CoreExpr
unwrapEv EvTerm
eqDict0
let overridden = (Coercion -> Bool) -> [Coercion] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> Bool
not (Bool -> Bool) -> (Coercion -> Bool) -> Coercion -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Coercion -> Bool
isReflCo) (((DataCon, [Coercion]) -> [Coercion])
-> [(DataCon, [Coercion])] -> [Coercion]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DataCon, [Coercion]) -> [Coercion]
forall a b. (a, b) -> b
snd [(DataCon, [Coercion])]
dcons)
small = [(DataCon, [Coercion])] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(DataCon, [Coercion])]
dcons Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
3 Bool -> Bool -> Bool
|| ((DataCon, [Coercion]) -> Bool) -> [(DataCon, [Coercion])] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([Coercion] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Coercion] -> Bool)
-> ((DataCon, [Coercion]) -> [Coercion])
-> (DataCon, [Coercion])
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DataCon, [Coercion]) -> [Coercion]
forall a b. (a, b) -> b
snd) [(DataCon, [Coercion])]
dcons
idxOf String
nm = [Int] -> Int
forall a. HasCallStack => [a] -> a
head [ Int
i | (Int
i, Id
m) <- [Int] -> [Id] -> [(Int, Id)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 :: Int ..] (Class -> [Id]
classMethods Class
ordCls)
, OccName -> String
occNameString (Id -> OccName
forall name. HasOccName name => name -> OccName
occName Id
m) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
nm ]
(relOverrides, relWs) <-
if not small then pure ([], [])
else do
let mk Bool
asc Bool
refl = Class
-> Class
-> CtLoc
-> Type
-> Type
-> Coercion
-> [(DataCon, [Coercion])]
-> Bool
-> Bool
-> TcPluginM (CoreExpr, [Ct])
buildRel Class
ordCls Class
eqCls CtLoc
loc Type
wrappedTy Type
innerTy Coercion
co [(DataCon, [Coercion])]
dcons Bool
asc Bool
refl
(ltI, w1) <- mk True False ; (leI, w2) <- mk True True
(gtI, w3) <- mk False False ; (geI, w4) <- mk False True
pure ( [(idxOf "<", ltI), (idxOf "<=", leI), (idxOf ">", gtI), (idxOf ">=", geI)]
, w1 ++ w2 ++ w3 ++ w4 )
if overridden
then do
dict <- recDictWith ordCls wrappedTy [eqDict] ([(0, cmpImpl)] ++ relOverrides)
pure (EvExpr dict, cmpWs ++ eqWs ++ relWs)
else do
let cmpTy = HasDebugCallStack => Type -> Type -> Type
Type -> Type -> Type
mkVisFunTyMany Type
wrappedTy (HasDebugCallStack => Type -> Type -> Type
Type -> Type -> Type
mkVisFunTyMany Type
wrappedTy (TyCon -> Type
mkTyConTy TyCon
orderingTyCon))
cmpUnf = SimpleOpts -> UnfoldingSource -> Int -> CoreExpr -> Unfolding
mkInlineUnfoldingWithArity SimpleOpts
defaultSimpleOpts UnfoldingSource
StableSystemSrc Int
2 CoreExpr
cmpImpl
cmpId0 <- freshId cmpTy "vvCompare"
let cmpId = Id
cmpId0 Id -> Unfolding -> Id
`setIdUnfolding` Unfolding
cmpUnf
dictInner <- recDictWith ordCls wrappedTy [eqDict] ([(0, Var cmpId)] ++ relOverrides)
let dict = Bind Id -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (Id -> CoreExpr -> Bind Id
forall b. b -> Expr b -> Bind b
NonRec Id
cmpId CoreExpr
cmpImpl) CoreExpr
dictInner
pure (EvExpr dict, cmpWs ++ eqWs ++ relWs)