{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      :   Grisette.Internal.TH.Ctor.UnifiedConstructor
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.Ctor.UnifiedConstructor
  ( makeUnifiedCtorWith,
    makePrefixedUnifiedCtor,
    makeNamedUnifiedCtor,
    makeUnifiedCtor,
  )
where

import Control.Monad (join, replicateM, when, zipWithM)
import Data.Maybe (catMaybes)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable, Mergeable1, Mergeable2)
import Grisette.Internal.TH.Ctor.Common
  ( decapitalizeTransformer,
    prefixTransformer,
    withNameTransformer,
  )
import Grisette.Internal.TH.Derivation.Common (ctxForVar)
import Grisette.Internal.TH.Util (constructorInfoToType, putHaddock, tvIsMode)
import Grisette.Internal.Unified.EvalModeTag (EvalModeTag)
import Grisette.Internal.Unified.UnifiedData
  ( GetData,
    UnifiedData,
    wrapData,
  )
import Language.Haskell.TH (conT, pprint, varT)
import Language.Haskell.TH.Datatype
  ( ConstructorInfo (constructorFields, constructorName),
    DatatypeInfo (datatypeCons, datatypeVars),
    reifyDatatype,
    tvKind,
    tvName,
  )
import Language.Haskell.TH.Datatype.TyVarBndr (TyVarBndrSpec, kindedTVSpecified)
import Language.Haskell.TH.Lib (appE, appTypeE, lamE, varE, varP)
import Language.Haskell.TH.Syntax
  ( Body (NormalB),
    Clause (Clause),
    Dec (FunD, SigD),
    Exp (ConE),
    Name,
    Pred,
    Q,
    Type (AppT, ArrowT, ConT, ForallT, VarT),
    mkName,
    newName,
  )

-- | Generate smart constructors to create unified values with provided name
-- transformer.
--
-- For a type @T mode a b c@ with constructors @T1@, @T2@, etc., this function
-- will generate smart constructors with the name transformed, e.g., given the
-- name transformer @(\name -> "mk" ++ name)@, it will generate @mkT1@, @mkT2@,
-- @mkT2@, etc.
--
-- The generated smart constructors will contruct values of type
-- @GetData mode (T mode a b c)@.
makeUnifiedCtorWith :: [Name] -> (String -> String) -> Name -> Q [Dec]
makeUnifiedCtorWith :: [Name] -> (String -> String) -> Name -> Q [Dec]
makeUnifiedCtorWith = ([String] -> Name -> Q [Dec])
-> (String -> String) -> Name -> Q [Dec]
withNameTransformer (([String] -> Name -> Q [Dec])
 -> (String -> String) -> Name -> Q [Dec])
-> ([Name] -> [String] -> Name -> Q [Dec])
-> [Name]
-> (String -> String)
-> Name
-> Q [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name] -> [String] -> Name -> Q [Dec]
makeNamedUnifiedCtor

-- | Generate smart constructors to create unified values.
--
-- For a type @T mode a b c@ with constructors @T1@, @T2@, etc., this function
-- will generate smart constructors with the given prefix, e.g., @mkT1@, @mkT2@,
-- etc.
--
-- The generated smart constructors will contruct values of type
-- @GetData mode (T mode a b c)@.
makePrefixedUnifiedCtor ::
  [Name] ->
  -- | Prefix for generated wrappers
  String ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makePrefixedUnifiedCtor :: [Name] -> String -> Name -> Q [Dec]
makePrefixedUnifiedCtor [Name]
modeCtx =
  [Name] -> (String -> String) -> Name -> Q [Dec]
makeUnifiedCtorWith [Name]
modeCtx ((String -> String) -> Name -> Q [Dec])
-> (String -> String -> String) -> String -> Name -> Q [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String -> String
prefixTransformer

-- | Generate smart constructors to create unified values.
--
-- For a type @T mode a b c@ with constructors @T1@, @T2@, etc., this function
-- will generate smart constructors with the names decapitalized, e.g.,
-- @t1@, @t2@, etc.
--
-- The generated smart constructors will contruct values of type
-- @GetData mode (T mode a b c)@.
makeUnifiedCtor ::
  [Name] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnifiedCtor :: [Name] -> Name -> Q [Dec]
makeUnifiedCtor [Name]
modeCtx = [Name] -> (String -> String) -> Name -> Q [Dec]
makeUnifiedCtorWith [Name]
modeCtx String -> String
decapitalizeTransformer

-- | Generate smart constructors to create unified values.
--
-- For a type @T mode a b c@ with constructors @T1@, @T2@, etc., this function
-- will generate smart constructors with the given names.
--
-- The generated smart constructors will contruct values of type
-- @GetData mode (T mode a b c)@.
makeNamedUnifiedCtor ::
  [Name] ->
  -- | Names for generated wrappers
  [String] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeNamedUnifiedCtor :: [Name] -> [String] -> Name -> Q [Dec]
makeNamedUnifiedCtor [Name]
modeCtx [String]
names Name
typName = do
  DatatypeInfo
d <- Name -> Q DatatypeInfo
reifyDatatype Name
typName
  let constructors :: [ConstructorInfo]
constructors = DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
d
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [ConstructorInfo] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ConstructorInfo]
constructors) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
  let modeVars :: [TyVarBndr_ ()]
