{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Skeletest.Assertions (
  shouldBe,
  shouldNotBe,
  shouldSatisfy,
  shouldNotSatisfy,
  context,
  failTest,
  AssertionFail (..),

  -- * Testable
  Testable,
  runTestable,
) where

import Data.Text (Text)
import Data.Text qualified as Text
import GHC.Stack (HasCallStack)
import GHC.Stack qualified as GHC
import System.IO.Unsafe (unsafePerformIO)
import UnliftIO.Exception (bracket_, throwIO)
import UnliftIO.IORef (IORef, modifyIORef, newIORef, readIORef)

import Skeletest.Internal.Predicate (
  Predicate,
  PredicateResult (..),
  runPredicate,
 )
import Skeletest.Internal.Predicate qualified as P
import Skeletest.Internal.TestInfo (getTestInfo)
import Skeletest.Internal.TestRunner (
  AssertionFail (..),
  FailContext,
  Testable (..),
  testResultPass,
 )

instance Testable IO where
  runTestable :: IO () -> IO TestResult
runTestable IO ()
m = IO ()
m IO () -> IO TestResult -> IO TestResult
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TestResult -> IO TestResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TestResult
testResultPass
  context :: forall a. String -> IO a -> IO a
context = String -> IO a -> IO a
forall a. String -> IO a -> IO a
contextIO
  throwFailure :: forall a. AssertionFail -> IO a
throwFailure = AssertionFail -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO

infix 1 `shouldBe`, `shouldNotBe`, `shouldSatisfy`, `shouldNotSatisfy`

-- | Assert that the given input should match the given value.
-- Equivalent to @actual `shouldSatisfy` P.eq expected@
shouldBe :: (HasCallStack, Testable m, Eq a) => a -> a -> m ()
a
actual shouldBe :: forall (m :: * -> *) a.
(HasCallStack, Testable m, Eq a) =>
a -> a -> m ()
`shouldBe` a
expected = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
GHC.withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ a
actual a -> Predicate m a -> m ()
forall (m :: * -> *) a.
(HasCallStack, Testable m) =>
a -> Predicate m a -> m ()
`shouldSatisfy` a -> Predicate m a
forall a (m :: * -> *). (Eq a, Monad m) => a -> Predicate m a
P.eq a
expected

-- | Assert that the given input should not match the given value.
-- Equivalent to @actual `shouldNotSatisfy` P.eq expected@
shouldNotBe :: (HasCallStack, Testable m, Eq a) => a -> a -> m ()
a
actual shouldNotBe :: forall (m :: * -> *) a.
(HasCallStack, Testable m, Eq a) =>
a -> a -> m ()
`shouldNotBe` a
expected = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
GHC.withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ a
actual a -> Predicate m a -> m ()
forall (m :: * -> *) a.
(HasCallStack, Testable m) =>
a -> Predicate m a -> m ()
`shouldNotSatisfy` a -> Predicate m a
forall a (m :: * -> *). (Eq a, Monad m) => a -> Predicate m a
P.eq a
expected

-- | Assert that the given input should satisfy the given predicate.
shouldSatisfy :: (HasCallStack, Testable m) => a -> Predicate m a -> m ()
a
actual shouldSatisfy :: forall (m :: * -> *) a.
(HasCallStack, Testable m) =>
a -> Predicate m a -> m ()
`shouldSatisfy` Predicate m a
p =
  (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
GHC.withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$
    Predicate m a -> a -> m PredicateResult
forall (m :: * -> *) a.
Monad m =>
Predicate m a -> a -> m PredicateResult
runPredicate Predicate m a
p a
actual m PredicateResult -> (PredicateResult -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      PredicateResult
PredicateSuccess -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      PredicateFail Text
msg -> Text -> m ()
forall (m :: * -> *) a. (HasCallStack, Testable m) => Text -> m a
failTest' Text
msg

-- | Assert that the given input should not satisfy the given predicate.
shouldNotSatisfy :: (HasCallStack, Testable m) => a -> Predicate m a -> m ()
a
actual shouldNotSatisfy :: forall (m :: * -> *) a.
(HasCallStack, Testable m) =>
a -> Predicate m a -> m ()
`shouldNotSatisfy` Predicate m a
p = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
GHC.withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ a
actual a -> Predicate m a -> m ()
forall (m :: * -> *) a.
(HasCallStack, Testable m) =>
a -> Predicate m a -> m ()
`shouldSatisfy` Predicate m a -> Predicate m a
forall (m :: * -> *) a. Monad m => Predicate m a -> Predicate m a
P.not Predicate m a
p

contextIO :: String -> IO a -> IO a
contextIO :: forall a. String -> IO a -> IO a
contextIO String
msg =
  IO () -> IO () -> IO a -> IO a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> m b -> m c -> m c
bracket_
    (IORef FailContext -> (FailContext -> FailContext) -> IO ()
forall (m :: * -> *) a. MonadIO m => IORef a -> (a -> a) -> m ()
modifyIORef IORef FailContext
failContextRef (String -> Text
Text.pack String
msg Text -> FailContext -> FailContext
forall a. a -> [a] -> [a]
:))
    (IORef FailContext -> (FailContext -> FailContext) -> IO ()
forall (m :: * -> *) a. MonadIO m => IORef a -> (a -> a) -> m ()
modifyIORef IORef FailContext
failContextRef (Int -> FailContext -> FailContext
forall a. Int -> [a] -> [a]
drop Int
1))

-- | Unconditionally fail the test with the given message.
failTest :: (HasCallStack, Testable m) => String -> m a
failTest :: forall (m :: * -> *) a. (HasCallStack, Testable m) => String -> m a
failTest = (HasCallStack => String -> m a) -> String -> m a
forall a. HasCallStack => (HasCallStack => a) -> a
GHC.withFrozenCallStack ((HasCallStack => String -> m a) -> String -> m a)
-> (HasCallStack => String -> m a) -> String -> m a
forall a b. (a -> b) -> a -> b
$ Text -> m a
forall (m :: * -> *) a. (HasCallStack, Testable m) => Text -> m a
failTest' (Text -> m a) -> (String -> Text) -> String -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
Text.pack

failTest' :: (HasCallStack, Testable m) => Text -> m a
failTest' :: forall (m :: * -> *) a. (HasCallStack, Testable m) => Text -> m a
failTest' Text
msg = do
  testInfo <- m TestInfo
forall (m :: * -> *). MonadIO m => m TestInfo
getTestInfo
  ctx <- readIORef failContextRef
  throwFailure
    AssertionFail
      { testInfo
      , testFailMessage = msg
      , testFailContext = ctx
      , callStack = GHC.callStack
      }

failContextRef :: IORef FailContext
failContextRef :: IORef FailContext
failContextRef = IO (IORef FailContext) -> IORef FailContext
forall a. IO a -> a
unsafePerformIO (IO (IORef FailContext) -> IORef FailContext)
-> IO (IORef FailContext) -> IORef FailContext
forall a b. (a -> b) -> a -> b
$ FailContext -> IO (IORef FailContext)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef []
{-# NOINLINE failContextRef #-}