{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Avoid lambda" #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.FunInstanceGen
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.FunInstanceGen
  ( supportedPrimFun,
    supportedPrimFunUpTo,
  )
where

import qualified Data.SBV as SBV
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( IsSymbolKind,
    SupportedNonFuncPrim,
    SupportedPrim
      ( castTypedSymbol,
        conSBVTerm,
        defaultValue,
        funcDummyConstraint,
        parseSMTModelResult,
        pevalDistinctTerm,
        pevalEqTerm,
        pevalITETerm,
        sameCon,
        sbvDistinct,
        sbvEq,
        symSBVName,
        symSBVTerm,
        withPrim
      ),
    TypedSymbol (unTypedSymbol),
    decideSymbolKind,
    translateTypeError,
    typedAnySymbol,
    withNonFuncPrim,
  )
import Language.Haskell.TH
  ( Cxt,
    Dec (InstanceD),
    DecsQ,
    Exp,
    ExpQ,
    Name,
    Overlap (Overlapping),
    Q,
    Type,
    TypeQ,
    forallT,
    lamE,
    newName,
    sigD,
    stringE,
    varE,
    varP,
    varT,
  )
import Language.Haskell.TH.Datatype.TyVarBndr
  ( plainTVInferred,
    plainTVSpecified,
  )
import Type.Reflection (TypeRep, typeRep, type (:~~:) (HRefl))

instanceWithOverlapDescD ::
  Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD :: Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD Maybe Overlap
o Q Cxt
ctxts Q Type
ty [DecsQ]
descs = do
  Cxt
ctxts1 <- Q Cxt
ctxts
  [[Dec]]
descs1 <- [DecsQ] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [DecsQ]
descs
  Type
ty1 <- Q Type
ty
  [Dec] -> DecsQ
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
o Cxt
ctxts1 Type
ty1 ([[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Dec]]
descs1)]

-- | Generate an instance of 'SupportedPrim' for a function with a given number
-- of arguments.
supportedPrimFun ::
  ExpQ ->
  ExpQ ->
  ExpQ ->
  ([TypeQ] -> ExpQ) ->
  String ->
  String ->
  Name ->
  Int ->
  DecsQ
supportedPrimFun :: ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFun
  ExpQ
dv
  ExpQ
ite
  ExpQ
parse
  [Q Type] -> ExpQ
consbv
  String
funNameInError
  String
funNamePrefix
  Name
funTypeName
  Int
numArg = do
    [Name]
