{-# LANGUAGE CPP #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module GHC.TypeLits.KnownNat.Compat
  ( KnownNatDefs(..), lookupKnownNatDefs
  , mkNaturalExpr

  , coercionRKind, classMethodTy
  , irrelevantMult
  )
  where

-- base
import Data.Type.Bool
  ( If )
#if MIN_VERSION_ghc(9,1,0)
import Data.Type.Ord
  ( OrdCond )
#else
import GHC.TypeNats
  ( type (<=) )
#endif


-- ghc-tcplugin-api
import GHC.TcPlugin.API
#if MIN_VERSION_ghc(9,3,0)
import GHC.TcPlugin.API.Internal ( unsafeLiftTcM )
#endif

-- ghc
import qualified GHC.Core.Make as GHC
  ( mkNaturalExpr )
#if MIN_VERSION_ghc(9,3,0)
import GHC.Tc.Utils.Monad
  ( getPlatform )
#endif
#if MIN_VERSION_ghc(8,11,0)
import GHC.Core.Coercion
  ( coercionRKind )
import GHC.Core.Predicate
  ( classMethodTy )
import GHC.Core.Type
  ( irrelevantMult )
#else
import GHC.Core.Coercion
  ( coercionKind )
import GHC.Core.Type
  ( dropForAlls, funResultTy, varType )
import GHC.Data.Pair
  ( Pair(..) )
#endif

-- ghc-typelits-knownnat
import GHC.TypeLits.KnownNat
  ( KnownNat1, KnownNat2, KnownNat3
  , KnownBool, KnownBoolNat2, KnownNat2Bool
  )

-- template-haskell
import qualified Language.Haskell.TH as TH
  ( Name )

--------------------------------------------------------------------------------

-- | Classes and instances from "GHC.TypeLits.KnownNat"
data KnownNatDefs
  = KnownNatDefs
  { KnownNatDefs -> Class
knownBool     :: Class
  , KnownNatDefs -> Class
knownBoolNat2 :: Class
  , KnownNatDefs -> Class
knownNat2Bool :: Class
  , KnownNatDefs -> Int -> Maybe Class
knownNatN     :: Int -> Maybe Class -- ^ KnownNat{N}
#if MIN_VERSION_ghc(9,1,0)
  , KnownNatDefs -> TyCon
ordCondTyCon  :: TyCon
#else
    -- | @<= :: Nat -> Nat -> Constraint@
  , leqNatTyCon   :: TyCon
#endif
  , KnownNatDefs -> TyCon
ifTyCon       :: TyCon
  }

-- | Find the \"magic\" classes and instances in "GHC.TypeLits.KnownNat"
lookupKnownNatDefs :: TcPluginM Init KnownNatDefs
lookupKnownNatDefs :: TcPluginM 'Init KnownNatDefs
lookupKnownNatDefs = do
    Class
kbC    <- Name -> TcPluginM 'Init Class
look ''KnownBool
    Class
kbn2C  <- Name -> TcPluginM 'Init Class
look ''KnownBoolNat2
    Class
kn2bC  <- Name -> TcPluginM 'Init Class
look ''KnownNat2Bool
    Class
kn1C   <- Name -> TcPluginM 'Init Class
look ''KnownNat1
    Class
kn2C   <- Name -> TcPluginM 'Init Class
look ''KnownNat2
    Class
kn3C   <- Name -> TcPluginM 'Init Class
look ''KnownNat3
#if MIN_VERSION_ghc(9,1,0)
    TyCon
ordcond <- Name -> TcPluginM 'Init Name
forall (s :: TcPluginStage).
(Monad (TcPluginM s), MonadTcPlugin (TcPluginM s)) =>
Name -> TcPluginM s Name
lookupTHName ''OrdCond TcPluginM 'Init Name
-> (Name -> TcPluginM 'Init TyCon) -> TcPluginM 'Init TyCon
forall a b.
TcPluginM 'Init a -> (a -> TcPluginM 'Init b) -> TcPluginM 'Init b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Name -> TcPluginM 'Init TyCon
forall (m :: * -> *). MonadTcPlugin m => Name -> m TyCon
tcLookupTyCon
#else
    leq     <- lookupTHName ''(<=) >>= tcLookupTyCon
#endif
    TyCon
ifTc <- Name -> TcPluginM 'Init Name
forall (s :: TcPluginStage).
(Monad (TcPluginM s), MonadTcPlugin (TcPluginM s)) =>
Name -> TcPluginM s Name
lookupTHName ''If TcPluginM 'Init Name
-> (Name -> TcPluginM 'Init TyCon) -> TcPluginM 'Init TyCon
forall a b.
TcPluginM 'Init a -> (a -> TcPluginM 'Init b) -> TcPluginM 'Init b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Name -> TcPluginM 'Init TyCon
forall (m :: * -> *). MonadTcPlugin m => Name -> m TyCon
tcLookupTyCon
    KnownNatDefs -> TcPluginM 'Init KnownNatDefs
forall a. a -> TcPluginM 'Init a
forall (m :: * -> *) a. Monad m => a -> m a
return KnownNatDefs
           { knownBool :: Class
knownBool     = Class
kbC
           , knownBoolNat2 :: Class
knownBoolNat2 = Class
kbn2C
           , knownNat2Bool :: Class
knownNat2Bool = Class
kn2bC
           , knownNatN :: Int -> Maybe Class
knownNatN     = \case { Int
1 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn1C
                                   ; Int
2 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn2C
                                   ; Int
3 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn3C
                                   ; Int
_ -> Maybe Class
forall a. Maybe a
Nothing
                                   }
#if MIN_VERSION_ghc(9,1,0)
           , ordCondTyCon :: TyCon
ordCondTyCon  = TyCon
ordcond
#else
           , leqNatTyCon   = leq
#endif
           , ifTyCon :: TyCon
ifTyCon       = TyCon
ifTc
           }
  where
    look :: TH.Name -> TcPluginM Init Class
    look :: Name -> TcPluginM 'Init Class
look Name
nm = Name -> TcPluginM 'Init Name
forall (s :: TcPluginStage).
(Monad (TcPluginM s), MonadTcPlugin (TcPluginM s)) =>
Name -> TcPluginM s Name
lookupTHName Name
nm TcPluginM 'Init Name
-> (Name -> TcPluginM 'Init Class) -> TcPluginM 'Init Class
forall a b.
TcPluginM 'Init a -> (a -> TcPluginM 'Init b) -> TcPluginM 'Init b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Name -> TcPluginM 'Init Class
forall (m :: * -> *). MonadTcPlugin m => Name -> m Class
tcLookupClass

--------------------------------------------------------------------------------

mkNaturalExpr :: Integer -> TcPluginM Solve CoreExpr
mkNaturalExpr :: Integer -> TcPluginM 'Solve CoreExpr
mkNaturalExpr Integer
i = do
#if MIN_VERSION_ghc(9,3,0)
    Platform
platform <- TcM Platform -> TcPluginM 'Solve Platform
forall a. TcM a -> TcPluginM 'Solve a
forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
unsafeLiftTcM TcM Platform
forall a b. TcRnIf a b Platform
getPlatform
    CoreExpr -> TcPluginM 'Solve CoreExpr
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> TcPluginM 'Solve CoreExpr)
-> CoreExpr -> TcPluginM 'Solve CoreExpr
forall a b. (a -> b) -> a -> b
$ Platform -> Integer -> CoreExpr
GHC.mkNaturalExpr Platform
platform Integer
i
#elif MIN_VERSION_ghc(8,11,0)
    return $ GHC.mkNaturalExpr i
#else
    GHC.mkNaturalExpr i
#endif

--------------------------------------------------------------------------------

#if !MIN_VERSION_ghc(8,11,0)
coercionRKind :: Coercion -> Type
coercionRKind co = rhs
  where
    Pair _ rhs = coercionKind co
#endif

--------------------------------------------------------------------------------

#if !MIN_VERSION_ghc(8,11,0)
classMethodTy :: Id -> Type
classMethodTy sel_id
  = funResultTy $        -- meth_ty
    dropForAlls $        -- C a => meth_ty
    varType sel_id        -- forall a. C n => meth_ty
#endif

--------------------------------------------------------------------------------

#if !MIN_VERSION_ghc(8,11,0)
irrelevantMult :: a -> a
irrelevantMult = id
#endif

--------------------------------------------------------------------------------