{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Connection
    ( PendingConnection (..)
    , acceptRequest
    , AcceptRequest(..)
    , defaultAcceptRequest
    , acceptRequestWith
    , rejectRequest
    , RejectRequest(..)
    , defaultRejectRequest
    , rejectRequestWith
    , Connection (..)
    , ConnectionOptions (..)
    , defaultConnectionOptions
    , receive
    , receiveDataMessage
    , receiveData
    , send
    , sendDataMessage
    , sendDataMessages
    , sendTextData
    , sendTextDatas
    , sendBinaryData
    , sendBinaryDatas
    , sendClose
    , sendCloseCode
    , sendPing
    , forkPingThread
    , CompressionOptions (..)
    , PermessageDeflate (..)
    , defaultPermessageDeflate
    , SizeLimit (..)
    ) where
import qualified Data.ByteString.Builder                         as Builder
import           Control.Applicative                             ((<$>))
import           Control.Concurrent                              (forkIO,
                                                                  threadDelay)
import           Control.Exception                               (AsyncException,
                                                                  fromException,
                                                                  handle,
                                                                  throwIO)
import           Control.Monad                                   (foldM, unless,
                                                                  when)
import qualified Data.ByteString                                 as B
import qualified Data.ByteString.Char8                           as B8
import           Data.IORef                                      (IORef,
                                                                  newIORef,
                                                                  readIORef,
                                                                  writeIORef)
import           Data.List                                       (find)
import           Data.Maybe                                      (catMaybes)
import qualified Data.Text                                       as T
import           Data.Word                                       (Word16)
import           Prelude
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Extensions                   as Extensions
import           Network.WebSockets.Extensions.PermessageDeflate
import           Network.WebSockets.Extensions.StrictUnicode
import           Network.WebSockets.Http
import           Network.WebSockets.Protocol
import           Network.WebSockets.Stream                       (Stream)
import qualified Network.WebSockets.Stream                       as Stream
import           Network.WebSockets.Types
data PendingConnection = PendingConnection
    { pendingOptions  :: !ConnectionOptions
    
    , pendingRequest  :: !RequestHead
    
    , pendingOnAccept :: !(Connection -> IO ())
    
    
    , pendingStream   :: !Stream
    
    }
data AcceptRequest = AcceptRequest
    { acceptSubprotocol :: !(Maybe B.ByteString)
    
    
    
    , acceptHeaders     :: !Headers
    
    }
defaultAcceptRequest :: AcceptRequest
defaultAcceptRequest = AcceptRequest Nothing []
sendResponse :: PendingConnection -> Response -> IO ()
sendResponse pc rsp = Stream.write (pendingStream pc)
    (Builder.toLazyByteString (encodeResponse rsp))
acceptRequest :: PendingConnection -> IO Connection
acceptRequest pc = acceptRequestWith pc defaultAcceptRequest
acceptRequestWith :: PendingConnection -> AcceptRequest -> IO Connection
acceptRequestWith pc ar = case find (flip compatible request) protocols of
    Nothing       -> do
        sendResponse pc $ response400 versionHeader ""
        throwIO NotSupported
    Just protocol -> do
        
        rqExts <- either throwIO return $
            getRequestSecWebSocketExtensions request
        
        pmdExt <- case connectionCompressionOptions (pendingOptions pc) of
            NoCompression                     -> return Nothing
            PermessageDeflateCompression pmd0 ->
                case negotiateDeflate (connectionMessageDataSizeLimit options) (Just pmd0) rqExts of
                    Left err   -> do
                        rejectRequestWith pc defaultRejectRequest {rejectMessage = B8.pack err}
                        throwIO NotSupported
                    Right pmd1 -> return (Just pmd1)
        
        let unicodeExt =
                if connectionStrictUnicode (pendingOptions pc)
                    then Just strictUnicode else Nothing
        
        let exts = catMaybes [pmdExt, unicodeExt]
        let subproto = maybe [] (\p -> [("Sec-WebSocket-Protocol", p)]) $ acceptSubprotocol ar
            headers = subproto ++ acceptHeaders ar ++ concatMap extHeaders exts
            response = finishRequest protocol request headers
        either throwIO (sendResponse pc) response
        parseRaw <- decodeMessages
            protocol
            (connectionFramePayloadSizeLimit options)
            (connectionMessageDataSizeLimit options)
            (pendingStream pc)
        writeRaw <- encodeMessages protocol ServerConnection (pendingStream pc)
        write <- foldM (\x ext -> extWrite ext x) writeRaw exts
        parse <- foldM (\x ext -> extParse ext x) parseRaw exts
        sentRef    <- newIORef False
        let connection = Connection
                { connectionOptions   = options
                , connectionType      = ServerConnection
                , connectionProtocol  = protocol
                , connectionParse     = parse
                , connectionWrite     = write
                , connectionSentClose = sentRef
                }
        pendingOnAccept pc connection
        return connection
  where
    options       = pendingOptions pc
    request       = pendingRequest pc
    versionHeader = [("Sec-WebSocket-Version",
        B.intercalate ", " $ concatMap headerVersions protocols)]
data RejectRequest = RejectRequest
    { 
      rejectCode    :: !Int
    , 
      rejectMessage :: !B.ByteString
    , 
      rejectHeaders :: Headers
    , 
      rejectBody    :: !B.ByteString
    }
defaultRejectRequest :: RejectRequest
defaultRejectRequest = RejectRequest
    { rejectCode    = 400
    , rejectMessage = "Bad Request"
    , rejectHeaders = []
    , rejectBody    = ""
    }
rejectRequestWith
    :: PendingConnection  
    -> RejectRequest      
    -> IO ()
rejectRequestWith pc reject = sendResponse pc $ Response
    ResponseHead
        { responseCode    = rejectCode reject
        , responseMessage = rejectMessage reject
        , responseHeaders = rejectHeaders reject
        }
    (rejectBody reject)
rejectRequest
    :: PendingConnection  
    -> B.ByteString       
    -> IO ()
rejectRequest pc body = rejectRequestWith pc
    defaultRejectRequest {rejectBody = body}
data Connection = Connection
    { connectionOptions   :: !ConnectionOptions
    , connectionType      :: !ConnectionType
    , connectionProtocol  :: !Protocol
    , connectionParse     :: !(IO (Maybe Message))
    , connectionWrite     :: !([Message] -> IO ())
    , connectionSentClose :: !(IORef Bool)
    
    
    
    
    
    }
receive :: Connection -> IO Message
receive conn = do
    mbMsg <- connectionParse conn
    case mbMsg of
        Nothing  -> throwIO ConnectionClosed
        Just msg -> return msg
receiveDataMessage :: Connection -> IO DataMessage
receiveDataMessage conn = do
    msg <- receive conn
    case msg of
        DataMessage _ _ _ am -> return am
        ControlMessage cm    -> case cm of
            Close i closeMsg -> do
                hasSentClose <- readIORef $ connectionSentClose conn
                unless hasSentClose $ send conn msg
                throwIO $ CloseRequest i closeMsg
            Pong _    -> do
                connectionOnPong (connectionOptions conn)
                receiveDataMessage conn
            Ping pl   -> do
                send conn (ControlMessage (Pong pl))
                receiveDataMessage conn
receiveData :: WebSocketsData a => Connection -> IO a
receiveData conn = fromDataMessage <$> receiveDataMessage conn
send :: Connection -> Message -> IO ()
send conn = sendAll conn . return
sendAll :: Connection -> [Message] -> IO ()
sendAll _    []   = return ()
sendAll conn msgs = do
    when (any isCloseMessage msgs) $
      writeIORef (connectionSentClose conn) True
    connectionWrite conn msgs
  where
    isCloseMessage (ControlMessage (Close _ _)) = True
    isCloseMessage _                            = False
sendDataMessage :: Connection -> DataMessage -> IO ()
sendDataMessage conn = sendDataMessages conn . return
sendDataMessages :: Connection -> [DataMessage] -> IO ()
sendDataMessages conn = sendAll conn . map (DataMessage False False False)
sendTextData :: WebSocketsData a => Connection -> a -> IO ()
sendTextData conn = sendTextDatas conn . return
sendTextDatas :: WebSocketsData a => Connection -> [a] -> IO ()
sendTextDatas conn =
    sendDataMessages conn .
    map (\x -> Text (toLazyByteString x) Nothing)
sendBinaryData :: WebSocketsData a => Connection -> a -> IO ()
sendBinaryData conn = sendBinaryDatas conn . return
sendBinaryDatas :: WebSocketsData a => Connection -> [a] -> IO ()
sendBinaryDatas conn = sendDataMessages conn . map (Binary . toLazyByteString)
sendClose :: WebSocketsData a => Connection -> a -> IO ()
sendClose conn = sendCloseCode conn 1000
sendCloseCode :: WebSocketsData a => Connection -> Word16 -> a -> IO ()
sendCloseCode conn code =
    send conn . ControlMessage . Close code . toLazyByteString
sendPing :: WebSocketsData a => Connection -> a -> IO ()
sendPing conn = send conn . ControlMessage . Ping . toLazyByteString
forkPingThread :: Connection -> Int -> IO ()
forkPingThread conn n
    | n <= 0    = return ()
    | otherwise = do
        _ <- forkIO (ignore `handle` go 1)
        return ()
  where
    go :: Int -> IO ()
    go i = do
        threadDelay (n * 1000 * 1000)
        sendPing conn (T.pack $ show i)
        go (i + 1)
    ignore e = case fromException e of
        Just async -> throwIO (async :: AsyncException)
        Nothing    -> return ()