module Test.WebDriver.Util.Sockets (
waitForSocket
, connectRetryPolicy
) where
import Control.Exception.Safe
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.IO.Unlift
import Control.Monad.Logger
import Control.Retry
import Data.String.Interpolate
import Network.Socket
import UnliftIO.Timeout
connectRetryPolicy :: MonadIO m => RetryPolicyM m
connectRetryPolicy :: forall (m :: * -> *). MonadIO m => RetryPolicyM m
connectRetryPolicy = Int -> RetryPolicyM m -> RetryPolicyM m
forall (m :: * -> *).
Monad m =>
Int -> RetryPolicyM m -> RetryPolicyM m
capDelay (Int
3000000) (Int -> RetryPolicyM m
forall (m :: * -> *). MonadIO m => Int -> RetryPolicyM m
fullJitterBackoff Int
100000)
waitForSocket :: (
MonadUnliftIO m, MonadLogger m, MonadMask m
)
=> RetryPolicyM m
-> AddrInfo
-> m ()
waitForSocket :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLogger m, MonadMask m) =>
RetryPolicyM m -> AddrInfo -> m ()
waitForSocket RetryPolicyM m
policy AddrInfo
addr = RetryPolicyM m -> (RetryStatus -> m ()) -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
RetryPolicyM m -> (RetryStatus -> m a) -> m a
recoverAll RetryPolicyM m
policy ((RetryStatus -> m ()) -> m ()) -> (RetryStatus -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \retryStatus :: RetryStatus
retryStatus@(RetryStatus {Int
Maybe Int
rsIterNumber :: Int
rsCumulativeDelay :: Int
rsPreviousDelay :: Maybe Int
rsPreviousDelay :: RetryStatus -> Maybe Int
rsCumulativeDelay :: RetryStatus -> Int
rsIterNumber :: RetryStatus -> Int
..}) -> do
(m () -> (SomeException -> m ()) -> m ())
-> (SomeException -> m ()) -> m () -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip m () -> (SomeException -> m ()) -> m ()
forall (m :: * -> *) e a b.
(HasCallStack, MonadMask m, Exception e) =>
m a -> (e -> m b) -> m a
withException (\(SomeException
e :: SomeException) -> Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logErrorN [i|waitForSocket: failed to connect to #{addr}: #{e}|]) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rsIterNumber Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|waitForSocket: attempt \##{rsIterNumber} to connect to #{addr}|]
m Socket -> (Socket -> m ()) -> (Socket -> m ()) -> m ()
forall (m :: * -> *) a b c.
(HasCallStack, MonadMask m) =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (IO Socket -> m Socket
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Socket
initSocket) (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Socket -> IO ()) -> Socket -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO ()
close) ((Socket -> m ()) -> m ()) -> (Socket -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
Int
connectTimeoutUs <- (RetryPolicyM m -> RetryStatus -> m (Maybe Int)
forall (m :: * -> *).
RetryPolicyM m -> RetryStatus -> m (Maybe Int)
getRetryPolicyM RetryPolicyM m
forall (m :: * -> *). MonadIO m => RetryPolicyM m
connectRetryPolicy) RetryStatus
retryStatus m (Maybe Int) -> (Maybe Int -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe Int
Nothing -> IOError -> m Int
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> m Int) -> IOError -> m Int
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Timeout due to connect retry policy"
Just Int
us -> Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
us
Int -> m () -> m (Maybe ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout Int
connectTimeoutUs (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)) m (Maybe ()) -> (Maybe () -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe ()
Nothing -> IOError -> m ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> m ()) -> IOError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Timeout in connect attempt"
Just () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
initSocket :: IO Socket
initSocket = Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr)