{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE ScopedTypeVariables   #-}
module Module.Capnp.Rpc (rpcTests) where

import Control.Concurrent.STM
import Data.Word
import Test.Hspec

import Control.Concurrent.Async (concurrently_, race_)
import Control.Exception.Safe   (bracket, try)
import Control.Monad            (replicateM, void, (>=>))
import Control.Monad.Catch      (throwM)
import Control.Monad.IO.Class   (liftIO)
import Data.Foldable            (for_)
import Data.Mutable             (freeze)
import System.Timeout           (timeout)

import qualified Data.ByteString.Builder as BB
import qualified Data.Text               as T
import qualified Network.Socket          as Socket
import qualified Supervisors

import Capnp
    ( createPure
    , def
    , defaultLimit
    , evalLimitT
    , lbsToMsg
    , msgToValue
    , valueToMsg
    )
import Capnp.Bits       (WordCount)
import Capnp.Rpc.Errors (eFailed)

import Capnp.Gen.Aircraft.Pure  hiding (Left, Right)
import Capnp.Gen.Capnp.Rpc.Pure
import Capnp.Rpc
import Capnp.Rpc.Untyped

import qualified Capnp.Gen.Echo.Pure as E
import qualified Capnp.Pointer       as P

rpcTests :: Spec
rpcTests = do
    echoTests
    aircraftTests
    unusualTests

-------------------------------------------------------------------------------
-- Tests using echo.capnp.
-------------------------------------------------------------------------------

echoTests :: Spec
echoTests = describe "Echo server & client" $
    it "Should echo back the same message." $ runVatPair
        (`E.export_Echo` TestEchoServer)
        (\_sup echoSrv -> do
                let msgs =
                        [ def { E.query = "Hello #1" }
                        , def { E.query = "Hello #2" }
                        ]
                rets <- traverse ((E.echo'echo echoSrv ?) >=> wait) msgs
                liftIO $ rets `shouldBe`
                    [ def { E.reply = "Hello #1" }
                    , def { E.reply = "Hello #2" }
                    ]
        )

data TestEchoServer = TestEchoServer

instance Server IO TestEchoServer
instance E.Echo'server_ IO TestEchoServer where
    echo'echo = pureHandler $ \_ params -> pure def { E.reply = E.query params }

-------------------------------------------------------------------------------
-- Tests using aircraft.capnp.
--
-- These use the 'CallSequence' interface as a counter.
-------------------------------------------------------------------------------

-- | Bump a counter n times, returning a list of the results.
bumpN :: CallSequence -> Int -> IO [CallSequence'getNumber'results]
bumpN ctr n = bumpNPromise ctr n >>= traverse wait

-- | Like 'bumpN', but doesn't wait for the results -- returns a list of promises.
bumpNPromise :: CallSequence -> Int -> IO [Promise CallSequence'getNumber'results]
bumpNPromise ctr n = replicateM n (callSequence'getNumber ctr ? def)

aircraftTests :: Spec
aircraftTests = describe "aircraft.capnp rpc tests" $ do
    describe "newPromiseClient" $
        it "Should preserve E-order" $
            Supervisors.withSupervisor $ \sup -> do
                (pc, f) <- newPromiseClient
                firsts <- bumpNPromise pc 2
                atomically (newTestCtr 0)
                    >>= export_CallSequence sup
                    >>= fulfill f
                nexts <- bumpN pc 2
                firstsResolved <- traverse wait firsts
                firstsResolved `shouldBe`
                    [ def { n = 1 }
                    , def { n = 2 }
                    ]
                nexts `shouldBe`
                    [ def { n = 3 }
                    , def { n = 4 }
                    ]
    it "Should propogate server-side exceptions to client method calls" $ runVatPair
        (`export_CallSequence` ExnCtrServer)
        (\_sup -> expectException
            (\cap -> callSequence'getNumber cap ? def)
            def
                { type_ = Exception'Type'failed
                , reason = "Something went sideways."
                }
        )
    it "Should receive unimplemented when calling a method on a null cap." $ runVatPair
        (\_sup -> pure $ CallSequence nullClient)
        (\_sup -> expectException
            (\cap -> callSequence'getNumber cap ? def)
            def
                { type_ = Exception'Type'unimplemented
                , reason = "Method unimplemented"
                }
        )
    it "Should throw an unimplemented exception if the server doesn't implement a method" $ runVatPair
        (`export_CallSequence` NoImplServer)
        (\_sup -> expectException
            (\cap -> callSequence'getNumber cap ? def)
            def
                { type_ = Exception'Type'unimplemented
                , reason = "Method unimplemented"
                }
        )
    it "Should throw an opaque exception when the server throws a non-rpc exception" $ runVatPair
        (`export_CallSequence` NonRpcExnServer)
        (\_sup -> expectException
            (\cap -> callSequence'getNumber cap ? def)
            def
                { type_ = Exception'Type'failed
                , reason = "Unhandled exception"
                }
        )
    it "A counter should maintain state" $ runVatPair
        (\sup -> newTestCtr 0 >>= export_CallSequence sup)
        (\_sup ctr -> do
            results <- bumpN ctr 4
            liftIO $ results `shouldBe`
                [ def { n = 1 }
                , def { n = 2 }
                , def { n = 3 }
                , def { n = 4 }
                ]
        )
    it "Methods returning interfaces work" $ runVatPair
        (\sup -> export_CounterFactory sup (TestCtrFactory sup))
        (\_sup factory -> do
            let newCounter start = do
                    CounterFactory'newCounter'results{counter} <-
                        counterFactory'newCounter factory ? def { start }
                        >>= wait
                    pure counter

            ctrA <- newCounter 2
            ctrB <- newCounter 0

            r1 <- bumpN ctrA 4
            liftIO $ r1 `shouldBe`
                [ def { n = 3 }
                , def { n = 4 }
                , def { n = 5 }
                , def { n = 6 }
                ]

            r2 <- bumpN ctrB 2
            liftIO $ r2 `shouldBe`
                [ def { n = 1 }
                , def { n = 2 }
                ]

            ctrC <- newCounter 30

            r3 <- bumpN ctrA 3
            liftIO $ r3 `shouldBe`
                [ def { n = 7 }
                , def { n = 8 }
                , def { n = 9 }
                ]

            r4 <- bumpN ctrC 1
            liftIO $ r4 `shouldBe` [ def { n = 31 } ]
        )
    it "Methods with interface parameters work" $ do
        ctrA <- atomically $ newTestCtr 2
        ctrB <- atomically $ newTestCtr 0
        ctrC <- atomically $ newTestCtr 30
        runVatPair
            (`export_CounterAcceptor` TestCtrAcceptor)
            (\sup acceptor -> do
                for_ [ctrA, ctrB, ctrC] $ \ctrSrv -> do
                    ctr <- atomically $ export_CallSequence sup ctrSrv
                    counterAcceptor'accept acceptor ? CounterAcceptor'accept'params { counter = ctr }
                        >>= wait
                r <- traverse
                    (\(TestCtrServer var) -> liftIO $ readTVarIO var)
                    [ctrA, ctrB, ctrC]
                liftIO $ r `shouldBe` [7, 5, 35]
            )

data TestCtrAcceptor = TestCtrAcceptor

instance Server IO TestCtrAcceptor
instance CounterAcceptor'server_ IO TestCtrAcceptor where
    counterAcceptor'accept =
        pureHandler $ \_ CounterAcceptor'accept'params{counter} -> do
            [start] <- map n <$> bumpN counter 1
            r <- bumpN counter 4
            liftIO $ r `shouldBe`
                [ def { n = start + 1 }
                , def { n = start + 2 }
                , def { n = start + 3 }
                , def { n = start + 4 }
                ]
            pure def

-------------------------------------------------------------------------------
-- Implementations of various interfaces for testing purposes.
-------------------------------------------------------------------------------

newtype TestCtrFactory = TestCtrFactory { sup :: Supervisor }

instance Server IO TestCtrFactory
instance CounterFactory'server_ IO TestCtrFactory where
    counterFactory'newCounter =
        pureHandler $ \TestCtrFactory{sup} CounterFactory'newCounter'params{start} -> do
            ctr <- atomically $ newTestCtr start >>= export_CallSequence sup
            pure CounterFactory'newCounter'results { counter = ctr }

newTestCtr :: Word32 -> STM TestCtrServer
newTestCtr n = TestCtrServer <$> newTVar n

newtype TestCtrServer = TestCtrServer (TVar Word32)

instance Server IO TestCtrServer
instance CallSequence'server_ IO TestCtrServer  where
    callSequence'getNumber = pureHandler $ \(TestCtrServer tvar) _ -> do
        ret <- liftIO $ atomically $ do
            modifyTVar' tvar (+1)
            readTVar tvar
        pure def { n = ret }

-- a 'CallSequence' which always throws an exception.
data ExnCtrServer = ExnCtrServer

instance Server IO ExnCtrServer
instance CallSequence'server_ IO ExnCtrServer where
    callSequence'getNumber = pureHandler $ \_ _ ->
        throwM def
            { type_ = Exception'Type'failed
            , reason = "Something went sideways."
            }

-- a 'CallSequence' which doesn't implement its methods.
data NoImplServer = NoImplServer

instance Server IO NoImplServer
instance CallSequence'server_ IO NoImplServer -- TODO: can we silence the warning somehow?

-- Server that throws some non-rpc exception.
data NonRpcExnServer = NonRpcExnServer

instance Server IO NonRpcExnServer
instance CallSequence'server_ IO NonRpcExnServer where
    callSequence'getNumber = pureHandler $ \_ _ -> error "OOPS"

-------------------------------------------------------------------------------
-- Tests for unusual patterns of messages .
--
-- Some of these will never come up when talking to a correct implementation of
-- capnproto, and others just won't come up when talking to the Haskell
-- implementation. Accordingly, these tests start a vat in one thread and
-- directly manipulate the transport in the other.
-------------------------------------------------------------------------------

unusualTests :: Spec
unusualTests = describe "Tests for unusual message patterns" $ do
    it "Should raise ReceivedAbort in response to an abort message." $ do
        -- Send an abort message to the remote vat, and verify that
        -- the vat actually aborts.
        let exn = def
                { type_ = Exception'Type'failed
                , reason = "Testing abort"
                }
        withTransportPair $ \(vatTrans, probeTrans) -> do
            ret <- try $ concurrently_
                (handleConn (vatTrans defaultLimit) def { debugMode = True})
                $ do
                    msg <- createPure maxBound $ valueToMsg $ Message'abort exn
                    sendMsg (probeTrans defaultLimit) msg
            ret `shouldBe` Left (ReceivedAbort exn)
    triggerAbort (Message'unimplemented $ Message'abort def) $
        "Your vat sent an 'unimplemented' message for an abort message " <>
        "that its remote peer never sent. This is likely a bug in your " <>
        "capnproto library."
    triggerAbort
        (Message'call def
            { target = MessageTarget'importedCap 443
            }
        )
        "No such export: 443"
    triggerAbort
        (Message'call def
            { target = MessageTarget'promisedAnswer def { questionId=300 }
            }
        )
        "No such answer: 300"
    triggerAbort
        (Message'return def { answerId = 234 })
        "No such question: 234"
    it "Should respond with an abort if sent junk data" $ do
        let wantAbortExn = def
                { reason = "Unhandled exception: TraversalLimitError"
                , type_ = Exception'Type'failed
                }
        withTransportPair $ \(vatTrans, probeTrans) ->
            concurrently_
                (do
                    Left (e :: RpcError) <- try $
                        handleConn (vatTrans defaultLimit) def { debugMode = True }
                    e `shouldBe` SentAbort wantAbortExn
                )
                (do
                    let bb = mconcat
                            [ BB.word32LE 0 -- 1 segment - 1 = 0
                            , BB.word32LE 2 -- 2 words in first segment
                            -- a pair of structs that point to each other:
                            , BB.word64LE (P.serializePtr (Just (P.StructPtr   0  0 1)))
                            , BB.word64LE (P.serializePtr (Just (P.StructPtr (-1) 0 1)))
                            ]
                        lbs = BB.toLazyByteString bb
                    msg <- lbsToMsg lbs
                    sendMsg (probeTrans defaultLimit) msg
                    msg' <- recvMsg (probeTrans defaultLimit)
                    resp <- msgToValue msg'
                    resp `shouldBe` Message'abort wantAbortExn
                )
    it "Should respond with an abort if erroneously sent return = resultsSentElsewhere" $

        withTransportPair $ \(vatTrans, probeTrans) ->
            let wantExn = eFailed $
                    "Received Return.resultsSentElswhere for a call "
                    <> "with sendResultsTo = caller."
            in concurrently_
                (do
                    Left (e :: RpcError) <- try $
                        handleConn (vatTrans defaultLimit) def
                            { debugMode = True
                            , withBootstrap = Just $ \_sup client ->
                                let ctr :: CallSequence = fromClient client
                                in void $ (callSequence'getNumber ctr ? def) >>= wait
                            }
                    e `shouldBe` SentAbort wantExn
                )
                (do
                    let send msg =
                            evalLimitT maxBound (valueToMsg msg >>= freeze)
                            >>= sendMsg (probeTrans defaultLimit)
                        recv = recvMsg (probeTrans defaultLimit) >>= msgToValue
                    Message'bootstrap Bootstrap{} <- recv
                    Message'call Call{questionId} <- recv
                    send $ Message'return def
                        { answerId = questionId
                        , union' = Return'resultsSentElsewhere
                        }
                    msg <- recv
                    msg `shouldBe` Message'abort wantExn
                )
    it "Should reply with unimplemented when sent a join (level 4 only)." $
        withTransportPair $ \(vatTrans, probeTrans) ->
        race_
            (handleConn (vatTrans defaultLimit) def { debugMode = True })
            $ do
                msg <- createPure maxBound $ valueToMsg $ Message'join def
                sendMsg (probeTrans defaultLimit) msg
                msg' <- recvMsg (probeTrans defaultLimit) >>= msgToValue
                msg' `shouldBe` Message'unimplemented (Message'join def)


-- | Verify that the given message triggers an abort with the specified 'reason'
-- field.
triggerAbort :: Message -> T.Text -> Spec
triggerAbort msg reason =
    it ("Should abort when sent the message " ++ show msg ++ " on startup") $ do
        let wantAbortExn = def
                { reason = reason
                , type_ = Exception'Type'failed
                }
        withTransportPair $ \(vatTrans, probeTrans) ->
            concurrently_
                (do
                    ret <- try $ handleConn (vatTrans defaultLimit) def { debugMode = True }
                    ret `shouldBe` Left (SentAbort wantAbortExn)
                )
                (do
                    rawMsg <- createPure maxBound $ valueToMsg msg
                    sendMsg (probeTrans defaultLimit) rawMsg
                    -- 4 second timeout. The remote vat's timeout before killing the
                    -- connection is one second, so if this happens we're never going
                    -- to receive the message. In theory this is possible, but if it
                    -- happens something is very wrong.
                    r <- timeout 4000000 $ recvMsg (probeTrans defaultLimit)
                    case r of
                        Nothing ->
                            error "Test timed out waiting on abort message."
                        Just rawResp -> do
                            resp <- msgToValue rawResp
                            resp `shouldBe` Message'abort wantAbortExn
                )

-------------------------------------------------------------------------------
-- Utilties used by the tests.
-------------------------------------------------------------------------------

withSocketPair :: ((Socket.Socket, Socket.Socket) -> IO a) -> IO a
withSocketPair =
    bracket
        (Socket.socketPair Socket.AF_UNIX Socket.Stream 0)
        (\(x, y) -> Socket.close x >> Socket.close y)

withTransportPair ::
    ( ( WordCount -> Transport
      , WordCount -> Transport
      ) -> IO a
    ) -> IO a
withTransportPair f =
    withSocketPair $ \(x, y) -> f (socketTransport x, socketTransport y)

-- | @'runVatPair' server client@ runs a pair of vats connected to one another,
-- using 'server' as the 'offerBootstrap' field in the one vat's config, and
-- 'client' as the 'withBootstrap' field in the other's.
runVatPair :: IsClient c => (Supervisor -> STM c) -> (Supervisor -> c -> IO ()) -> IO ()
runVatPair getBootstrap withBootstrap = withTransportPair $ \(clientTrans, serverTrans) -> do
    let runClient = handleConn (clientTrans defaultLimit) def
            { debugMode = True
            , withBootstrap = Just $ \sup -> withBootstrap sup . fromClient
            }
        runServer = handleConn (serverTrans defaultLimit) def
            { debugMode = True
            , getBootstrap = fmap (Just . toClient) . getBootstrap
            }
    race_ runServer runClient

expectException :: Show a => (cap -> IO (Promise a)) -> Exception -> cap -> IO ()
expectException callFn wantExn cap = do
    ret <- try $ callFn cap >>= wait
    case ret of
        Left (e :: Exception) ->
            liftIO $ e `shouldBe` wantExn
        Right val ->
            error $ "Should have received exn, but got " ++ show val