{-# LANGUAGE RecordWildCards, StandaloneDeriving, OverloadedStrings #-}
{-# LANGUAGE CPP, FlexibleContexts, TupleSections, TypeSynonymInstances #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NamedFieldPuns, ScopedTypeVariables #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Internal.Protocol (
    FullCollection,
    
    Pipe, newPipe, newPipeWith, send, call,
    
    Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
    
    Request(..), QueryOption(..),
    
    Reply(..), ResponseFlag(..),
    
    Username, Password, Nonce, pwHash, pwKey,
    isClosed, close, ServerData(..), Pipeline(..)
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Control.Monad (forM, replicateM, unless)
import Data.Binary.Get (Get, runGet)
import Data.Binary.Put (Put, runPut)
import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle)
import System.IO.Error (doesNotExistErrorType, mkIOError)
import System.IO.Unsafe (unsafePerformIO)
import Data.Maybe (maybeToList)
import GHC.Conc (ThreadStatus(..), threadStatus)
import Control.Monad (forever)
import Control.Monad.STM (atomically)
import Control.Concurrent (ThreadId, killThread, forkIOWithUnmask)
import Control.Concurrent.STM.TChan (TChan, newTChan, readTChan, writeTChan, isEmptyTChan)
import Control.Exception.Lifted (SomeException, mask_, onException, throwIO, try)
import qualified Data.ByteString.Lazy as L
import Control.Monad.Trans (MonadIO, liftIO)
import Data.Bson (Document)
import Data.Bson.Binary (getDocument, putDocument, getInt32, putInt32, getInt64,
                         putInt64, putCString)
import Data.Text (Text)
import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (bitOr, byteStringHex)
import Database.MongoDB.Transport (Transport)
import qualified Database.MongoDB.Transport as Tr
#if MIN_VERSION_base(4,6,0)
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
                                       putMVar, readMVar, mkWeakMVar, isEmptyMVar)
#else
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
                                         putMVar, readMVar, addMVarFinalizer)
#endif
#if !MIN_VERSION_base(4,6,0)
mkWeakMVar :: MVar a -> IO () -> IO ()
mkWeakMVar = addMVarFinalizer
#endif
data Pipeline = Pipeline
    { vStream :: MVar Transport 
    , responseQueue :: TChan (MVar (Either IOError Response)) 
    , listenThread :: ThreadId
    , finished :: MVar ()
    , serverData :: ServerData
    }
data ServerData = ServerData
                { isMaster            :: Bool
                , minWireVersion      :: Int
                , maxWireVersion      :: Int
                , maxMessageSizeBytes :: Int
                , maxBsonObjectSize   :: Int
                , maxWriteBatchSize   :: Int
                }
forkUnmaskedFinally :: IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkUnmaskedFinally action and_then =
  mask_ $ forkIOWithUnmask $ \unmask ->
    try (unmask action) >>= and_then
newPipeline :: ServerData -> Transport -> IO Pipeline
newPipeline serverData stream = do
    vStream <- newMVar stream
    responseQueue <- atomically newTChan
    finished <- newEmptyMVar
    let drainReplies = do
          chanEmpty <- atomically $ isEmptyTChan responseQueue
          if chanEmpty
            then return ()
            else do
              var <- atomically $ readTChan responseQueue
              putMVar var $ Left $ mkIOError
                                        doesNotExistErrorType
                                        "Handle has been closed"
                                        Nothing
                                        Nothing
              drainReplies
    rec
        let pipe = Pipeline{..}
        listenThread <- forkUnmaskedFinally (listen pipe) $ \_ -> do
                                                              putMVar finished ()
                                                              drainReplies
    _ <- mkWeakMVar vStream $ do
        killThread listenThread
        Tr.close stream
    return pipe
isFinished :: Pipeline -> IO Bool
isFinished Pipeline {finished} = do
  empty <- isEmptyMVar finished
  return $ not empty
close :: Pipeline -> IO ()
close Pipeline{..} = do
    killThread listenThread
    Tr.close =<< readMVar vStream
isClosed :: Pipeline -> IO Bool
isClosed Pipeline{listenThread} = do
    status <- threadStatus listenThread
    return $ case status of
        ThreadRunning -> False
        ThreadFinished -> True
        ThreadBlocked _ -> False
        ThreadDied -> True
listen :: Pipeline -> IO ()
listen Pipeline{..} = do
    stream <- readMVar vStream
    forever $ do
        e <- try $ readMessage stream
        var <- atomically $ readTChan responseQueue
        putMVar var e
        case e of
            Left err -> Tr.close stream >> ioError err  
            Right _ -> return ()
psend :: Pipeline -> Message -> IO ()
psend p@Pipeline{..} !message = withMVar vStream (flip writeMessage message) `onException` close p
pcall :: Pipeline -> Message -> IO (IO Response)
pcall p@Pipeline{..} message = do
  listenerStopped <- isFinished p
  if listenerStopped
    then ioError $ mkIOError doesNotExistErrorType "Handle has been closed" Nothing Nothing
    else withMVar vStream doCall `onException` close p
  where
    doCall stream = do
        writeMessage stream message
        var <- newEmptyMVar
        liftIO $ atomically $ writeTChan responseQueue var
        return $ readMVar var >>= either throwIO return 
