{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Internal.TH.Ctor.SmartConstructor
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.Ctor.SmartConstructor
  ( makeSmartCtorWith,
    makePrefixedSmartCtor,
    makeNamedSmartCtor,
    makeSmartCtor,
  )
where

import Control.Monad (join, replicateM, when, zipWithM)
import Data.Bifunctor (Bifunctor (second))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, mrgSingle)
import Grisette.Internal.TH.Ctor.Common
  ( decapitalizeTransformer,
    prefixTransformer,
    withNameTransformer,
  )
import Grisette.Internal.TH.Util (constructorInfoToType, putHaddock)
import Language.Haskell.TH
  ( Body (NormalB),
    Clause (Clause),
    Dec (FunD, SigD),
    Exp (AppE, ConE, LamE, VarE),
    Name,
    Pat (VarP),
    Pred,
    Q,
    Type (AppT, ArrowT, ForallT, VarT),
    mkName,
    newName,
  )
import Language.Haskell.TH.Datatype
  ( ConstructorInfo
      ( constructorFields,
        constructorName
      ),
    DatatypeInfo (datatypeCons),
    reifyDatatype,
  )
import Language.Haskell.TH.Datatype.TyVarBndr
  ( Specificity (SpecifiedSpec),
    TyVarBndrSpec,
    plainTVFlag,
  )

-- | Generate constructor wrappers that wraps the result in a container with
-- `TryMerge` with provided name transformer.
--
-- > makeSmartCtorWith (\name -> "mrg" ++ name) ''Maybe
--
-- generates
--
-- > mrgNothing :: (Mergeable (Maybe a), Applicative m, TryMerge m) => m (Maybe a)
-- > mrgNothing = mrgSingle Nothing
makeSmartCtorWith :: (String -> String) -> Name -> Q [Dec]
makeSmartCtorWith :: (String -> String) -> Name -> Q [Dec]
makeSmartCtorWith = ([String] -> Name -> Q [Dec])
-> (String -> String) -> Name -> Q [Dec]
withNameTransformer [String] -> Name -> Q [Dec]
makeNamedSmartCtor

-- | Generate constructor wrappers that wraps the result in a container with
-- `TryMerge`.
--
-- > makePrefixedSmartCtor "mrg" ''Maybe
--
-- generates
--
-- > mrgNothing :: (Mergeable (Maybe a), Applicative m, TryMerge m) => m (Maybe a)
-- > mrgNothing = mrgSingle Nothing
-- > mrgJust :: (Mergeable (Maybe a), Applicative m, TryMerge m) => a -> m (Maybe a)
-- > mrgJust = \x -> mrgSingle (Just x)
makePrefixedSmartCtor ::
  -- | Prefix for generated wrappers
  String ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makePrefixedSmartCtor :: String -> Name -> Q [Dec]
makePrefixedSmartCtor = (String -> String) -> Name -> Q [Dec]
makeSmartCtorWith ((String -> String) -> Name -> Q [Dec])
-> (String -> String -> String) -> String -> Name -> Q [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String -> String
prefixTransformer

-- | Generate constructor wrappers that wraps the result in a container with
-- `TryMerge`.
--
-- > makeSmartCtor ''Maybe
--
-- generates
--
-- > nothing :: (Mergeable (Maybe a), Applicative m, TryMerge m) => m (Maybe a)
-- > nothing = mrgSingle Nothing
-- > just :: (Mergeable (Maybe a), Applicative m, TryMerge m) => a -> m (Maybe a)
-- > just = \x -> mrgSingle (Just x)
makeSmartCtor ::
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeSmartCtor :: Name -> Q [Dec]
makeSmartCtor = (String -> String) -> Name -> Q [Dec]
makeSmartCtorWith String -> String
decapitalizeTransformer

-- | Generate constructor wrappers that wraps the result in a container with
-- `TryMerge` with provided names.
--
-- > makeNamedSmartCtor ["mrgTuple2"] ''(,)
--
-- generates
--
-- > mrgTuple2 :: (Mergeable (a, b), Applicative m, TryMerge m) => a -> b -> u (a, b)
-- > mrgTuple2 = \v1 v2 -> mrgSingle (v1, v2)
makeNamedSmartCtor ::
  -- | Names for generated wrappers
  [String] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeNamedSmartCtor :: [String] -> Name -> Q [Dec]
makeNamedSmartCtor [String]
names Name
typName = do
  DatatypeInfo
d <- Name -> Q DatatypeInfo
reifyDatatype Name
typName
  let constructors :: [ConstructorInfo]
constructors = DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
d
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [ConstructorInfo] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ConstructorInfo]
constructors) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
  [[Dec]]