names <- (Int -> Q Name) -> [Int] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Q Name) -> (Int -> String) -> Int -> Q Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"a" String -> String -> String
forall a. Semigroup a => a -> a -> a
<>) (String -> String) -> (Int -> String) -> Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show) [Int
0 .. Int
numArg Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    let tyVars :: [Q Type]
tyVars = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> [Name] -> [Q Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
names
    Name
knd <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"knd"
    Name
knd' <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"knd'"
    let kndty :: Q Type
kndty = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
knd
    let knd'ty :: Q Type
knd'ty = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
knd'
    Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD
      (if Int
numArg Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 then Maybe Overlap
forall a. Maybe a
Nothing else Overlap -> Maybe Overlap
forall a. a -> Maybe a
Just Overlap
Overlapping)
      ([Q Type] -> Q Cxt
constraints [Q Type]
tyVars)
      [t|SupportedPrim $([Q Type] -> Q Type
funType [Q Type]
tyVars)|]
      ( [ [d|$(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'sameCon) = (==)|],
          [d|$(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'defaultValue) = $ExpQ
dv|],
          [d|$(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'pevalITETerm) = $ExpQ
ite|],
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'pevalEqTerm) =
              $( [Q Type] -> String -> ExpQ
translateError
                   [Q Type]
tyVars
                   String
"does not supported equality comparison."
               )
            |],
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'pevalDistinctTerm) =
              $( [Q Type] -> String -> ExpQ
translateError
                   [Q Type]
tyVars
                   String
"does not supported equality comparison."
               )
            |],
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'conSBVTerm) = $([Q Type] -> ExpQ
consbv [Q Type]
tyVars)
            |],
          -- \$( translateError
          --      tyVars
          --      ( "must have already been partially evaluated away before "
          --          <> "reaching this point."
          --      )
          --  )

          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'symSBVName) = \_ num ->
              $(String -> ExpQ
forall (m :: * -> *). Quote m => String -> m Exp
stringE (String -> ExpQ) -> String -> ExpQ
forall a b. (a -> b) -> a -> b
$ String
funNamePrefix String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
numArg String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_") <> show num
            |],
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'symSBVTerm) = \r ->
              withPrim @($([Q Type] -> Q Type
funType [Q Type]
tyVars)) $ return $ SBV.uninterpret r
            |],
          [d|$(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'withPrim) = $([Q Type] -> ExpQ
withPrims [Q Type]
tyVars)|],
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'sbvEq) =
              $( [Q Type] -> String -> ExpQ
translateError
                   [Q Type]
tyVars
                   String
"does not support equality comparison."
               )
            |],
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'sbvDistinct) =
              $( [Q Type] -> String -> ExpQ
translateError
                   [Q Type]
tyVars
                   String
"does not support equality comparison."
               )
            |],
          [d|$(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'parseSMTModelResult) = $ExpQ
parse|],
          (Dec -> [Dec] -> [Dec]
forall a. a -> [a] -> [a]
: [])
            (Dec -> [Dec]) -> Q Dec -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Q Type -> Q Dec
forall (m :: * -> *). Quote m => Name -> m Type -> m Dec
sigD
              'castTypedSymbol
              ( [TyVarBndr Specificity] -> Q Cxt -> Q Type -> Q Type
forall (m :: * -> *).
Quote m =>
[TyVarBndr Specificity] -> m Cxt -> m Type -> m Type
forallT
                  [Name -> TyVarBndr Specificity
plainTVInferred Name
knd, Name -> TyVarBndr Specificity
plainTVSpecified Name
knd']
                  ((Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
: []) (Type -> Cxt) -> Q Type -> Q Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [t|IsSymbolKind $Q Type
knd'ty|])
                  [t|
                    TypedSymbol $Q Type
kndty $([Q Type] -> Q Type
funType [Q Type]
tyVars) ->
                    Maybe (TypedSymbol $Q Type
knd'ty $([Q Type] -> Q Type
funType [Q Type]
tyVars))
                    |]
              ),
          [d|
            $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'castTypedSymbol) = \sym ->
              case decideSymbolKind @($Q Type
knd'ty) of
                Left HRefl -> Nothing
                Right HRefl -> Just $ typedAnySymbol $ unTypedSymbol sym
            |],
          ( if Int
numArg Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2
              then
                [d|
                  $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'funcDummyConstraint) = \f ->
                    withPrim @($([Q Type] -> Q Type
funType [Q Type]
tyVars)) $
                      withNonFuncPrim @($([Q Type] -> Q Type
forall a. HasCallStack => [a] -> a
last [Q Type]
tyVars)) $ do
                        f (conSBVTerm (defaultValue :: $([Q Type] -> Q Type
forall a. HasCallStack => [a] -> a
head [Q Type]
tyVars)))
                          SBV..== f
                            (conSBVTerm (defaultValue :: $([Q Type] -> Q Type
forall a. HasCallStack => [a] -> a
head [Q Type]
tyVars)))
                  |]
              else
                [d|
                  $(Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP 'funcDummyConstraint) = \f ->
                    withNonFuncPrim @($([Q Type] -> Q Type
forall a. HasCallStack => [a] -> a
head [Q Type]
tyVars)) $
                      funcDummyConstraint @($([Q Type] -> Q Type
funType ([Q Type] -> Q Type) -> [Q Type] -> Q Type
forall a b. (a -> b) -> a -> b
$ [Q Type] -> [Q Type]
forall a. HasCallStack => [a] -> [a]
tail [Q Type]
tyVars))
                        (f (conSBVTerm (defaultValue :: $([Q Type] -> Q Type
forall a. HasCallStack => [a] -> a
head [Q Type]
tyVars))))
                  |]
          )
        ]
      )
    where
      translateError :: [Q Type] -> String -> ExpQ
translateError [Q Type]
tyVars String
finalMsg =
        [|
          translateTypeError
            ( Just
                $( String -> ExpQ
forall (m :: * -> *). Quote m => String -> m Exp
stringE (String -> ExpQ) -> String -> ExpQ
forall a b. (a -> b) -> a -> b
$
                     String
"BUG. Please send a bug report. "
                       String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
funNameInError
                       String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" "
                       String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
finalMsg
                 )
            )
            (typeRep :: TypeRep $([Q Type] -> Q Type
funType [Q Type]
tyVars))
          |]

      constraints :: [Q Type] -> Q Cxt
constraints =
        ([Cxt] -> Cxt) -> Q [Cxt] -> Q Cxt
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Cxt] -> Cxt
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [Cxt] -> Q Cxt) -> ([Q Type] -> Q [Cxt]) -> [Q Type] -> Q Cxt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Q Type -> Q Cxt) -> [Q Type] -> Q [Cxt]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Q Type
ty -> [Q Type] -> Q Cxt
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [[t|SupportedNonFuncPrim $Q Type
ty|]])
      funType :: [Q Type] -> Q Type
funType =
        (Q Type -> Q Type -> Q Type) -> [Q Type] -> Q Type
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\Q Type
fty Q Type
ty -> [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
funTypeName) $Q Type
ty $Q Type
fty|]) ([Q Type] -> Q Type)
-> ([Q Type] -> [Q Type]) -> [Q Type] -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Q Type] -> [Q Type]
forall a. [a] -> [a]
reverse
      withPrims :: [Q Type] -> Q Exp
      withPrims :: [Q Type] -> ExpQ
withPrims [Q Type]
tyVars = do
        Name
r <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"r"
        [Q Pat] -> ExpQ -> ExpQ
forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE [Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
r] (ExpQ -> ExpQ) -> ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$
          (Q Type -> ExpQ -> ExpQ) -> ExpQ -> [Q Type] -> ExpQ
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
            (\Q Type
ty ExpQ
r -> [|withNonFuncPrim @($Q Type
ty) $ExpQ
r|])
            (Name -> ExpQ
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
r)
            [Q Type]
tyVars

-- | Generate instances of 'SupportedPrim' for functions with up to a given
-- number of arguments.
supportedPrimFunUpTo ::
  ExpQ -> ExpQ -> ExpQ -> ([TypeQ] -> ExpQ) -> String -> String -> Name -> Int -> DecsQ
supportedPrimFunUpTo :: ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFunUpTo
  ExpQ
dv
  ExpQ
ite
  ExpQ
parse
  [Q Type] -> ExpQ
consbv
  String
funNameInError
  String
funNamePrefix
  Name
funTypeName
  Int
numArg =
    [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
      ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DecsQ] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
        [ ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFun
            ExpQ
dv
            ExpQ
ite
            ExpQ
parse
            [Q Type] -> ExpQ
consbv
            String
funNameInError
            String
funNamePrefix
            Name
funTypeName
            Int
n
        | Int
n <- [Int
2 .. Int
numArg]
        ]