{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoFieldSelectors #-}

module Skeletest.Internal.TestInfo (
  TestInfo (..),
  withTestInfo,
  getTestInfo,
  lookupTestInfo,
) where

import Control.Monad.IO.Class (MonadIO)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Text (Text)
import GHC.Stack (HasCallStack)
import Skeletest.Internal.Error (skeletestError)
import Skeletest.Internal.Markers (SomeMarker)
import System.IO.Unsafe (unsafePerformIO)
import UnliftIO (MonadUnliftIO)
import UnliftIO.Concurrent (ThreadId, myThreadId)
import UnliftIO.Exception (bracket_)
import UnliftIO.IORef (IORef, modifyIORef, newIORef, readIORef)

data TestInfo = TestInfo
  { TestInfo -> [Text]
contexts :: [Text]
  , TestInfo -> Text
name :: Text
  , TestInfo -> [SomeMarker]
markers :: [SomeMarker]
  -- ^ Markers this test was tagged with or inherited from its groups.
  --
  -- May be queried with 'Skeletest.Plugin.findMarker' or 'Skeletest.Plugin.hasMarker'.
  , TestInfo -> FilePath
file :: FilePath
  -- ^ Relative to CWD
  }
  deriving (Int -> TestInfo -> ShowS
[TestInfo] -> ShowS
TestInfo -> FilePath
(Int -> TestInfo -> ShowS)
-> (TestInfo -> FilePath) -> ([TestInfo] -> ShowS) -> Show TestInfo
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TestInfo -> ShowS
showsPrec :: Int -> TestInfo -> ShowS
$cshow :: TestInfo -> FilePath
show :: TestInfo -> FilePath
$cshowList :: [TestInfo] -> ShowS
showList :: [TestInfo] -> ShowS
Show)

type TestInfoMap = Map ThreadId TestInfo

testInfoMapRef :: IORef TestInfoMap
testInfoMapRef :: IORef TestInfoMap
testInfoMapRef = IO (IORef TestInfoMap) -> IORef TestInfoMap
forall a. IO a -> a
unsafePerformIO (IO (IORef TestInfoMap) -> IORef TestInfoMap)
-> IO (IORef TestInfoMap) -> IORef TestInfoMap
forall a b. (a -> b) -> a -> b
$ TestInfoMap -> IO (IORef TestInfoMap)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef TestInfoMap
forall k a. Map k a
Map.empty
{-# NOINLINE testInfoMapRef #-}

withTestInfo :: (MonadUnliftIO m) => TestInfo -> m a -> m a
withTestInfo :: forall (m :: * -> *) a. MonadUnliftIO m => TestInfo -> m a -> m a
withTestInfo TestInfo
info m a
m = do
  ThreadId
tid <- m ThreadId
forall (m :: * -> *). MonadIO m => m ThreadId
myThreadId
  m () -> m () -> m a -> m a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> m b -> m c -> m c
bracket_ (ThreadId -> m ()
forall {m :: * -> *}. MonadIO m => ThreadId -> m ()
set ThreadId
tid) (ThreadId -> m ()
forall {m :: * -> *}. MonadIO m => ThreadId -> m ()
unset ThreadId
tid) m a
m
 where
  set :: ThreadId -> m ()
set ThreadId
tid = IORef TestInfoMap -> (TestInfoMap -> TestInfoMap) -> m ()
forall (m :: * -> *) a. MonadIO m => IORef a -> (a -> a) -> m ()
modifyIORef IORef TestInfoMap
testInfoMapRef ((TestInfoMap -> TestInfoMap) -> m ())
-> (TestInfoMap -> TestInfoMap) -> m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> TestInfo -> TestInfoMap -> TestInfoMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ThreadId
tid TestInfo
info
  unset :: ThreadId -> m ()
unset ThreadId
tid = IORef TestInfoMap -> (TestInfoMap -> TestInfoMap) -> m ()
forall (m :: * -> *) a. MonadIO m => IORef a -> (a -> a) -> m ()
modifyIORef IORef TestInfoMap
testInfoMapRef ((TestInfoMap -> TestInfoMap) -> m ())
-> (TestInfoMap -> TestInfoMap) -> m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> TestInfoMap -> TestInfoMap
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ThreadId
tid

lookupTestInfo :: (MonadIO m) => m (Maybe TestInfo)
lookupTestInfo :: forall (m :: * -> *). MonadIO m => m (Maybe TestInfo)
lookupTestInfo = do
  ThreadId
tid <- m ThreadId
forall (m :: * -> *). MonadIO m => m ThreadId
myThreadId
  ThreadId -> TestInfoMap -> Maybe TestInfo
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ThreadId
tid (TestInfoMap -> Maybe TestInfo)
-> m TestInfoMap -> m (Maybe TestInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef TestInfoMap -> m TestInfoMap
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef TestInfoMap
testInfoMapRef

getTestInfo :: (MonadIO m, HasCallStack) => m TestInfo
getTestInfo :: forall (m :: * -> *). (MonadIO m, HasCallStack) => m TestInfo
getTestInfo =
  m (Maybe TestInfo)
forall (m :: * -> *). MonadIO m => m (Maybe TestInfo)
lookupTestInfo m (Maybe TestInfo) -> (Maybe TestInfo -> m TestInfo) -> m TestInfo
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just TestInfo
info -> TestInfo -> m TestInfo
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TestInfo
info
    Maybe TestInfo
Nothing -> Text -> m TestInfo
forall (m :: * -> *) a. MonadIO m => Text -> m a
skeletestError Text
"getTestInfo was called from outside a test context"