type Pipe = Pipeline
newPipe :: ServerData -> Handle -> IO Pipe
newPipe sd handle = Tr.fromHandle handle >>= (newPipeWith sd)
newPipeWith :: ServerData -> Transport -> IO Pipe
newPipeWith sd conn = newPipeline sd conn
send :: Pipe -> [Notice] -> IO ()
send pipe notices = psend pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IO (IO Reply)
call pipe notices request = do
    requestId <- genRequestId
    promise <- pcall pipe (notices, Just (request, requestId))
    return $ check requestId <$> promise
 where
    check requestId (responseTo, reply) = if requestId == responseTo then reply else
        error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type Message = ([Notice], Maybe (Request, RequestId))
writeMessage :: Transport -> Message -> IO ()
writeMessage conn (notices, mRequest) = do
    noticeStrings <- forM notices $ \n -> do
          requestId <- genRequestId
          let s = runPut $ putNotice n requestId
          return $ (lenBytes s) `L.append` s
    let requestString = do
          (request, requestId) <- mRequest
          let s = runPut $ putRequest request requestId
          return $ (lenBytes s) `L.append` s
    Tr.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
    Tr.flush conn
 where
    lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes
    encodeSize = runPut . putInt32 . (+ 4)
type Response = (ResponseTo, Reply)
readMessage :: Transport -> IO Response
readMessage conn = readResp  where
    readResp = do
        len <- fromEnum . decodeSize . L.fromStrict <$> Tr.read conn 4
        runGet getReply . L.fromStrict <$> Tr.read conn len
    decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: (MonadIO m) => m RequestId
genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where
    counter :: IORef RequestId
    counter = unsafePerformIO (newIORef 0)
    {-# NOINLINE counter #-}
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
    putInt32 requestId
    putInt32 0
    putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
    _requestId <- getInt32
    responseTo <- getInt32
    opcode <- getInt32
    return (opcode, responseTo)
data Notice =
      Insert {
        iFullCollection :: FullCollection,
        iOptions :: [InsertOption],
        iDocuments :: [Document]}
    | Update {
        uFullCollection :: FullCollection,
        uOptions :: [UpdateOption],
        uSelector :: Document,
        uUpdater :: Document}
    | Delete {
        dFullCollection :: FullCollection,
        dOptions :: [DeleteOption],
        dSelector :: Document}
    | KillCursors {
        kCursorIds :: [CursorId]}
    deriving (Show, Eq)
data InsertOption = KeepGoing  
    deriving (Show, Eq)
data UpdateOption =
      Upsert  
    | MultiUpdate  
    deriving (Show, Eq)
data DeleteOption = SingleRemove  
    deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
    putHeader (nOpcode notice) requestId
    case notice of
        Insert{..} -> do
            putInt32 (iBits iOptions)
            putCString iFullCollection
            mapM_ putDocument iDocuments
        Update{..} -> do
            putInt32 0
            putCString uFullCollection
            putInt32 (uBits uOptions)
            putDocument uSelector
            putDocument uUpdater
        Delete{..} -> do
            putInt32 0
            putCString dFullCollection
            putInt32 (dBits dOptions)
            putDocument dSelector
        KillCursors{..} -> do
            putInt32 0
            putInt32 $ toEnum (length kCursorIds)
            mapM_ putInt64 kCursorIds
iBit :: InsertOption -> Int32
iBit KeepGoing = bit 0
iBits :: [InsertOption] -> Int32
iBits = bitOr . map iBit
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
      Query {
        qOptions :: [QueryOption],
        qFullCollection :: FullCollection,
        qSkip :: Int32,  
        qBatchSize :: Int32,  
        qSelector :: Document,  
        qProjector :: Document  
    } | GetMore {
        gFullCollection :: FullCollection,
        gBatchSize :: Int32,
        gCursorId :: CursorId}
    deriving (Show, Eq)
data QueryOption =
      TailableCursor  
    | SlaveOK  
    | NoCursorTimeout  
    | AwaitData  
    | Partial  
    deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
    putHeader (qOpcode request) requestId
    case request of
        Query{..} -> do
            putInt32 (qBits qOptions)
            putCString qFullCollection
            putInt32 qSkip
            putInt32 qBatchSize
            putDocument qSelector
            unless (null qProjector) (putDocument qProjector)
        GetMore{..} -> do
            putInt32 0
            putCString gFullCollection
            putInt32 gBatchSize
            putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBit Partial = bit 7
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
    rResponseFlags :: [ResponseFlag],
    rCursorId :: CursorId,  
    rStartingFrom :: Int32,
    rDocuments :: [Document]
    } deriving (Show, Eq)
data ResponseFlag =
      CursorNotFound  
    | QueryError  
    | AwaitCapable  
    deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
    (opcode, responseTo) <- getHeader
    unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
    rResponseFlags <-  rFlags <$> getInt32
    rCursorId <- getInt64
    rStartingFrom <- getInt32
    numDocs <- fromIntegral <$> getInt32
    rDocuments <- replicateM numDocs getDocument
    return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = Text
type Password = Text
type Nonce = Text
pwHash :: Username -> Password -> Text
pwHash u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 $ u `T.append` ":mongo:" `T.append` p
pwKey :: Nonce -> Username -> Password -> Text
pwKey n u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 . T.append n . T.append u $ pwHash u p