{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}

{-|
Helper module for finding free ports, with various options for port ranges, retries, and excluded ports.
-}

module Test.WebDriver.Util.Ports (
  findFreePort

  -- * Exception-throwing versions
  , findFreePortOrException
  , findFreePortOrException'

  -- * Lower-level
  , findFreePortInRange
  , findFreePortInRange'

  -- * Lower-level
  , isPortFree
  , tryOpenAndClosePort
  , ephemeralPortRange
  ) where

import Control.Monad.Catch (MonadCatch, catch)
import Control.Monad.IO.Class
import Control.Retry
import Data.Maybe
import Network.Socket
import System.Random (randomRIO)
import UnliftIO.Exception (SomeException)


-- | Find an unused port in the ephemeral port range.
-- See https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers.
findFreePort :: (MonadIO m, MonadCatch m) => m (Maybe PortNumber)
findFreePort :: forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
m (Maybe PortNumber)
findFreePort = (PortNumber, PortNumber) -> [PortNumber] -> m (Maybe PortNumber)
forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
(PortNumber, PortNumber) -> [PortNumber] -> m (Maybe PortNumber)
findFreePortInRange (PortNumber, PortNumber)
ephemeralPortRange []

-- | Find a free port in the ephemeral range, throwing an exception if one isn't found.
findFreePortOrException :: (MonadIO m, MonadCatch m) => m PortNumber
findFreePortOrException :: forall (m :: * -> *). (MonadIO m, MonadCatch m) => m PortNumber
findFreePortOrException = (PortNumber -> Bool) -> m PortNumber
forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
(PortNumber -> Bool) -> m PortNumber
findFreePortOrException' (Bool -> PortNumber -> Bool
forall a b. a -> b -> a
const Bool
True)

-- | Same as 'findFreePortOrException', but with a callback to test if the port is acceptable or not.
findFreePortOrException' :: (MonadIO m, MonadCatch m) => (PortNumber -> Bool) -> m PortNumber
findFreePortOrException' :: forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
(PortNumber -> Bool) -> m PortNumber
findFreePortOrException' PortNumber -> Bool
isAcceptable = m (Maybe PortNumber)
forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
m (Maybe PortNumber)
findFreePort m (Maybe PortNumber)
-> (Maybe PortNumber -> m PortNumber) -> m PortNumber
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Just PortNumber
port
    | PortNumber -> Bool
isAcceptable PortNumber
port -> PortNumber -> m PortNumber
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return PortNumber
port
    | Bool
otherwise -> (PortNumber -> Bool) -> m PortNumber
forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
(PortNumber -> Bool) -> m PortNumber
findFreePortOrException' PortNumber -> Bool
isAcceptable
  Maybe PortNumber
Nothing -> [Char] -> m PortNumber
forall a. HasCallStack => [Char] -> a
error [Char]
"Couldn't find free port"

-- | Find an unused port in a given range, excluding certain ports.
-- If the retries time out, returns 'Nothing'.
findFreePortInRange :: (
  MonadIO m, MonadCatch m
  )
  -- | Candidate port range
  => (PortNumber, PortNumber)
  -- | Ports to exclude
  -> [PortNumber]
  -> m (Maybe PortNumber)
findFreePortInRange :: forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
(PortNumber, PortNumber) -> [PortNumber] -> m (Maybe PortNumber)
findFreePortInRange = RetryPolicy
-> (PortNumber, PortNumber) -> [PortNumber] -> m (Maybe PortNumber)
forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
RetryPolicy
-> (PortNumber, PortNumber) -> [PortNumber] -> m (Maybe PortNumber)
findFreePortInRange' (Int -> RetryPolicy
limitRetries Int
50)

-- | Same as 'findFreePortInRange', but with a configurable retry policy.
findFreePortInRange' :: forall m. (
  MonadIO m, MonadCatch m
  )
  -- | Retry policy
  => RetryPolicy
  -- | Candidate port range
  -> (PortNumber, PortNumber)
  -- | Ports to exclude
  -> [PortNumber]
  -> m (Maybe PortNumber)
