{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Test.MockCat.Internal.MockRegistry
  ( attachVerifierToFn
  , lookupVerifierForFn
  , register
  , registerUnitMeta
  , lookupUnitMeta
  , UnitMeta
  , withUnitGuard
  , withAllUnitGuards
  , markUnitUsed
  , isGuardActive
  ) where

import Test.MockCat.Internal.Registry.Core
  ( attachVerifierToFn
  , lookupVerifierForFn
  , registerUnitMeta
  , lookupUnitMeta
  , UnitMeta
  , withUnitGuard
  , withAllUnitGuards
  , markUnitUsed
  , isGuardActive
  )
import GHC.IO (evaluate)
import Control.Concurrent.STM (TVar, atomically, writeTVar)
import Test.MockCat.Internal.Types (MockName, InvocationRecorder(..), InvocationRecord, perform)
import Data.Proxy (Proxy(..))
import Data.Dynamic
import Test.MockCat.Internal.Builder (invocationRecord, appendCalledParams)
import Type.Reflection (TyCon, splitApps, typeRep, typeRepTyCon)
import Data.Typeable (eqT)
import Data.Type.Equality ((:~:) (Refl))

ioTyCon :: TyCon
ioTyCon :: TyCon
ioTyCon = TypeRep (IO ()) -> TyCon
forall {k} (a :: k). TypeRep a -> TyCon
typeRepTyCon (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @(IO ()))

isIOType :: forall a. Typeable a => Proxy a -> Bool
isIOType :: forall a. Typeable a => Proxy a -> Bool
isIOType Proxy a
_ =
  case TypeRep a -> (TyCon, [SomeTypeRep])
forall {k} (a :: k). TypeRep a -> (TyCon, [SomeTypeRep])
splitApps (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) of
    (TyCon
tc, [SomeTypeRep]
_) -> TyCon
tc TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
ioTyCon

-- | Wrap a function value for unit-typed stubs so that calls are tracked.
-- This uses the UnitMeta guard to avoid double-counting when both the tracked
-- and base values are registered. The wrapper will mark the unit meta used and
-- append an invocation to the recorder's TVar when appropriate.
wrapUnitStub ::
  forall fn.
  Typeable fn =>
  TVar (InvocationRecord ()) ->
  UnitMeta ->
  fn ->
  fn
wrapUnitStub :: forall fn.
Typeable fn =>
TVar (InvocationRecord ()) -> UnitMeta -> fn -> fn
wrapUnitStub TVar (InvocationRecord ())
ref UnitMeta
meta fn
value =
  let trackedValue :: fn
trackedValue = IO fn -> fn
forall a. IO a -> a
perform (IO fn -> fn) -> IO fn -> fn
forall a b. (a -> b) -> a -> b
$ do
        Bool
guardActive <- UnitMeta -> IO Bool
isGuardActive UnitMeta
meta
        if Bool
guardActive Bool -> Bool -> Bool
|| Proxy fn -> Bool
forall a. Typeable a => Proxy a -> Bool
isIOType (Proxy fn
forall {k} (t :: k). Proxy t
Proxy :: Proxy fn)
          then fn -> IO fn
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure fn
value
          else do
            UnitMeta -> IO ()
markUnitUsed UnitMeta
meta
            TVar (InvocationRecord ()) -> () -> IO ()
forall params. TVar (InvocationRecord params) -> params -> IO ()
appendCalledParams TVar (InvocationRecord ())
ref ()
            fn -> IO fn
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure fn
value
  in
    fn
trackedValue


-- | Register a recorder for a function in the global mock registry.
-- This handles the special '()' (unit) case by creating a tracked wrapper
-- and registering both the tracked and base values so StableName lookup
-- succeeds regardless of which closure is later passed for verification.
register ::
  forall fn params.
  ( Typeable params
  , Typeable (InvocationRecorder params)
  , Typeable fn
  ) =>
  Maybe MockName ->
  InvocationRecorder params ->
  fn ->
  IO fn
register :: forall fn params.
(Typeable params, Typeable (InvocationRecorder params),
 Typeable fn) =>
Maybe MockName -> InvocationRecorder params -> fn -> IO fn
register Maybe MockName
name recorder :: InvocationRecorder params
recorder@(InvocationRecorder {invocationRef :: forall params.
InvocationRecorder params -> TVar (InvocationRecord params)
invocationRef = TVar (InvocationRecord params)
ref}) fn
fn = do
  fn
baseValue <- fn -> IO fn
forall a. a -> IO a
evaluate fn
fn
  case Maybe (params :~: ())
forall {k} (a :: k) (b :: k).
(Typeable a, Typeable b) =>
Maybe (a :~: b)
eqT :: Maybe (params :~: ()) of
    Just params :~: ()
Refl -> do
      UnitMeta
meta <- TVar (InvocationRecord params) -> IO UnitMeta
forall ref. TVar ref -> IO UnitMeta
registerUnitMeta TVar (InvocationRecord params)
ref
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (InvocationRecord params) -> InvocationRecord params -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (InvocationRecord params)
ref InvocationRecord params
forall params. InvocationRecord params
invocationRecord
      let trackedValue :: fn
trackedValue = TVar (InvocationRecord ()) -> UnitMeta -> fn -> fn
forall fn.
Typeable fn =>
TVar (InvocationRecord ()) -> UnitMeta -> fn -> fn
wrapUnitStub TVar (InvocationRecord params)
TVar (InvocationRecord ())
ref UnitMeta
meta fn
baseValue
      UnitMeta -> IO () -> IO ()
forall a. UnitMeta -> IO a -> IO a
withUnitGuard UnitMeta
meta (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        fn -> (Maybe MockName, InvocationRecorder params) -> IO ()
forall fn params.
Typeable (InvocationRecorder params) =>
fn -> (Maybe MockName, InvocationRecorder params) -> IO ()
attachVerifierToFn fn
trackedValue (Maybe MockName
name, InvocationRecorder params
recorder)
        fn -> (Maybe MockName, InvocationRecorder params) -> IO ()
forall fn params.
Typeable (InvocationRecorder params) =>
fn -> (Maybe MockName, InvocationRecorder params) -> IO ()
attachVerifierToFn fn
baseValue (Maybe MockName
name, InvocationRecorder params
recorder)
      fn -> IO fn
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure fn
trackedValue
    Maybe (params :~: ())
Nothing -> do
      fn -> (Maybe MockName, InvocationRecorder params) -> IO ()
forall fn params.
Typeable (InvocationRecorder params) =>
fn -> (Maybe MockName, InvocationRecorder params) -> IO ()
attachVerifierToFn fn
baseValue (Maybe MockName
name, InvocationRecorder params
recorder)
      fn -> IO fn
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure fn
baseValue