{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}

module Test.WebDriver.Commands.BiDi.Session (
  withBiDiSession
  , withBiDiSession'

  , BiDiEvent(..)
  , BiDiResponse(..)

  , BiDiOptions(..)
  , defaultBiDiOptions
  ) where

import Control.Monad (forever)
import Control.Monad.Fix (fix)
import Control.Monad.IO.Unlift
import Control.Monad.Logger (MonadLogger, logDebugN, logErrorN)
import Data.Aeson
import Data.Aeson.TH
import qualified Data.List as L
import Data.String.Interpolate
import Data.Text (Text)
import qualified Network.URI as URI
import qualified Network.WebSockets as WS
import Test.WebDriver.Capabilities.Aeson
import Test.WebDriver.Types
import Text.Read (readMaybe)
import UnliftIO.Async (withAsync)
import UnliftIO.Exception
import UnliftIO.STM (atomically, stateTVar)
import UnliftIO.Timeout (timeout)


data BiDiEvent = BiDiEvent {
  BiDiEvent -> Text
biDiType :: Text
  , BiDiEvent -> Text
biDiMethod :: Text
  , BiDiEvent -> Value
biDiParams :: Value
  } deriving Int -> BiDiEvent -> ShowS
[BiDiEvent] -> ShowS
BiDiEvent -> String
(Int -> BiDiEvent -> ShowS)
-> (BiDiEvent -> String)
-> ([BiDiEvent] -> ShowS)
-> Show BiDiEvent
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BiDiEvent -> ShowS
showsPrec :: Int -> BiDiEvent -> ShowS
$cshow :: BiDiEvent -> String
show :: BiDiEvent -> String
$cshowList :: [BiDiEvent] -> ShowS
showList :: [BiDiEvent] -> ShowS
Show
deriveFromJSON toCamel2 ''BiDiEvent

data BiDiResponse = BiDiResponse {
  BiDiResponse -> Text
biDiResponseType :: Text
  , BiDiResponse -> Int
biDiResponseId :: Int
  , BiDiResponse -> Maybe Value
biDiResponseResult :: Maybe Value
  , BiDiResponse -> Maybe Value
biDiResponseError :: Maybe Value
  } deriving Int -> BiDiResponse -> ShowS
[BiDiResponse] -> ShowS
BiDiResponse -> String
(Int -> BiDiResponse -> ShowS)
-> (BiDiResponse -> String)
-> ([BiDiResponse] -> ShowS)
-> Show BiDiResponse
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BiDiResponse -> ShowS
showsPrec :: Int -> BiDiResponse -> ShowS
$cshow :: BiDiResponse -> String
show :: BiDiResponse -> String
$cshowList :: [BiDiResponse] -> ShowS
showList :: [BiDiResponse] -> ShowS
Show
deriveFromJSON toCamel3 ''BiDiResponse

-- | Options controlling BiDi session establishment.
data BiDiOptions = BiDiOptions {
  BiDiOptions -> Int
biDiSubscriptionRequestTimeoutUs :: Int
  }
-- | Default BiDi options.
defaultBiDiOptions :: BiDiOptions
defaultBiDiOptions :: BiDiOptions
defaultBiDiOptions = BiDiOptions {
  biDiSubscriptionRequestTimeoutUs :: Int
biDiSubscriptionRequestTimeoutUs = Int
15_000_000
  }

-- | Wrapper around 'withBiDiSession'' which uses the WebSocket URL from
-- the current 'Session'. You must make sure to pass '_capabilitiesWebSocketUrl'
-- = @Just True@ to enable this. This will not work with Selenium 3.
withBiDiSession :: (WebDriver m, MonadLogger m) => BiDiOptions -> [Text] -> (BiDiEvent -> m ()) -> m a -> m a
withBiDiSession :: forall (m :: * -> *) a.
(WebDriver m, MonadLogger m) =>
BiDiOptions -> [Text] -> (BiDiEvent -> m ()) -> m a -> m a
withBiDiSession BiDiOptions
biDiOptions [Text]
events BiDiEvent -> m ()
cb m a
action = do
  Session {String
Maybe String
TVar Int
SessionId
Driver
sessionDriver :: Driver
sessionId :: SessionId
sessionName :: String
sessionWebSocketUrl :: Maybe String
sessionIdCounter :: TVar Int
sessionIdCounter :: Session -> TVar Int
sessionWebSocketUrl :: Session -> Maybe String
sessionName :: Session -> String
sessionId :: Session -> SessionId
sessionDriver :: Session -> Driver
..} <- m Session
forall (m :: * -> *). SessionState m => m Session
getSession
  String