modeVars = (TyVarBndr_ () -> Bool) -> [TyVarBndr_ ()] -> [TyVarBndr_ ()]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT ''EvalModeTag) (Type -> Bool) -> (TyVarBndr_ () -> Type) -> TyVarBndr_ () -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr_ () -> Type
forall flag. TyVarBndr_ flag -> Type
tvKind) (DatatypeInfo -> [TyVarBndr_ ()]
datatypeVars DatatypeInfo
d)
  -- when (length modeVars /= 1) $
  --  fail "Expected exactly one EvalModeTag variable in the datatype."
  case [TyVarBndr_ ()]
modeVars of
    [TyVarBndr_ ()
mode] -> do
      [[Dec]]
ds <-
        (String -> ConstructorInfo -> Q [Dec])
-> [String] -> [ConstructorInfo] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
          ([Name]
-> DatatypeInfo
-> Maybe TyVarBndrSpec
-> Type
-> String
-> ConstructorInfo
-> Q [Dec]
mkSingleWrapper [Name]
modeCtx DatatypeInfo
d Maybe TyVarBndrSpec
forall a. Maybe a
Nothing (Type -> String -> ConstructorInfo -> Q [Dec])
-> Type -> String -> ConstructorInfo -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ Name -> Type
VarT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr_ () -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndr_ ()
mode)
          [String]
names
          [ConstructorInfo]
constructors
      [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds
    [] -> do
      Name
n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"mode"
      let newBndr :: TyVarBndrSpec
newBndr = Name -> Type -> TyVarBndrSpec
kindedTVSpecified Name
n (Name -> Type
ConT ''EvalModeTag)
      [[Dec]]
ds <-
        (String -> ConstructorInfo -> Q [Dec])
-> [String] -> [ConstructorInfo] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
          ([Name]
-> DatatypeInfo
-> Maybe TyVarBndrSpec
-> Type
-> String
-> ConstructorInfo
-> Q [Dec]
mkSingleWrapper [Name]
modeCtx DatatypeInfo
d (TyVarBndrSpec -> Maybe TyVarBndrSpec
forall a. a -> Maybe a
Just TyVarBndrSpec
newBndr) (Name -> Type
VarT Name
n))
          [String]
names
          [ConstructorInfo]
constructors
      [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds
    [TyVarBndr_ ()]
_ -> String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected one or zero EvalModeTag variable in the datatype."

augmentFinalType :: Type -> Type -> Q ([Pred], Type)
augmentFinalType :: Type -> Type -> Q ([Type], Type)
augmentFinalType Type
mode (AppT a :: Type
a@(AppT Type
ArrowT Type
_) Type
t) = do
  ([Type]
pred, Type
ret) <- Type -> Type -> Q ([Type], Type)
augmentFinalType Type
mode Type
t
  ([Type], Type) -> Q ([Type], Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
pred, Type -> Type -> Type
AppT Type
a Type
ret)
augmentFinalType Type
mode Type
t = do
  Type
r <- [t|GetData $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
mode) $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t)|]
  Type
predu <- [t|UnifiedData $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
mode) $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t)|]
  ([Type], Type) -> Q ([Type], Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type
predu], Type
r)

augmentConstructorType ::
  [Name] -> Maybe TyVarBndrSpec -> Type -> Type -> Q Type
augmentConstructorType :: [Name] -> Maybe TyVarBndrSpec -> Type -> Type -> Q Type
augmentConstructorType
  [Name]
modeCtx
  Maybe TyVarBndrSpec
freshModeBndr
  Type
mode
  (ForallT [TyVarBndrSpec]
tybinders [Type]
ctx Type
ty1) = do
    ([Type]
preds, Type
augmentedTyp) <- Type -> Type -> Q ([Type], Type)
augmentFinalType Type
mode Type
ty1
    let modeBndrsInForall :: [TyVarBndrSpec]
modeBndrsInForall = (TyVarBndrSpec -> Bool) -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall a. (a -> Bool) -> [a] -> [a]
filter TyVarBndrSpec -> Bool
forall flag. TyVarBndr_ flag -> Bool
tvIsMode [TyVarBndrSpec]
tybinders
    [Type]
mergeablePreds <-
      [Maybe Type] -> [Type]
forall a. [Maybe a] -> [a]
catMaybes
        ([Maybe Type] -> [Type]) -> Q [Maybe Type] -> Q [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TyVarBndrSpec -> Q (Maybe Type))