findFreePortInRange' :: forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
RetryPolicy
-> (PortNumber, PortNumber) -> [PortNumber] -> m (Maybe PortNumber)
findFreePortInRange' RetryPolicy
retryPolicy (PortNumber
start, PortNumber
end) [PortNumber]
blacklist = RetryPolicyM m
-> (RetryStatus -> Maybe PortNumber -> m Bool)
-> (RetryStatus -> m (Maybe PortNumber))
-> m (Maybe PortNumber)
forall (m :: * -> *) b.
MonadIO m =>
RetryPolicyM m
-> (RetryStatus -> b -> m Bool) -> (RetryStatus -> m b) -> m b
retrying RetryPolicyM m
RetryPolicy
retryPolicy RetryStatus -> Maybe PortNumber -> m Bool
forall {m :: * -> *} {p} {a}. Monad m => p -> Maybe a -> m Bool
callback (m (Maybe PortNumber) -> RetryStatus -> m (Maybe PortNumber)
forall a b. a -> b -> a
const m (Maybe PortNumber)
findFreePortInRange')
  where
    callback :: p -> Maybe a -> m Bool
callback p
_retryStatus Maybe a
result = Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Maybe a -> Bool
forall a. Maybe a -> Bool
isNothing Maybe a
result

    getAcceptableCandidate :: m PortNumber
    getAcceptableCandidate :: m PortNumber
getAcceptableCandidate = do
      PortNumber
candidate <- IO PortNumber -> m PortNumber
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Integer -> PortNumber
forall a. Num a => Integer -> a
fromInteger (Integer -> PortNumber) -> IO Integer -> IO PortNumber
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Integer, Integer) -> IO Integer
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (PortNumber -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
start, PortNumber -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
end))
      if | PortNumber
candidate PortNumber -> [PortNumber] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PortNumber]
blacklist -> m PortNumber
getAcceptableCandidate
         | Bool
otherwise -> PortNumber -> m PortNumber
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return PortNumber
candidate

    findFreePortInRange' :: m (Maybe PortNumber)
    findFreePortInRange' :: m (Maybe PortNumber)
findFreePortInRange' = do
      PortNumber
candidate <- m PortNumber
getAcceptableCandidate
      PortNumber -> m Bool
forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
PortNumber -> m Bool
isPortFree PortNumber
candidate m Bool -> (Bool -> m (Maybe PortNumber)) -> m (Maybe PortNumber)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> Maybe PortNumber -> m (Maybe PortNumber)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PortNumber
forall a. Maybe a
Nothing
        Bool
True -> Maybe PortNumber -> m (Maybe PortNumber)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe PortNumber -> m (Maybe PortNumber))
-> Maybe PortNumber -> m (Maybe PortNumber)
forall a b. (a -> b) -> a -> b
$ PortNumber -> Maybe PortNumber
forall a. a -> Maybe a
Just PortNumber
candidate

-- | Test if a given 'PortNumber' is currently available.
isPortFree :: (MonadIO m, MonadCatch m) => PortNumber -> m Bool
isPortFree :: forall (m :: * -> *).
(MonadIO m, MonadCatch m) =>
PortNumber -> m Bool
isPortFree PortNumber
candidate = m Bool -> (SomeException -> m Bool) -> m Bool
forall e a. (HasCallStack, Exception e) => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
catch (PortNumber -> m PortNumber
forall (m :: * -> *). MonadIO m => PortNumber -> m PortNumber
tryOpenAndClosePort PortNumber
candidate m PortNumber -> m Bool -> m Bool
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True)
                             (\(SomeException
_ :: SomeException) -> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False)

-- | Test a given 'PortNumber' availability by trying to open and close a socket on it.
-- Throws an exception on failure.
tryOpenAndClosePort :: MonadIO m => PortNumber -> m PortNumber
tryOpenAndClosePort :: forall (m :: * -> *). MonadIO m => PortNumber -> m PortNumber
tryOpenAndClosePort PortNumber
port = IO PortNumber -> m PortNumber
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO PortNumber -> m PortNumber) -> IO PortNumber -> m PortNumber
forall a b. (a -> b) -> a -> b
$ do
  Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
0
  Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
  let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrSocketType = Stream, addrFamily = AF_INET }
  Maybe AddrInfo -> Maybe [Char] -> Maybe [Char] -> IO [AddrInfo]
forall (t :: * -> *).
GetAddrInfo t =>
Maybe AddrInfo -> Maybe [Char] -> Maybe [Char] -> IO (t AddrInfo)
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) ([Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
"127.0.0.1") ([Char] -> Maybe [Char]
forall a. a -> Maybe a
Just ([Char] -> Maybe [Char]) -> [Char] -> Maybe [Char]
forall a b. (a -> b) -> a -> b
$ PortNumber -> [Char]
forall a. Show a => a -> [Char]
show PortNumber
port) IO [AddrInfo] -> ([AddrInfo] -> IO PortNumber) -> IO PortNumber
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ((AddrInfo {addrAddress :: AddrInfo -> SockAddr
addrAddress=SockAddr
addr}):[AddrInfo]
_) -> do
      Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
addr
      Socket -> IO ()
close Socket
sock
      PortNumber -> IO PortNumber
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber -> IO PortNumber) -> PortNumber -> IO PortNumber
forall a b. (a -> b) -> a -> b
$ PortNumber -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port
    [] -> [Char] -> IO PortNumber
forall a. HasCallStack => [Char] -> a
error [Char]
"Couldn't resolve address 127.0.0.1"

-- | The ephemeral port range.
-- See https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers.
ephemeralPortRange :: (PortNumber, PortNumber)
ephemeralPortRange :: (PortNumber, PortNumber)
ephemeralPortRange = (PortNumber
49152, PortNumber
65535)