{-# LANGUAGE CPP, OverloadedStrings #-} {-# LANGUAGE BlockArguments #-} module PubSubTest (testPubSubThreaded) where import Control.Concurrent import Control.Monad import Control.Concurrent.Async import Control.Exception import Data.Function (fix) import qualified Data.List import Data.Text (Text) import Data.Typeable import Data.ByteString import Control.Concurrent.STM import System.Timeout (timeout) import qualified Test.Framework as Test import qualified Test.Framework.Providers.HUnit as Test (testCase) import qualified Test.HUnit as HUnit import Database.Redis testPubSubThreaded :: [Connection -> Test.Test] testPubSubThreaded = [ removeAllTest , callbackErrorTest , removeFromUnregister , pendingChannelsTrackingTest , subscribeReplyDecodingTracksPendingSets , withPubSubTest , withPubSubTimeoutTest , withPubSubTestBoth ] -- | A handler label to be able to distinguish the handlers from one another -- to help make sure we unregister the correct handler. type HandlerLabel = Text data TestMsg = MsgFromChannel HandlerLabel ByteString | MsgFromPChannel HandlerLabel RedisChannel ByteString deriving (Show, Eq) type MsgVar = TVar [TestMsg] -- | A handler that just writes the message into the TVar handler :: HandlerLabel -> MsgVar -> MessageCallback handler label ref msg = atomically $ modifyTVar ref $ \x -> x ++ [MsgFromChannel label msg] -- | A pattern handler that just writes the message into the TVar phandler :: HandlerLabel -> MsgVar -> PMessageCallback phandler label ref chan msg = atomically $ modifyTVar ref $ \x -> x ++ [MsgFromPChannel label chan msg] -- | Wait for a given message to be received waitForMessage :: MsgVar -> HandlerLabel -> ByteString -> IO () waitForMessage ref label msg = atomically $ do let expected = MsgFromChannel label msg lst <- readTVar ref unless (expected `Prelude.elem` lst) retry writeTVar ref $ Prelude.filter (/= expected) lst -- | Wait for a given pattern message to be received waitForPMessage :: MsgVar -> HandlerLabel -> RedisChannel -> ByteString -> IO () waitForPMessage ref label chan msg = atomically $ do let expected = MsgFromPChannel label chan msg lst <- readTVar ref unless (expected `Prelude.elem` lst) retry writeTVar ref $ Prelude.filter (/= expected) lst expectRedisChannels :: Connection -> [RedisChannel] -> IO () expectRedisChannels conn expected = do actual <- runRedis conn $ sendRequest ["PUBSUB", "CHANNELS"] case actual of Left err -> HUnit.assertFailure $ "Error geting channels: " ++ show err Right s -> HUnit.assertEqual "redis channels" (Data.List.sort s) (Data.List.sort expected) -- | Test basic messages, plus using removeChannels removeAllTest :: Connection -> Test.Test removeAllTest conn = Test.testCase "Multithreaded Pub/Sub - basic" $ do msgVar <- newTVarIO [] initialComplete <- newTVarIO False ctrl <- newPubSubController [("foo1", handler "InitialFoo1" msgVar), ("foo2", handler "InitialFoo2" msgVar)] [("bar1:*", phandler "InitialBar1" msgVar), ("bar2:*", phandler "InitialBar2" msgVar)] withAsync (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) $ \_ -> do -- wait for initial atomically $ readTVar initialComplete >>= \b -> if b then return () else retry expectRedisChannels conn ["foo1", "foo2"] runRedis conn $ publish "foo1" "Hello" waitForMessage msgVar "InitialFoo1" "Hello" runRedis conn $ publish "bar2:zzz" "World" waitForPMessage msgVar "InitialBar2" "bar2:zzz" "World" -- subscribe to foo1 and bar1 again addChannelsAndWait ctrl [("foo1", handler "NewFoo1" msgVar)] [("bar1:*", phandler "NewBar1" msgVar)] expectRedisChannels conn ["foo1", "foo2"] runRedis conn $ publish "foo1" "abcdef" waitForMessage msgVar "InitialFoo1" "abcdef" waitForMessage msgVar "NewFoo1" "abcdef" -- unsubscribe from foo1 and bar1 removeChannelsAndWait ctrl ["foo1", "unusued"] ["bar1:*", "unused:*"] expectRedisChannels conn ["foo2"] -- foo2 and bar2 are still subscribed runRedis conn $ publish "foo2" "12345" waitForMessage msgVar "InitialFoo2" "12345" runRedis conn $ publish "bar2:aaa" "0987" waitForPMessage msgVar "InitialBar2" "bar2:aaa" "0987" data TestError = TestError ByteString deriving (Eq, Show) instance Exception TestError -- | Test an error thrown from a message handler callbackErrorTest :: Connection -> Test.Test callbackErrorTest conn = Test.testCase "Multithreaded Pub/Sub - error in handler" $ do initialComplete <- newTVarIO False ctrl <- newPubSubController [("foo", throwIO . TestError)] [] thread <- async (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) atomically $ readTVar initialComplete >>= \b -> if b then return () else retry runRedis conn $ publish "foo" "Hello" ret <- waitCatch thread case ret of Left (SomeException e) | cast e == Just (TestError "Hello") -> return () _ -> HUnit.assertFailure $ "Did not properly throw error from message thread " ++ show ret -- | Test removing channels by using the return value of 'addHandlersAndWait'. removeFromUnregister :: Connection -> Test.Test removeFromUnregister conn = Test.testCase "Multithreaded Pub/Sub - unregister handlers" $ do msgVar <- newTVarIO [] initialComplete <- newTVarIO False ctrl <- newPubSubController [] [] withAsync (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) $ \_ -> do atomically $ readTVar initialComplete >>= \b -> if b then return () else retry -- register to some channels void $ addChannelsAndWait ctrl [("abc", handler "InitialAbc" msgVar), ("xyz", handler "InitialXyz" msgVar)] [("def:*", phandler "InitialDef" msgVar), ("uvw", phandler "InitialUvw" msgVar)] expectRedisChannels conn ["abc", "xyz"] runRedis conn $ publish "abc" "Hello" waitForMessage msgVar "InitialAbc" "Hello" -- register to some more channels unreg <- addChannelsAndWait ctrl [("abc", handler "SecondAbc" msgVar), ("123", handler "Second123" msgVar)] [("def:*", phandler "SecondDef" msgVar), ("890:*", phandler "Second890" msgVar)] expectRedisChannels conn ["abc", "xyz", "123"] -- check messages on all channels runRedis conn $ publish "abc" "World" waitForMessage msgVar "InitialAbc" "World" waitForMessage msgVar "SecondAbc" "World" runRedis conn $ publish "123" "World2" waitForMessage msgVar "Second123" "World2" runRedis conn $ publish "def:bbbb" "World3" waitForPMessage msgVar "InitialDef" "def:bbbb" "World3" waitForPMessage msgVar "SecondDef" "def:bbbb" "World3" runRedis conn $ publish "890:tttt" "World4" waitForPMessage msgVar "Second890" "890:tttt" "World4" -- unregister unreg -- we have no way of waiting until unregister actually happened, so just delay and hope threadDelay $ 1000*1000 -- 1 second expectRedisChannels conn ["abc", "xyz"] -- now only initial should be around. In particular, abc should still be subscribed runRedis conn $ publish "abc" "World5" waitForMessage msgVar "InitialAbc" "World5" runRedis conn $ publish "def:cccc" "World6" waitForPMessage msgVar "InitialDef" "def:cccc" "World6" waitUntilPendingEmpty :: PubSubController -> IO () waitUntilPendingEmpty ctrl = do ret <- timeout (5 * 1000 * 1000) loop case ret of Nothing -> HUnit.assertFailure "Timed out waiting for pending PubSub channels to be cleared" Just _ -> return () where loop = do pendingCh <- pendingChannels ctrl pendingPCh <- pendingPatternChannels ctrl unless (Prelude.null pendingCh && Prelude.null pendingPCh) $ do threadDelay (10 * 1000) loop assertDoesNotHappen :: String -> IO a -> IO () assertDoesNotHappen label action = do ret <- timeout (700 * 1000) action case ret of Nothing -> return () Just _ -> HUnit.assertFailure $ "Unexpectedly observed: " ++ label -- | Verify exported pending sets track add/remove operations before Redis acknowledges requests. pendingChannelsTrackingTest :: Connection -> Test.Test pendingChannelsTrackingTest _ = Test.testCase "Multithreaded Pub/Sub - pending channels tracking" $ do msgVar <- newTVarIO [] ctrl <- newPubSubController [] [] _ <- addChannels ctrl [("pending:chan", handler "PendingChan" msgVar)] [("pending:*", phandler "PendingPattern" msgVar)] pendingCh <- pendingChannels ctrl pendingPCh <- pendingPatternChannels ctrl HUnit.assertBool "channel should be marked pending" ("pending:chan" `Prelude.elem` pendingCh) HUnit.assertBool "pattern channel should be marked pending" ("pending:*" `Prelude.elem` pendingPCh) removeChannels ctrl ["pending:chan"] ["pending:*"] pendingCh2 <- pendingChannels ctrl pendingPCh2 <- pendingPatternChannels ctrl HUnit.assertBool "removed channel should no longer be pending" (not $ "pending:chan" `Prelude.elem` pendingCh2) HUnit.assertBool "removed pattern channel should no longer be pending" (not $ "pending:*" `Prelude.elem` pendingPCh2) -- | Exercise subscribe/unsubscribe decoding paths and ensure pending sets are drained per channel type. subscribeReplyDecodingTracksPendingSets :: Connection -> Test.Test subscribeReplyDecodingTracksPendingSets conn = Test.testCase "Multithreaded Pub/Sub - decode subscribe/unsubscribe replies" $ do msgVar <- newTVarIO [] initialComplete <- newTVarIO False ctrl <- newPubSubController [] [] withAsync (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) $ \_ -> do atomically $ readTVar initialComplete >>= \b -> if b then return () else retry _ <- addChannels ctrl [("decode:chan", handler "DecodeChan" msgVar)] [("decode:*", phandler "DecodePattern" msgVar)] waitUntilPendingEmpty ctrl runRedis conn $ publish "decode:chan" "msg-1" waitForMessage msgVar "DecodeChan" "msg-1" runRedis conn $ publish "decode:abc" "msg-2" waitForPMessage msgVar "DecodePattern" "decode:abc" "msg-2" removeChannelsAndWait ctrl ["decode:chan"] ["decode:*"] waitUntilPendingEmpty ctrl runRedis conn $ publish "decode:chan" "msg-3" assertDoesNotHappen "channel callback after unsubscribe" $ waitForMessage msgVar "DecodeChan" "msg-3" runRedis conn $ publish "decode:def" "msg-4" assertDoesNotHappen "pattern callback after unsubscribe" $ waitForPMessage msgVar "DecodePattern" "decode:def" "msg-4" withPubSubTest :: Connection -> Test.Test withPubSubTest conn = Test.testCase "Multithreaded Pub/Sub - withPubSub" $ do lock <- newEmptyMVar _ <- forkIO $ do () <- takeMVar lock _ <- runRedis conn $ publish "foo9" "bar" pure () result <- withPubSub conn ["foo9"] [] $ \messageSTM -> do putMVar lock () atomically messageSTM case result of Message "foo9" "bar" -> pure () x -> HUnit.assertFailure $ "Received unexpected message: " ++ show x withPubSubTestBoth :: Connection -> Test.Test withPubSubTestBoth conn = Test.testCase "Multithreaded Pub/Sub - withPubSub (both chan and pchan)" $ do lock <- newEmptyMVar _ <- forkIO $ do () <- takeMVar lock _ <- runRedis conn $ publish "foo100" "bar" _ <- runRedis conn $ publish "foo200" "bar" pure () result <- withPubSub conn ["foo100"] ["foo2*"] $ \fetch -> do putMVar lock () x <- timeout 1000000 $ do flip fix (False, False) \next (seenFoo100, seenFoo200) -> do unless (seenFoo100 && seenFoo200) do msg <- atomically fetch case msg of Message "foo100" "bar" -> next (True, seenFoo200) PMessage "foo2*" "foo200" "bar" -> next (seenFoo100, True) x -> HUnit.assertFailure $ "Received unexpected message: " ++ show x return x case result of Nothing -> HUnit.assertFailure $ "Messages were not received" Just{} -> pure () withPubSubTimeoutTest :: Connection -> Test.Test withPubSubTimeoutTest conn = Test.testCase "Multithreaded Pub/Sub - withPubSub with timeout" $ do result <- withPubSub conn ["foo100"] [] $ \messageSTM -> do timeout (300000) $ atomically messageSTM case result of Nothing -> pure () Just x -> HUnit.assertFailure $ "Expected to timeout without receiving a message, but received: " ++ show x