-> [TyVarBndrSpec] -> Q [Maybe Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
          ( \TyVarBndrSpec
bndr ->
              [Type] -> Type -> Type -> Q (Maybe Type)
ctxForVar
                (Name -> Type
ConT (Name -> Type) -> [Name] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [''Mergeable, ''Mergeable1, ''Mergeable2])
                (Name -> Type
VarT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ TyVarBndrSpec -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndrSpec
bndr)
                (TyVarBndrSpec -> Type
forall flag. TyVarBndr_ flag -> Type
tvKind TyVarBndrSpec
bndr)
          )
          [TyVarBndrSpec]
tybinders
    [Type]
modePred <-
      case ([TyVarBndrSpec]
modeBndrsInForall, Maybe TyVarBndrSpec
freshModeBndr) of
        ([TyVarBndrSpec
bndr], Maybe TyVarBndrSpec
Nothing) ->
          (Name -> Q Type) -> [Name] -> Q [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Name
nm -> [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
nm) $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> Name -> Q Type
forall a b. (a -> b) -> a -> b
$ TyVarBndrSpec -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndrSpec
bndr)|]) [Name]
modeCtx
        ([], Just TyVarBndrSpec
bndr) ->
          (Name -> Q Type) -> [Name] -> Q [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Name
nm -> [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
nm) $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> Name -> Q Type
forall a b. (a -> b) -> a -> b
$ TyVarBndrSpec -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndrSpec
bndr)|]) [Name]
modeCtx
        ([TyVarBndrSpec], Maybe TyVarBndrSpec)
_ -> String -> Q [Type]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unsupported constructor type."
    case Maybe TyVarBndrSpec
freshModeBndr of
      Just TyVarBndrSpec
bndr -> do
        Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$
          [TyVarBndrSpec] -> [Type] -> Type -> Type
ForallT
            (TyVarBndrSpec
bndr TyVarBndrSpec -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall a. a -> [a] -> [a]
: [TyVarBndrSpec]
tybinders)
            ([Type]
modePred [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
mergeablePreds [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
preds [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ctx)
            Type
augmentedTyp
      Maybe TyVarBndrSpec
Nothing ->
        Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$
          [TyVarBndrSpec] -> [Type] -> Type -> Type
ForallT
            [TyVarBndrSpec]
tybinders
            ([Type]
modePred [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
mergeablePreds [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
preds [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ctx)
            Type
augmentedTyp
augmentConstructorType [Name]
_ Maybe TyVarBndrSpec
freshModeBndr Type
mode Type
ty = do
  ([Type]
preds, Type
augmentedTyp) <- Type -> Type -> Q ([Type], Type)
augmentFinalType Type
mode Type
ty
  case Maybe TyVarBndrSpec
freshModeBndr of
    Just TyVarBndrSpec
bndr -> Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndrSpec] -> [Type] -> Type -> Type
ForallT [TyVarBndrSpec
bndr] ([Type]
preds) Type
augmentedTyp
    Maybe TyVarBndrSpec
Nothing ->
      String -> Q Type
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Type) -> String -> Q Type
forall a b. (a -> b) -> a -> b
$
        String
"augmentConstructorType: unsupported constructor type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Ppr a => a -> String
pprint Type
ty

augmentExpr :: Type -> Int -> Exp -> Q Exp
augmentExpr :: Type -> Int -> Exp -> Q Exp
augmentExpr Type
mode Int
n Exp
f = do
  [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
  let args :: [Q Pat]
args = (Name -> Q Pat) -> [Name] -> [Q Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
xs
  [Q Pat] -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE
    [Q Pat]
args
    ( ( Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE
          (Q Exp -> Q Type -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Type -> m Exp
appTypeE [|wrapData|] (Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
mode))
          ((Q Exp -> Q Exp -> Q Exp) -> Q Exp -> [Q Exp] -> Q Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
f) ((Name -> Q Exp) -> [Name] -> [Q Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE [Name]
xs))
      )
    )

mkSingleWrapper ::
  [Name] ->
  DatatypeInfo ->
  Maybe TyVarBndrSpec ->
  Type ->
  String ->
  ConstructorInfo ->
  Q [Dec]
mkSingleWrapper :: [Name]
-> DatatypeInfo
-> Maybe TyVarBndrSpec
-> Type
-> String
-> ConstructorInfo
-> Q [Dec]
mkSingleWrapper [Name]
modeCtx DatatypeInfo
dataType Maybe TyVarBndrSpec
freshModeBndr Type
mode String
name ConstructorInfo
info = do
  Type
constructorTyp <- DatatypeInfo -> ConstructorInfo -> Q Type
constructorInfoToType DatatypeInfo
dataType ConstructorInfo
info
  Type
augmentedTyp <-
    [Name] -> Maybe TyVarBndrSpec -> Type -> Type -> Q Type
augmentConstructorType [Name]
modeCtx Maybe TyVarBndrSpec
freshModeBndr Type
mode Type
constructorTyp
  let oriName :: Name
oriName = ConstructorInfo -> Name
constructorName ConstructorInfo
info
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Type -> Int -> Exp -> Q Exp
augmentExpr Type
mode ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ ConstructorInfo -> [Type]
constructorFields ConstructorInfo
info) (Name -> Exp
ConE Name
oriName)
  Name -> String -> Q ()
putHaddock Name
retName (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
    String
"Smart constructor for v'"
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Name -> String
forall a. Show a => a -> String
show Name
oriName
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"' to construct unified value."
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]