{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}

{-|

This module contains helper functions for waiting.

It can be very useful in tests to retry something, with a reasonable backoff policy to prevent the test from consuming lots of CPU while waiting.

-}


module Test.Sandwich.Waits (
  -- * General waits
  waitUntil
  , waitUntil'
  , defaultRetryPolicy
  ) where

import Control.Monad.IO.Unlift
import Data.String.Interpolate
import Data.Time
import Data.Typeable
import GHC.Stack
import Test.Sandwich
import UnliftIO.Exception
import UnliftIO.Retry
import UnliftIO.Timeout

#if MIN_VERSION_base(4,14,0)
import System.Timeout (Timeout)
#endif


-- | Keep trying an action up to a timeout while it fails with a 'FailureReason'.
-- Use exponential backoff, with delays capped at 1 second.
waitUntil :: forall m a. (HasCallStack, MonadUnliftIO m) => Double -> m a -> m a
waitUntil :: forall (m :: * -> *) a.
(HasCallStack, MonadUnliftIO m) =>
Double -> m a -> m a
waitUntil = RetryPolicy -> Double -> m a -> m a
forall (m :: * -> *) a.
(HasCallStack, MonadUnliftIO m) =>
RetryPolicy -> Double -> m a -> m a
waitUntil' RetryPolicyM m
RetryPolicy
defaultRetryPolicy

-- | The default retry policy.
defaultRetryPolicy :: RetryPolicy
defaultRetryPolicy :: RetryPolicy
defaultRetryPolicy = Int -> RetryPolicyM m -> RetryPolicyM m
forall (m :: * -> *).
Monad m =>
Int -> RetryPolicyM m -> RetryPolicyM m
capDelay Int
1_000_000 (RetryPolicyM m -> RetryPolicyM m)
-> RetryPolicyM m -> RetryPolicyM m
forall a b. (a -> b) -> a -> b
$ Int -> RetryPolicyM m
forall (m :: * -> *). Monad m => Int -> RetryPolicyM m
exponentialBackoff Int
1_000

-- | Same as 'waitUntil', but with a configurable retry policy.
waitUntil' :: forall m a. (HasCallStack, MonadUnliftIO m) => RetryPolicy -> Double -> m a -> m a
waitUntil' :: forall (m :: * -> *) a.
(HasCallStack, MonadUnliftIO m) =>
RetryPolicy -> Double -> m a -> m a
waitUntil' RetryPolicy
policy Double
timeInSeconds m a
action = do
  UTCTime
startTime <- IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime

  RetryPolicyM m
-> [RetryStatus -> Handler m RetryAction]
-> (RetryStatus -> m a)
-> m a
forall (m :: * -> *) a.
MonadUnliftIO m =>
RetryPolicyM m
-> [RetryStatus -> Handler m RetryAction]
-> (RetryStatus -> m a)
-> m a
recoveringDynamic RetryPolicyM m
RetryPolicy
policy [UTCTime -> RetryStatus -> Handler m RetryAction
forall {p}. UTCTime -> p -> Handler m RetryAction
handleFailureReasonException UTCTime
startTime] ((RetryStatus -> m a) -> m a) -> (RetryStatus -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \RetryStatus
_status ->
    m a -> m a
HasCallStack => m a -> m a
rethrowTimeoutExceptionWithCallStack (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
      Int -> m a -> m (Maybe a)
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout (Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Double
timeInSeconds Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
1_000_000)) m a
action m (Maybe a) -> (Maybe a -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe a
Nothing -> String -> m a
forall (m :: * -> *) a. (HasCallStack, MonadIO m) => String -> m a
expectationFailure [i|Action timed out in waitUntil|]
        Just a
x -> a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

  where
    handleFailureReasonException :: UTCTime -> p -> Handler m RetryAction
handleFailureReasonException UTCTime
startTime p
_status = (FailureReason -> m RetryAction) -> Handler m RetryAction
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((FailureReason -> m RetryAction) -> Handler m RetryAction)
-> (FailureReason -> m RetryAction) -> Handler m RetryAction
forall a b. (a -> b) -> a -> b
$ \(FailureReason
_ :: FailureReason) ->
      UTCTime -> m RetryAction
retryUnlessTimedOut UTCTime
startTime

    retryUnlessTimedOut :: UTCTime -> m RetryAction
    retryUnlessTimedOut :: UTCTime -> m RetryAction
retryUnlessTimedOut UTCTime
startTime = do
      UTCTime
now <- IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
      let thresh :: NominalDiffTime
thresh = Pico -> NominalDiffTime
secondsToNominalDiffTime (Double -> Pico
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
timeInSeconds)
      if | (UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
now UTCTime
startTime) NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> NominalDiffTime
thresh -> RetryAction -> m RetryAction
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return RetryAction
DontRetry
         | Bool
otherwise -> RetryAction -> m RetryAction
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return RetryAction
ConsultPolicy

    -- We can only catch the timeout for base >= 4.14.0.0, since before that the Timeout exception wasn't exported
    rethrowTimeoutExceptionWithCallStack :: (HasCallStack) => m a -> m a
    rethrowTimeoutExceptionWithCallStack :: HasCallStack => m a -> m a
rethrowTimeoutExceptionWithCallStack = (SomeException -> m a) -> m a -> m a
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleSyncOrAsync ((SomeException -> m a) -> m a -> m a)
-> (SomeException -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ \(e :: SomeException
e@(SomeException e
inner)) ->
      if
#if MIN_VERSION_base(4,14,0)
        | Just (Timeout
_ :: Timeout) <- SomeException -> Maybe Timeout
forall e. Exception e => SomeException -> Maybe e
fromExceptionUnwrap SomeException
e -> do
            FailureReason -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (FailureReason -> m a) -> FailureReason -> m a
forall a b. (a -> b) -> a -> b
$ Maybe CallStack -> String -> FailureReason
Reason (CallStack -> Maybe CallStack
forall a. a -> Maybe a
Just (CallStack -> CallStack
popCallStack CallStack
HasCallStack => CallStack
callStack)) String
"Timeout in waitUntil"
        | Just (SyncExceptionWrapper (e -> Maybe SomeException
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast -> Just (SomeException (e -> Maybe SomeAsyncException
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast -> Just (SomeAsyncException (e -> Maybe Timeout
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast -> Just (Timeout
_ :: Timeout))))))) <- e -> Maybe SyncExceptionWrapper
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast e
inner -> do
            FailureReason -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (FailureReason -> m a) -> FailureReason -> m a
forall a b. (a -> b) -> a -> b
$ Maybe CallStack -> String -> FailureReason
Reason (CallStack -> Maybe CallStack
forall a. a -> Maybe a
Just (CallStack -> CallStack
popCallStack CallStack
HasCallStack => CallStack
callStack)) String
"Timeout in waitUntil"
#endif
        | Bool
otherwise -> do
            SomeException -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
e