module Test.WebDriver.Util.Sockets (
  waitForSocket
  , connectRetryPolicy
  ) where

import Control.Exception.Safe
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.IO.Unlift
import Control.Monad.Logger
import Control.Retry
import Data.String.Interpolate
import Network.Socket
import UnliftIO.Timeout



-- * Socket waits

-- | Each individual connect attempt needs a timeout to prevent it from hanging
-- indefinitely. This policy allows us to make that timeout length adaptive,
-- based on the 'RetryStatus' of the outer retry policy.
--
-- Thus, the first attempt to connect will have a short timeout (currently 100ms),
-- and then successive attempts will get longer timeouts via "FullJitter" backoff.
-- The goals of this are twofold:
--
-- 1) If a connect call hangs during the first few attempts, it is timed out quickly
-- and re-attempted, so on a healthy network you aren't penalized too much by the hang.
-- The outer retry policy can control the time between attempts, so the user can set
-- it high enough to make this be the case.
--
-- 2) If the network is slow, we will eventually reach the maximum timeout of 3 seconds,
-- which should be long enough. Note that the popular wait-for script uses 1 second
-- timeouts, so this is extra conservative:
-- https://github.com/eficode/wait-for/blob/7586b3622f010808bb2027c19aaf367221b4ad54/wait-for#L72
connectRetryPolicy :: MonadIO m => RetryPolicyM m
connectRetryPolicy :: forall (m :: * -> *). MonadIO m => RetryPolicyM m
connectRetryPolicy = Int -> RetryPolicyM m -> RetryPolicyM m
forall (m :: * -> *).
Monad m =>
Int -> RetryPolicyM m -> RetryPolicyM m
capDelay (Int
3000000) (Int -> RetryPolicyM m
forall (m :: * -> *). MonadIO m => Int -> RetryPolicyM m
fullJitterBackoff Int
100000)

waitForSocket :: (
  MonadUnliftIO m, MonadLogger m, MonadMask m
  )
  => RetryPolicyM m
  -> AddrInfo
  -> m ()
waitForSocket :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLogger m, MonadMask m) =>
RetryPolicyM m -> AddrInfo -> m ()
waitForSocket RetryPolicyM m
policy AddrInfo
addr = RetryPolicyM m -> (RetryStatus -> m ()) -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
RetryPolicyM m -> (RetryStatus -> m a) -> m a
recoverAll RetryPolicyM m
policy ((RetryStatus -> m ()) -> m ()) -> (RetryStatus -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \retryStatus :: RetryStatus
retryStatus@(RetryStatus {Int
Maybe Int
rsIterNumber :: Int
rsCumulativeDelay :: Int
rsPreviousDelay :: Maybe Int
rsPreviousDelay :: RetryStatus -> Maybe Int
rsCumulativeDelay :: RetryStatus -> Int
rsIterNumber :: RetryStatus -> Int
..}) -> do
  (m () -> (SomeException -> m ()) -> m ())
-> (SomeException -> m ()) -> m () -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip m () -> (SomeException -> m ()) -> m ()
forall (m :: * -> *) e a b.
(HasCallStack, MonadMask m, Exception e) =>
m a -> (e -> m b) -> m a
withException (\(SomeException
e :: SomeException) -> Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logErrorN [i|waitForSocket: failed to connect to #{addr}: #{e}|]) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rsIterNumber Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
      Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|waitForSocket: attempt \##{rsIterNumber} to connect to #{addr}|]

    m Socket -> (Socket -> m ()) -> (Socket -> m ()) -> m ()
forall (m :: * -> *) a b c.
(HasCallStack, MonadMask m) =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (IO Socket -> m Socket
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Socket
initSocket) (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Socket -> IO ()) -> Socket -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO ()
close) ((Socket -> m ()) -> m ()) -> (Socket -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      Int
connectTimeoutUs <- (RetryPolicyM m -> RetryStatus -> m (Maybe Int)
forall (m :: * -> *).
RetryPolicyM m -> RetryStatus -> m (Maybe Int)
getRetryPolicyM RetryPolicyM m
forall (m :: * -> *). MonadIO m => RetryPolicyM m
connectRetryPolicy) RetryStatus
retryStatus m (Maybe Int) -> (Maybe Int -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe Int
Nothing -> IOError -> m Int
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> m Int) -> IOError -> m Int
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Timeout due to connect retry policy"
        Just Int
us -> Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
us

      Int -> m () -> m (Maybe ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout Int
connectTimeoutUs (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)) m (Maybe ()) -> (Maybe () -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe ()
Nothing -> IOError -> m ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> m ()) -> IOError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Timeout in connect attempt"
        Just () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  where
    initSocket :: IO Socket
initSocket = Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr)