{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- | This module provides types and functions for representing mock parameters.
-- Parameters are used both for setting up expectations and for verification.
module Test.MockCat.Param
  ( Param(..),
    WrapParam(..),
    value,
    param,
    ConsGen(..),
    expect,
    expect_,
    any,
    ArgsOf,
    ProjectionArgs,
    projArgs,
    ReturnOf,
    ProjectionReturn,
    projReturn,
    returnValue,
    Normalize,
    ToParamArg(..)
  )
where

import Test.MockCat.Cons ((:>) (..), Head(..))
import Unsafe.Coerce (unsafeCoerce)
import Prelude hiding (any)
import Data.Typeable (Typeable, typeOf)
import Foreign.Ptr (Ptr, ptrToIntPtr, castPtr, IntPtr)
import qualified Data.Text as T (Text)

infixr 0 ~>

data Param v where
  -- | A parameter that expects a specific value.
  ExpectValue :: (Show v, Eq v) => v -> String -> Param v
  -- | A parameter that expects a value satisfying a condition.
  ExpectCondition :: (v -> Bool) -> String -> Param v
  -- | A parameter that wraps a value without Eq or Show constraints.
  ValueWrapper :: v -> String -> Param v

-- | Class for wrapping raw values into Param.
-- For types with Show and Eq, it uses ExpectValue to enable comparison and display.
-- For other types, it uses ValueWrapper.
class WrapParam a where
  wrap :: a -> Param a

instance {-# OVERLAPPING #-} WrapParam String where
  wrap :: String -> Param String
wrap String
s = String -> String -> Param String
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue String
s (String -> String
forall a. Show a => a -> String
show String
s)

instance {-# OVERLAPPING #-} WrapParam Int where
  wrap :: Int -> Param Int
wrap Int
v = Int -> String -> Param Int
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Int
v (Int -> String
forall a. Show a => a -> String
show Int
v)

instance {-# OVERLAPPING #-} WrapParam Integer where
  wrap :: Integer -> Param Integer
wrap Integer
v = Integer -> String -> Param Integer
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Integer
v (Integer -> String
forall a. Show a => a -> String
show Integer
v)

instance {-# OVERLAPPING #-} WrapParam Bool where
  wrap :: Bool -> Param Bool
wrap Bool
v = Bool -> String -> Param Bool
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Bool
v (Bool -> String
forall a. Show a => a -> String
show Bool
v)

instance {-# OVERLAPPING #-} WrapParam Double where
  wrap :: Double -> Param Double
wrap Double
v = Double -> String -> Param Double
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Double
v (Double -> String
forall a. Show a => a -> String
show Double
v)

instance {-# OVERLAPPING #-} WrapParam Float where
  wrap :: Float -> Param Float
wrap Float
v = Float -> String -> Param Float
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Float
v (Float -> String
forall a. Show a => a -> String
show Float
v)

instance {-# OVERLAPPING #-} WrapParam Char where
  wrap :: Char -> Param Char
wrap Char
v = Char -> String -> Param Char
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Char
v (Char -> String
forall a. Show a => a -> String
show Char
v)

instance {-# OVERLAPPING #-} WrapParam T.Text where
  wrap :: Text -> Param Text
wrap Text
v = Text -> String -> Param Text
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Text
v (Text -> String
forall a. Show a => a -> String
show Text
v)

instance {-# OVERLAPPING #-} (Show a, Eq a) => WrapParam [a] where
  wrap :: [a] -> Param [a]
wrap [a]
v = [a] -> String -> Param [a]
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue [a]
v ([a] -> String
forall a. Show a => a -> String
show [a]
v)

instance {-# OVERLAPPING #-} (Show a, Eq a) => WrapParam (Maybe a) where
  wrap :: Maybe a -> Param (Maybe a)
wrap Maybe a
v = Maybe a -> String -> Param (Maybe a)
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue Maybe a
v (Maybe a -> String
forall a. Show a => a -> String
show Maybe a
v)

instance {-# OVERLAPPABLE #-} WrapParam a where
  wrap :: a -> Param a
wrap a
v = a -> String -> Param a
forall v. v -> String -> Param v
ValueWrapper a
v String
"ValueWrapper"

instance Eq (Param a) where
  (ExpectValue a
a String
_) == :: Param a -> Param a -> Bool
== (ExpectValue a
b String
_) = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
b
  (ExpectValue a
a String
_) == (ExpectCondition a -> Bool
m2 String
_) = a -> Bool
m2 a
a
  (ExpectCondition a -> Bool
m1 String
_) == (ExpectValue a
b String
_) = a -> Bool
m1 a
b
  (ExpectCondition a -> Bool
_ String
l1) == (ExpectCondition a -> Bool
_ String
l2) = String
l1 String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
l2
  ValueWrapper a
a String
_ == ExpectCondition a -> Bool
m String
_ = a -> Bool
m a
a
  ExpectCondition a -> Bool
m String
_ == ValueWrapper a
a String
_ = a -> Bool
m a
a
  ExpectValue a
a String
_ == ValueWrapper a
b String
_ = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
b
  ValueWrapper a
a String
_ == ExpectValue a
b String
_ = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
b
  ValueWrapper a
_ String
_ == ValueWrapper a
_ String
_ = Bool
False

instance Show (Param v) where
  show :: Param v -> String
show (ExpectValue v
_ String
l) = String
l
  show (ExpectCondition v -> Bool
_ String
l) = String
l
  show (ValueWrapper v
_ String
l) = String
l

value :: Param v -> v
value :: forall v. Param v -> v
value (ExpectValue v
a String
_) = v
a
value (ValueWrapper v
a String
_) = v
a
value Param v
_ = String -> v
forall a. HasCallStack => String -> a
error String
"not implemented"

-- | Create a Param from a value. Requires Eq and Show.
param :: (Show v, Eq v) => v -> Param v
param :: forall v. (Show v, Eq v) => v -> Param v
param v
v = v -> String -> Param v
forall v. (Show v, Eq v) => v -> String -> Param v
ExpectValue v
v (v -> String
forall a. Show a => a -> String
show v
v)


-- | Type family to untie the knot for ConsGen instances
type family Normalize a where
  Normalize (a :> b) = a :> b
  Normalize (Param a) = Param a
  Normalize a = Param a

class ToParamArg a where
  toParamArg :: a -> Normalize a

instance {-# OVERLAPPING #-} (Typeable (a -> b)) => ToParamArg (a -> b) where
  toParamArg :: (a -> b) -> Normalize (a -> b)
toParamArg a -> b
f = ((a -> b) -> Bool) -> String -> Param (a -> b)
forall v. (v -> Bool) -> String -> Param v
ExpectCondition ((a -> b) -> (a -> b) -> Bool
forall a. a -> a -> Bool
compareFunction a -> b
f) ((a -> b) -> String
forall a. Typeable a => a -> String
showFunction a -> b
f)

instance {-# OVERLAPPING #-} ToParamArg (Param a) where
  toParamArg :: Param a -> Normalize (Param a)
toParamArg = Param a -> Normalize (Param a)
Param a -> Param a
forall a. a -> a
id

instance {-# OVERLAPPABLE #-} (Normalize a ~ Param a, WrapParam a) => ToParamArg a where
  toParamArg :: a -> Normalize a
toParamArg = a -> Normalize a
a -> Param a
forall a. WrapParam a => a -> Param a
wrap

class ToParamResult b where
  toParamResult :: b -> Normalize b

instance {-# OVERLAPPING #-} ToParamResult (Param a) where
  toParamResult :: Param a -> Normalize (Param a)
toParamResult = Param a -> Normalize (Param a)
Param a -> Param a
forall a. a -> a
id

instance {-# OVERLAPPING #-} ToParamResult (a :> b) where
  toParamResult :: (a :> b) -> Normalize (a :> b)
toParamResult = (a :> b) -> a :> b
(a :> b) -> Normalize (a :> b)
forall a. a -> a
id

instance {-# OVERLAPPABLE #-} (Normalize b ~ Param b, WrapParam b) => ToParamResult b where
  toParamResult :: b -> Normalize b
toParamResult = b -> Normalize b
b -> Param b
forall a. WrapParam a => a -> Param a
wrap

class ConsGen a b where
  (~>) :: a -> b -> Normalize a :> Normalize b

instance (ToParamArg a, ToParamResult b) => ConsGen a b where
  ~> :: a -> b -> Normalize a :> Normalize b
(~>) a
a b
b = Normalize a -> Normalize b -> Normalize a :> Normalize b
forall a b. a -> b -> a :> b
(:>) (a -> Normalize a
forall a. ToParamArg a => a -> Normalize a
toParamArg a
a) (b -> Normalize b
forall b. ToParamResult b => b -> Normalize b
toParamResult b
b)

-- | Make a parameter to which any value is expected to apply.
--   Use with type application to specify the type: @any \@String@
--
--   > f <- mock $ any ~> True
any :: forall a. Param a
any :: forall a. Param a
any = (a -> Bool) -> String -> Param a
forall v. (v -> Bool) -> String -> Param v
ExpectCondition (Bool -> a -> Bool
forall a b. a -> b -> a
const Bool
True) String
"any"

{- | Create a conditional parameter with a label.
    When calling a mock function, if the argument does not satisfy this condition, an error occurs.
    In this case, the specified label is included in the error message.

    > expect (>5) ">5"
-}
expect :: (a -> Bool) -> String -> Param a
expect :: forall v. (v -> Bool) -> String -> Param v
expect = (a -> Bool) -> String -> Param a
forall v. (v -> Bool) -> String -> Param v
ExpectCondition

{- | Create a conditional parameter without a label.
  The error message is displayed as "[some condition]".

  > expect_ (>5)
-}
expect_ :: (a -> Bool) -> Param a
expect_ :: forall a. (a -> Bool) -> Param a
expect_ a -> Bool
f = (a -> Bool) -> String -> Param a
forall v. (v -> Bool) -> String -> Param v
ExpectCondition a -> Bool
f String
"[some condition]"

-- | The type of the argument parameters of the parameters.
type family ArgsOf params where
  ArgsOf (Head :> Param r) = ()                        -- Constant value has no arguments
  ArgsOf (Param a :> Param r) = Param a
  ArgsOf (Param a :> rest) = Param a :> ArgsOf rest

-- | Class for projecting the arguments of the parameter.
class ProjectionArgs params where
  projArgs :: params -> ArgsOf params

instance {-# OVERLAPPING #-} ProjectionArgs (Head :> Param r) where
  projArgs :: (Head :> Param r) -> ArgsOf (Head :> Param r)
projArgs (Head
_ :> Param r
_) = ()

instance {-# OVERLAPPING #-} ProjectionArgs (Param a :> Param r) where
  projArgs :: (Param a :> Param r) -> ArgsOf (Param a :> Param r)
projArgs (Param a
a :> Param r
_) = ArgsOf (Param a :> Param r)
Param a
a

instance
  {-# OVERLAPPABLE #-}
  (ProjectionArgs rest, ArgsOf (Param a :> rest) ~ (Param a :> ArgsOf rest)) =>
  ProjectionArgs (Param a :> rest) where
  projArgs :: (Param a :> rest) -> ArgsOf (Param a :> rest)
projArgs (Param a
a :> rest
rest) = Param a
a Param a -> ArgsOf rest -> Param a :> ArgsOf rest
forall a b. a -> b -> a :> b
:> rest -> ArgsOf rest
forall params. ProjectionArgs params => params -> ArgsOf params
projArgs rest
rest

-- | The type of the return parameter of the parameters.
type family ReturnOf params where
  ReturnOf (Head :> Param r) = Param r                 -- Constant value returns Param r
  ReturnOf (Param a :> Param r) = Param r
  ReturnOf (Param a :> rest) = ReturnOf rest

class ProjectionReturn param where
  projReturn :: param -> ReturnOf param

instance {-# OVERLAPPING #-} ProjectionReturn (Head :> Param r) where
  projReturn :: (Head :> Param r) -> ReturnOf (Head :> Param r)
projReturn (Head
_ :> Param r
r) = ReturnOf (Head :> Param r)
Param r
r

instance {-# OVERLAPPING #-} ProjectionReturn (Param a :> Param r) where
  projReturn :: (Param a :> Param r) -> ReturnOf (Param a :> Param r)
projReturn (Param a
_ :> Param r
r) = ReturnOf (Param a :> Param r)
Param r
r

instance
  {-# OVERLAPPABLE #-}
  (ProjectionReturn rest, ReturnOf (Param a :> rest) ~ ReturnOf rest) =>
  ProjectionReturn (Param a :> rest) where
  projReturn :: (Param a :> rest) -> ReturnOf (Param a :> rest)
projReturn (Param a
_ :> rest
rest) = rest -> ReturnOf rest
forall param. ProjectionReturn param => param -> ReturnOf param
projReturn rest
rest

returnValue :: (ProjectionReturn params, ReturnOf params ~ Param r) => params -> r
returnValue :: forall params r.
(ProjectionReturn params, ReturnOf params ~ Param r) =>
params -> r
returnValue = Param r -> r
forall v. Param v -> v
value (Param r -> r) -> (params -> Param r) -> params -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. params -> ReturnOf params
params -> Param r
forall param. ProjectionReturn param => param -> ReturnOf param
projReturn

-- | Get the pointer address of a value (used for both comparison and display)
getPtrAddr :: forall a. a -> IntPtr
getPtrAddr :: forall a. a -> IntPtr
getPtrAddr a
x = Ptr Any -> IntPtr
forall a. Ptr a -> IntPtr
ptrToIntPtr (Ptr () -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr (a -> Ptr ()
forall a b. a -> b
unsafeCoerce a
x :: Ptr ()))

-- | Helper function to compare function values using pointer equality
-- Uses the same pointer calculation as showFunction for consistency
compareFunction :: forall a. a -> a -> Bool
compareFunction :: forall a. a -> a -> Bool
compareFunction a
x a
y = a -> IntPtr
forall a. a -> IntPtr
getPtrAddr a
x IntPtr -> IntPtr -> Bool
forall a. Eq a => a -> a -> Bool
== a -> IntPtr
forall a. a -> IntPtr
getPtrAddr a
y

-- | Show function using type information and a pointer hash
showFunction :: forall a. Typeable a => a -> String
showFunction :: forall a. Typeable a => a -> String
showFunction a
x =
  let typeStr :: String
typeStr = TypeRep -> String
forall a. Show a => a -> String
show (a -> TypeRep
forall a. Typeable a => a -> TypeRep
typeOf a
x)
      -- Use the same pointer address calculation as compareFunction
      ptrAddr :: String
ptrAddr = IntPtr -> String
forall a. Show a => a -> String
show (a -> IntPtr
forall a. a -> IntPtr
getPtrAddr a
x)
   in String
typeStr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"@" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
ptrAddr