{-# LANGUAGE OverloadedStrings #-}
module WebMock.Util (requestBodyToByteString) where

import Imports

import Data.ByteString.Builder qualified as Builder
import Data.ByteString.Lazy qualified as L
import Data.Int
import Data.IORef
import Network.HTTP.Client.Internal

requestBodyToByteString :: RequestBody -> IO LazyByteString
requestBodyToByteString :: RequestBody -> IO LazyByteString
requestBodyToByteString = \ case
  RequestBodyLBS LazyByteString
body -> LazyByteString -> IO LazyByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LazyByteString
body
  RequestBodyBS ByteString
body -> LazyByteString -> IO LazyByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> LazyByteString
L.fromStrict ByteString
body)
  RequestBodyBuilder Int64
n Builder
builder -> Int64 -> LazyByteString -> IO LazyByteString
checkLength Int64
n (Builder -> LazyByteString
Builder.toLazyByteString Builder
builder)
  RequestBodyStream Int64
n GivesPopper ()
stream -> GivesPopper () -> IO LazyByteString
streamToByteString GivesPopper ()
stream IO LazyByteString
-> (LazyByteString -> IO LazyByteString) -> IO LazyByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int64 -> LazyByteString -> IO LazyByteString
checkLength Int64
n
  RequestBodyStreamChunked GivesPopper ()
stream -> GivesPopper () -> IO LazyByteString
streamToByteString GivesPopper ()
stream
  RequestBodyIO IO RequestBody
body -> IO RequestBody
body IO RequestBody
-> (RequestBody -> IO LazyByteString) -> IO LazyByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RequestBody -> IO LazyByteString
requestBodyToByteString

streamToByteString :: GivesPopper () -> IO LazyByteString
streamToByteString :: GivesPopper () -> IO LazyByteString
streamToByteString GivesPopper ()
givesPopper = do
  IORef LazyByteString
ref <- LazyByteString -> IO (IORef LazyByteString)
forall a. a -> IO (IORef a)
newIORef LazyByteString
forall a. HasCallStack => a
undefined
  GivesPopper ()
givesPopper ([ByteString] -> IORef LazyByteString -> Popper -> IO ()
go [] IORef LazyByteString
ref)
  IORef LazyByteString -> IO LazyByteString
forall a. IORef a -> IO a
readIORef IORef LazyByteString
ref
  where
    go :: [ByteString] -> IORef LazyByteString -> Popper -> IO ()
    go :: [ByteString] -> IORef LazyByteString -> Popper -> IO ()
go [ByteString]
xs IORef LazyByteString
ref Popper
get = Popper
get Popper -> (ByteString -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ case
      ByteString
"" -> IORef LazyByteString -> LazyByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef LazyByteString
ref ([ByteString] -> LazyByteString
L.fromChunks ([ByteString] -> LazyByteString) -> [ByteString] -> LazyByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
xs)
      ByteString
x -> [ByteString] -> IORef LazyByteString -> Popper -> IO ()
go (ByteString
x ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
xs) IORef LazyByteString
ref Popper
get

checkLength :: Int64 -> LazyByteString -> IO LazyByteString
checkLength :: Int64 -> LazyByteString -> IO LazyByteString
checkLength Int64
n LazyByteString
xs
  | Int64
n Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
len = LazyByteString -> IO LazyByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LazyByteString
xs
  | Bool
otherwise = HttpExceptionContent -> IO LazyByteString
forall a. HttpExceptionContent -> IO a
throwHttp (HttpExceptionContent -> IO LazyByteString)
-> HttpExceptionContent -> IO LazyByteString
forall a b. (a -> b) -> a -> b
$ Word64 -> Word64 -> HttpExceptionContent
WrongRequestBodyStreamSize (Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
len) (Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n)
  where
    len :: Int64
len = LazyByteString -> Int64
L.length LazyByteString
xs