webSocketUrl <- case Maybe String
sessionWebSocketUrl of
    Maybe String
Nothing -> IOError -> m String
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (IOError -> m String) -> IOError -> m String
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError [i|Session wasn't configured with a BiDi WebSocket URL when trying to record logs. Make sure to enable _capabilitiesWebSocketUrl.|]
    Just String
x -> String -> m String
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
x

  URI
uri <- case String -> Maybe URI
URI.parseURI String
webSocketUrl of
    Just URI
x -> URI -> m URI
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure URI
x
    Maybe URI
Nothing -> IOError -> m URI
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (IOError -> m URI) -> IOError -> m URI
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError [i|Couldn't parse WebSocket URL: #{webSocketUrl}|]

  Int
bidiSessionId <- STM Int -> m Int
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Int -> m Int) -> STM Int -> m Int
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> (Int, Int)) -> STM Int
forall s a. TVar s -> (s -> (a, s)) -> STM a
stateTVar TVar Int
sessionIdCounter (\Int
x -> (Int
x, Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

  BiDiOptions
-> Int -> URI -> [Text] -> (BiDiEvent -> m ()) -> m a -> m a
forall (m :: * -> *) a.
(MonadUnliftIO m, MonadLogger m) =>
BiDiOptions
-> Int -> URI -> [Text] -> (BiDiEvent -> m ()) -> m a -> m a
withBiDiSession' BiDiOptions
biDiOptions Int
bidiSessionId URI
uri [Text]
events BiDiEvent -> m ()
cb m a
action

-- | Connect to WebSocket URL and subscribe to the given events using the W3C BiDi protocol; see
-- <https://w3c.github.io/webdriver-bidi/>.
withBiDiSession' :: (MonadUnliftIO m, MonadLogger m) => BiDiOptions -> Int -> URI.URI -> [Text] -> (BiDiEvent -> m ()) -> m a -> m a
withBiDiSession' :: forall (m :: * -> *) a.
(MonadUnliftIO m, MonadLogger m) =>
BiDiOptions
-> Int -> URI -> [Text] -> (BiDiEvent -> m ()) -> m a -> m a
withBiDiSession' BiDiOptions
biDiOptions Int
bidiSessionId uri :: URI
uri@(URI.URI { uriAuthority :: URI -> Maybe URIAuth
uriAuthority=(Just (URI.URIAuth {uriPort :: URIAuth -> String
uriPort=(String -> Maybe Int
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe Int) -> ShowS -> String -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Int -> [a] -> [a]
L.drop Int
1 -> Just (Int
port :: Int)), String
uriUserInfo :: String
uriRegName :: String
uriRegName :: URIAuth -> String
uriUserInfo :: URIAuth -> String
..})), String
uriScheme :: String
uriPath :: String
uriQuery :: String
uriFragment :: String
uriFragment :: URI -> String
uriQuery :: URI -> String
uriPath :: URI -> String
uriScheme :: URI -> String
.. }) [Text]
events BiDiEvent -> m ()
cb m a
action = do
  Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: Connecting to #{uriRegName}:#{port}#{uriPath}|]

  ((forall a. m a -> IO a) -> IO a) -> m a
forall b. ((forall a. m a -> IO a) -> IO b) -> m b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO a) -> m a)
-> ((forall a. m a -> IO a) -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
runInIO -> IO a -> IO a
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ String -> Int -> String -> ClientApp a -> IO a
forall a. String -> Int -> String -> ClientApp a -> IO a
WS.runClient String
uriRegName Int
port String
uriPath (ClientApp a -> IO a) -> ClientApp a -> IO a
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> m a -> IO a
forall a. m a -> IO a
runInIO (m a -> IO a) -> m a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: Connected successfully, sending subscription request with ID #{bidiSessionId}|]
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Value -> ByteString
forall a. ToJSON a => a -> ByteString
encode (Value -> ByteString) -> Value -> ByteString
forall a b. (a -> b) -> a -> b
$ [Pair] -> Value
object [
      Key
"id" Key -> Int -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Int
bidiSessionId
      , Key
"method" Key -> Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= (Text
"session.subscribe" :: Text)
      , Key
"params" Key -> Value -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [Pair] -> Value
object [
          Key
"events" Key -> [Text] -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= ([Text]
events :: [Text])
        ]
      ]

    Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN Text
"BiDi: Sent subscription request, waiting for response..."
    Int
-> m (Either SomeException ())
-> m (Maybe (Either SomeException ()))
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout (BiDiOptions -> Int
biDiSubscriptionRequestTimeoutUs BiDiOptions
biDiOptions) (Int -> Connection -> m (Either SomeException ())
forall (m :: * -> *).
(MonadIO m, MonadLogger m) =>
Int -> Connection -> m (Either SomeException ())
waitForSubscriptionResult Int
bidiSessionId Connection
conn) m (Maybe (Either SomeException ()))
-> (Maybe (Either SomeException ()) -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Maybe (Either SomeException ())
Nothing -> IOError -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (IOError -> m a) -> IOError -> m a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"BiDi: Subscription response timed out"
      Just (Left SomeException
err) ->
        IOError -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (IOError -> m a) -> IOError -> m a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError [i|BiDi: got exception (URI #{uri}): #{err}|]
      Just (Right ()) -> do
        Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN Text
"BiDi: Starting log event listener"
        m Any -> (Async Any -> m a) -> m a
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> (Async a -> m b) -> m b
withAsync (Connection -> m Any
forall {b}. Connection -> m b
messageListener Connection
conn) ((Async Any -> m a) -> m a) -> (Async Any -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Async Any
_messageListenerAsy -> do
          m a -> m () -> m a
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
finally m a
action (m () -> m a) -> m () -> m a
forall a b. (a -> b) -> a -> b
$ do
            Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: finished wrapped action|]
            IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> Text -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendClose Connection
conn ([i|Finishing session #{bidiSessionId}|] :: Text)
  where
    messageListener :: Connection -> m b
messageListener Connection
conn =
      m () -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m b) -> m () -> m b
forall a b. (a -> b) -> a -> b
$
        (ByteString -> Maybe BiDiEvent
forall a. FromJSON a => ByteString -> Maybe a
decode (ByteString -> Maybe BiDiEvent)
-> m ByteString -> m (Maybe BiDiEvent)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
conn) m (Maybe BiDiEvent) -> (Maybe BiDiEvent -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Just (BiDiEvent
x :: BiDiEvent) -> BiDiEvent -> m ()
cb BiDiEvent
x
          Maybe BiDiEvent
x -> Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: Ignoring non-log event message: #{x}|]
withBiDiSession' BiDiOptions
_ Int
_ URI
uri [Text]
_events BiDiEvent -> m ()
_cb m a
_action =
  IOError -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (IOError -> m a) -> IOError -> m a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError [i|WebSocket URL didn't contain an authority: #{uri}|]


waitForSubscriptionResult :: (
  MonadIO m, MonadLogger m
  ) => Int -> WS.Connection -> m (Either SomeException ())
waitForSubscriptionResult :: forall (m :: * -> *).
(MonadIO m, MonadLogger m) =>
Int -> Connection -> m (Either SomeException ())
waitForSubscriptionResult Int
bidiSessionId Connection
conn = (m (Either SomeException ()) -> m (Either SomeException ()))
-> m (Either SomeException ())
forall a. (a -> a) -> a
fix ((m (Either SomeException ()) -> m (Either SomeException ()))
 -> m (Either SomeException ()))
-> (m (Either SomeException ()) -> m (Either SomeException ()))
-> m (Either SomeException ())
forall a b. (a -> b) -> a -> b
$ \m (Either SomeException ())
loop -> do
  ByteString
msg <- IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
conn
  Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: Waiting for subscription response: #{msg}|]
  case ByteString -> Maybe BiDiResponse
forall a. FromJSON a => ByteString -> Maybe a
decode ByteString
msg of
    Just response :: BiDiResponse
response@(BiDiResponse Text
responseType Int
responseId Maybe Value
_ Maybe Value
_)
      | Text
responseType Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"success" Bool -> Bool -> Bool
&& Int
responseId Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
bidiSessionId -> do
          Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN Text
"BiDi: Subscription successful!"
          Either SomeException () -> m (Either SomeException ())
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SomeException () -> m (Either SomeException ()))
-> Either SomeException () -> m (Either SomeException ())
forall a b. (a -> b) -> a -> b
$ () -> Either SomeException ()
forall a b. b -> Either a b
Right ()
      | Text
responseType Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"error" Bool -> Bool -> Bool
&& Int
responseId Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
bidiSessionId -> do
          let errMsg :: String
errMsg = String
"BiDi subscription failed: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BiDiResponse -> String
forall a. Show a => a -> String
show BiDiResponse
response
          Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logErrorN [i|BiDi: #{errMsg}|]
          Either SomeException () -> m (Either SomeException ())
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SomeException () -> m (Either SomeException ()))
-> Either SomeException () -> m (Either SomeException ())
forall a b. (a -> b) -> a -> b
$ SomeException -> Either SomeException ()
forall a b. a -> Either a b
Left (IOError -> SomeException
forall e. Exception e => e -> SomeException
toException (String -> IOError
userError String
errMsg))
      | Bool
otherwise -> do
          Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: Ignoring response with type #{responseType}, ID #{responseId}|]
          m (Either SomeException ())
loop
    Maybe BiDiResponse
Nothing -> do
      Text -> m ()
forall (m :: * -> *). MonadLogger m => Text -> m ()
logDebugN [i|BiDi: Not a BiDiResponse, continuing to wait for subscription response (#{msg})|]
      m (Either SomeException ())
loop

-- * Better WebSocket ping/pong
-- | I've run into problems with the ping/pong support in the websockets
-- library, so I came up with this alternative version that I use in my
-- websockets projects.
--
-- But, according to some searching, it seems that the BiDi client doesn't need
-- to do ping/pong, as we can count on the browser driver to do it. So leaving
-- this unused for now. It might still be useful to use at some point to catch
-- dead drivers etc.

-- data BetterPongTimeout =
--   BetterPongTimeoutUnexpectedResponse BL.ByteString
--   | BetterPongTimeoutNoResponse
--   deriving Show
-- instance Exception BetterPongTimeout

-- data PingPongOptions = PingPongOptions {
--   pingInterval :: Int -- ^ Interval in seconds
--   , pongTimeout :: Int -- ^ Timeout in seconds
--   , pingAction :: Int -> IO () -- ^ Action to perform after sending a ping
--   , pingMessage :: Text -> IO () -- ^ Message to log
-- }
-- defaultPingPongOptions :: PingPongOptions
-- defaultPingPongOptions = PingPongOptions {
--   pingInterval = 15
--   , pongTimeout = 30
--   , pingAction = const $ return ()
--   , pingMessage = const $ return ()
-- }

-- withBetterPingPong :: MonadUnliftIO m => PingPongOptions -> WS.Connection -> (WS.Connection -> m ()) -> m ()
-- withBetterPingPong (PingPongOptions {..}) connection app = void $
--   withAsync (app connection) $ \appAsync -> do
--     withAsync (liftIO pingPongThread) $ \pingAsync -> do
--       waitAnyCancel [appAsync, pingAsync]
--     where
--       pingPongThread = do
--         -- Make sure the heartbeat MVar is empty
--         _ <- tryTakeMVar (WS.connectionHeartbeat connection)

--         flip withException reportPingPongException $ flip fix (0 :: Int) $ \loop n -> do
--           let bytes :: BL.ByteString = Builder.toLazyByteString $ Builder.int64BE $ fromIntegral n

--           WS.sendPing connection bytes

--           timeout (pongTimeout * 1000 * 1000) (takeMVar (WS.connectionHeartbeat connection)) >>= \case
--             Just _ -> return ()
--             Nothing -> throwIO $ BetterPongTimeoutNoResponse

--           threadDelay (pongTimeout * 1000 * 1000)
--           loop (n + 1)

--       reportPingPongException :: SomeException -> IO ()
--       reportPingPongException (fromException -> Just (AsyncCancelled {})) = return ()
--       reportPingPongException e = pingMessage [i|Ping pong thread had exception: #{e}|]