ds <- (String -> ConstructorInfo -> Q [Dec])
-> [String] -> [ConstructorInfo] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (DatatypeInfo -> String -> ConstructorInfo -> Q [Dec]
mkSingleWrapper DatatypeInfo
d) [String]
names [ConstructorInfo]
constructors
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds

augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
  [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
  let args :: [Pat]
args = (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
  Exp
mrgSingleFun <- [|mrgSingle|]
  Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$
    [Pat] -> Exp -> Exp
LamE
      [Pat]
args
      ( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
      )

augmentFinalType :: Type -> Q (([TyVarBndrSpec], [Pred]), Type)
augmentFinalType :: Type -> Q (([TyVarBndrSpec], [Type]), Type)
augmentFinalType (AppT a :: Type
a@(AppT Type
ArrowT Type
_) Type
t) = do
  (([TyVarBndrSpec], [Type]), Type)
tl <- Type -> Q (([TyVarBndrSpec], [Type]), Type)
augmentFinalType Type
t
  (([TyVarBndrSpec], [Type]), Type)
-> Q (([TyVarBndrSpec], [Type]), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ((([TyVarBndrSpec], [Type]), Type)
 -> Q (([TyVarBndrSpec], [Type]), Type))
-> (([TyVarBndrSpec], [Type]), Type)
-> Q (([TyVarBndrSpec], [Type]), Type)
forall a b. (a -> b) -> a -> b
$ (Type -> Type)
-> (([TyVarBndrSpec], [Type]), Type)
-> (([TyVarBndrSpec], [Type]), Type)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Type -> Type -> Type
AppT Type
a) (([TyVarBndrSpec], [Type]), Type)
tl
augmentFinalType Type
t = do
  Name
mName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"m"
  let mTy :: Type
mTy = Name -> Type
VarT Name
mName
  Type
mergeable <- [t|Mergeable|]
  Type
applicative <- [t|Applicative|]
  Type
tryMerge <- [t|TryMerge|]
  (([TyVarBndrSpec], [Type]), Type)
-> Q (([TyVarBndrSpec], [Type]), Type)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( ( [Name -> Specificity -> TyVarBndrSpec
forall flag. Name -> flag -> TyVarBndr_ flag
plainTVFlag Name
mName Specificity
SpecifiedSpec],
        [Type -> Type -> Type
AppT Type
mergeable Type
t, Type -> Type -> Type
AppT Type
applicative Type
mTy, Type -> Type -> Type
AppT Type
tryMerge Type
mTy]
      ),
      Type -> Type -> Type
AppT Type
mTy Type
t
    )

augmentConstructorType :: Type -> Q Type
augmentConstructorType :: Type -> Q Type
augmentConstructorType (ForallT [TyVarBndrSpec]
tybinders [Type]
ctx Type
ty1) = do
  (([TyVarBndrSpec]
bndrs, [Type]
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndrSpec], [Type]), Type)
augmentFinalType Type
ty1
  Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndrSpec] -> [Type] -> Type -> Type
ForallT ([TyVarBndrSpec]
tybinders [TyVarBndrSpec] -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall a. [a] -> [a] -> [a]
++ [TyVarBndrSpec]
bndrs) ([Type]
preds [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ctx) Type
augmentedTyp
augmentConstructorType Type
t = do
  (([TyVarBndrSpec]
bndrs, [Type]
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndrSpec], [Type]), Type)
augmentFinalType Type
t
  Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndrSpec] -> [Type] -> Type -> Type
ForallT [TyVarBndrSpec]
bndrs [Type]
preds Type
augmentedTyp

mkSingleWrapper :: DatatypeInfo -> String -> ConstructorInfo -> Q [Dec]
mkSingleWrapper :: DatatypeInfo -> String -> ConstructorInfo -> Q [Dec]
mkSingleWrapper DatatypeInfo
dataType String
name ConstructorInfo
info = do
  Type
constructorTyp <- DatatypeInfo -> ConstructorInfo -> Q Type
constructorInfoToType DatatypeInfo
dataType ConstructorInfo
info
  Type
augmentedTyp <- Type -> Q Type
augmentConstructorType Type
constructorTyp
  let oriName :: Name
oriName = ConstructorInfo -> Name
constructorName ConstructorInfo
info
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ ConstructorInfo -> [Type]
constructorFields ConstructorInfo
info) (Name -> Exp
ConE Name
oriName)
  Name -> String -> Q ()
putHaddock Name
retName (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
    String
"Smart constructor for v'"
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Name -> String
forall a. Show a => a -> String
show Name
oriName
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"' to construct values wrapped and possibly merged in a container."
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]