{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE DeriveAnyClass     #-}
{-# LANGUAGE NamedFieldPuns     #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# LANGUAGE RecordWildCards    #-}
{-# LANGUAGE StrictData         #-}
{-# LANGUAGE TypeApplications   #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | "Database.Redis" like interface with connection through Redis Sentinel.
--
-- More details here: <https://redis.io/topics/sentinel>.
--
-- Example:
--
-- @
-- conn <- 'connect' 'SentinelConnectionInfo' (("localhost", PortNumber 26379) :| []) "mymaster" 'defaultConnectInfo'
--
-- 'runRedis' conn $ do
--   'set' "hello" "world"
-- @
--
-- When connection is opened, the Sentinels will be queried to get current master. Subsequent 'runRedis'
-- calls will talk to that master.
--
-- If 'runRedis' call fails, the next call will choose a new master to talk to.
--
-- This implementation is based on Gist by Emanuel Borsboom
-- at <https://gist.github.com/borsboom/681d37d273d5c4168723>
module Database.Redis.Sentinel
  (
    -- * Connection
    SentinelConnectInfo(..)
  , SentinelConnection
  , connect
    -- * runRedis with Sentinel support
  , runRedis
  , RedisSentinelException(..)

    -- * Re-export Database.Redis
  , module Database.Redis
  ) where

import           Control.Concurrent
import           Control.Exception     (Exception, IOException, evaluate, throwIO)
import           Control.Monad
import           Control.Monad.Catch   (Handler (..), MonadCatch, catches, throwM)
import           Control.Monad.Except
import           Control.Monad.IO.Class(MonadIO(liftIO))
import           Data.ByteString       (ByteString)
import qualified Data.ByteString       as BS
import qualified Data.ByteString.Char8 as BS8
import           Data.Foldable         (toList)
import           Data.List             (delete)
import           Data.List.NonEmpty    (NonEmpty (..))
import           Data.Typeable         (Typeable)
import           Data.Unique
import           Network.Socket        (HostName)

import           Database.Redis hiding (Connection, connect, runRedis)
import qualified Database.Redis as Redis

-- | Interact with a Redis datastore.  See 'Database.Redis.runRedis' for details.
runRedis :: SentinelConnection
         -> Redis (Either Reply a)
         -> IO (Either Reply a)
runRedis :: forall a.
SentinelConnection -> Redis (Either Reply a) -> IO (Either Reply a)
runRedis (SentinelConnection MVar SentinelConnection'
connMVar) Redis (Either Reply a)
action = do
  (Connection
baseConn, Unique
preToken) <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar SentinelConnection'
connMVar forall a b. (a -> b) -> a -> b
$ \oldConnection :: SentinelConnection'
oldConnection@SentinelConnection'
          { Bool
rcCheckFailover :: SentinelConnection' -> Bool
rcCheckFailover :: Bool
rcCheckFailover
          , rcToken :: SentinelConnection' -> Unique
rcToken = Unique
oldToken
          , rcSentinelConnectInfo :: SentinelConnection' -> SentinelConnectInfo
rcSentinelConnectInfo = SentinelConnectInfo
oldConnectInfo
          , rcMasterConnectInfo :: SentinelConnection' -> ConnectInfo
rcMasterConnectInfo = ConnectInfo
oldMasterConnectInfo
          , rcBaseConnection :: SentinelConnection' -> Connection
rcBaseConnection = Connection
oldBaseConnection } ->
      if Bool
rcCheckFailover
        then do
          (SentinelConnectInfo
newConnectInfo, ConnectInfo
newMasterConnectInfo) <- SentinelConnectInfo -> IO (SentinelConnectInfo, ConnectInfo)
updateMaster SentinelConnectInfo
oldConnectInfo
          Unique
newToken <- IO Unique
newUnique
          (ConnectInfo
connInfo, Connection
conn) <-
            if ConnectInfo -> ConnectInfo -> Bool
sameHost ConnectInfo
newMasterConnectInfo ConnectInfo
oldMasterConnectInfo
              then forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectInfo
oldMasterConnectInfo, Connection
oldBaseConnection)
              else do
                Connection
newConn <- ConnectInfo -> IO Connection
Redis.connect ConnectInfo
newMasterConnectInfo
                forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectInfo
newMasterConnectInfo, Connection
newConn)

          forall (m :: * -> *) a. Monad m => a -> m a
return
            ( SentinelConnection'
              { rcCheckFailover :: Bool
rcCheckFailover = Bool
False
              , rcToken :: Unique
rcToken = Unique
newToken
              , rcSentinelConnectInfo :: SentinelConnectInfo
rcSentinelConnectInfo = SentinelConnectInfo
newConnectInfo
              , rcMasterConnectInfo :: ConnectInfo
rcMasterConnectInfo = ConnectInfo
connInfo
              , rcBaseConnection :: Connection
rcBaseConnection = Connection
conn
              }
            , (Connection
conn, Unique
newToken)
            )
        else forall (m :: * -> *) a. Monad m => a -> m a
return (SentinelConnection'
oldConnection, (Connection
oldBaseConnection, Unique
oldToken))

  -- Use evaluate to make sure we catch exceptions from 'runRedis'.
  Either Reply a
reply <- (forall a. Connection -> Redis a -> IO a
Redis.runRedis Connection
baseConn Redis (Either Reply a)
action forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. a -> IO a
evaluate)
    forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m ()) -> m a
`catchRedisRethrow` (\HostName
_ -> Unique -> IO ()
setCheckSentinel Unique
preToken)
  case Either Reply a
reply of
    Left (Error ByteString
e) | ByteString
"READONLY " ByteString -> ByteString -> Bool
`BS.isPrefixOf` ByteString
e ->
        -- This means our connection has turned into a slave
        Unique -> IO ()
setCheckSentinel Unique
preToken
    Either Reply a
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
  forall (m :: * -> *) a. Monad m => a -> m a
return Either Reply a
reply

  where
    sameHost :: Redis.ConnectInfo -> Redis.ConnectInfo -> Bool
    sameHost :: ConnectInfo -> ConnectInfo -> Bool
sameHost ConnectInfo
l ConnectInfo
r = ConnectInfo -> HostName
connectHost ConnectInfo
l forall a. Eq a => a -> a -> Bool
== ConnectInfo -> HostName
connectHost ConnectInfo
r Bool -> Bool -> Bool
&& ConnectInfo -> PortID
connectPort ConnectInfo
l forall a. Eq a => a -> a -> Bool
== ConnectInfo -> PortID
connectPort ConnectInfo
r

    setCheckSentinel :: Unique -> IO ()
setCheckSentinel Unique
preToken = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar SentinelConnection'
connMVar forall a b. (a -> b) -> a -> b
$ \conn :: SentinelConnection'
conn@SentinelConnection'{Unique
rcToken :: Unique
rcToken :: SentinelConnection' -> Unique
rcToken} ->
      if Unique
preToken forall a. Eq a => a -> a -> Bool
== Unique
rcToken
        then do
          Unique
newToken <- IO Unique
newUnique
          forall (m :: * -> *) a. Monad m => a -> m a
return (SentinelConnection'
conn{rcToken :: Unique
rcToken = Unique
newToken, rcCheckFailover :: Bool
rcCheckFailover = Bool
True})
        else forall (m :: * -> *) a. Monad m => a -> m a
return SentinelConnection'
conn


connect :: SentinelConnectInfo -> IO SentinelConnection
connect :: SentinelConnectInfo -> IO SentinelConnection
connect SentinelConnectInfo
origConnectInfo = do
  (SentinelConnectInfo
connectInfo, ConnectInfo
masterConnectInfo) <- SentinelConnectInfo -> IO (SentinelConnectInfo, ConnectInfo)
updateMaster SentinelConnectInfo
origConnectInfo
  Connection
conn <- ConnectInfo -> IO Connection
Redis.connect ConnectInfo
masterConnectInfo
  Unique
token <- IO Unique
newUnique

  MVar SentinelConnection' -> SentinelConnection
SentinelConnection forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
newMVar SentinelConnection'
    { rcCheckFailover :: Bool
rcCheckFailover = Bool
False
    , rcToken :: Unique
rcToken = Unique
token
    , rcSentinelConnectInfo :: SentinelConnectInfo
rcSentinelConnectInfo = SentinelConnectInfo
connectInfo
    , rcMasterConnectInfo :: ConnectInfo
rcMasterConnectInfo = ConnectInfo
masterConnectInfo
    , rcBaseConnection :: Connection
rcBaseConnection = Connection
conn
    }

updateMaster :: SentinelConnectInfo
             -> IO (SentinelConnectInfo, Redis.ConnectInfo)
updateMaster :: SentinelConnectInfo -> IO (SentinelConnectInfo, ConnectInfo)
updateMaster sci :: SentinelConnectInfo
sci@SentinelConnectInfo{NonEmpty (HostName, PortID)
ByteString
ConnectInfo
connectBaseInfo :: SentinelConnectInfo -> ConnectInfo
connectMasterName :: SentinelConnectInfo -> ByteString
connectSentinels :: SentinelConnectInfo -> NonEmpty (HostName, PortID)
connectBaseInfo :: ConnectInfo
connectMasterName :: ByteString
connectSentinels :: NonEmpty (HostName, PortID)
..} = do
    -- This is using the Either monad "backwards" -- Left means stop because we've made a connection,
    -- Right means try again.
    Either (ConnectInfo, (HostName, PortID)) ()
resultEither <- forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ NonEmpty (HostName, PortID)
connectSentinels forall a b. (a -> b) -> a -> b
$ \(HostName
host, PortID
port) -> do
      HostName
-> PortID -> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
trySentinel HostName
host PortID
port forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m a) -> m a
`catchRedis` (\HostName
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ())


    case Either (ConnectInfo, (HostName, PortID)) ()
resultEither of
        Left (ConnectInfo
conn, (HostName, PortID)
sentinelPair) -> forall (m :: * -> *) a. Monad m => a -> m a
return
          ( SentinelConnectInfo
sci
            { connectSentinels :: NonEmpty (HostName, PortID)
connectSentinels = (HostName, PortID)
sentinelPair forall a. a -> [a] -> NonEmpty a
:| forall a. Eq a => a -> [a] -> [a]
delete (HostName, PortID)
sentinelPair (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (HostName, PortID)
connectSentinels)
            }
          , ConnectInfo
conn
          )
        Right () -> forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ NonEmpty (HostName, PortID) -> RedisSentinelException
NoSentinels NonEmpty (HostName, PortID)
connectSentinels
  where
    trySentinel :: HostName -> PortID -> ExceptT (Redis.ConnectInfo, (HostName, PortID)) IO ()
    trySentinel :: HostName
-> PortID -> ExceptT (ConnectInfo, (HostName, PortID)) IO ()
trySentinel HostName
sentinelHost PortID
sentinelPort = do
      -- bang to ensure exceptions from runRedis get thrown immediately.
      !Either Reply [ByteString]
replyE <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
        !Connection
sentinelConn <- ConnectInfo -> IO Connection
Redis.connect forall a b. (a -> b) -> a -> b
$ ConnectInfo
Redis.defaultConnectInfo
            { connectHost :: HostName
connectHost = HostName
sentinelHost
            , connectPort :: PortID
connectPort = PortID
sentinelPort
            , connectMaxConnections :: Int
connectMaxConnections = Int
1
            }
        forall a. Connection -> Redis a -> IO a
Redis.runRedis Connection
sentinelConn forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *) a.
(RedisCtx m f, RedisResult a) =>
[ByteString] -> m (f a)
sendRequest
          [ByteString
"SENTINEL", ByteString
"get-master-addr-by-name", ByteString
connectMasterName]

      case Either Reply [ByteString]
replyE of
        Right [ByteString
host, ByteString
port] ->
          forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
            ( ConnectInfo
connectBaseInfo
              { connectHost :: HostName
connectHost = ByteString -> HostName
BS8.unpack ByteString
host
              , connectPort :: PortID
connectPort =
                  forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                    (PortNumber -> PortID
PortNumber PortNumber
26379)
                    (PortNumber -> PortID
PortNumber forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
                    forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (Int, ByteString)
BS8.readInt ByteString
port
              }
            , (HostName
sentinelHost, PortID
sentinelPort)
            )
        Either Reply [ByteString]
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()

catchRedisRethrow :: MonadCatch m => m a -> (String -> m ()) -> m a
catchRedisRethrow :: forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m ()) -> m a
catchRedisRethrow m a
action HostName -> m ()
handler =
  m a
action forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, MonadCatch m) =>
m a -> f (Handler m a) -> m a
`catches`
    [ forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler forall a b. (a -> b) -> a -> b
$ \IOException
ex -> HostName -> m ()
handler (forall a. Show a => a -> HostName
show @IOException IOException
ex) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM IOException
ex
    , forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler forall a b. (a -> b) -> a -> b
$ \ConnectionLostException
ex -> HostName -> m ()
handler (forall a. Show a => a -> HostName
show @ConnectionLostException ConnectionLostException
ex) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ConnectionLostException
ex
    ]

catchRedis :: MonadCatch m => m a -> (String -> m a) -> m a
catchRedis :: forall (m :: * -> *) a.
MonadCatch m =>
m a -> (HostName -> m a) -> m a
catchRedis m a
action HostName -> m a
handler =
  m a
action forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, MonadCatch m) =>
m a -> f (Handler m a) -> m a
`catches`
    [ forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler forall a b. (a -> b) -> a -> b
$ \IOException
ex -> HostName -> m a
handler (forall a. Show a => a -> HostName
show @IOException IOException
ex)
    , forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler forall a b. (a -> b) -> a -> b
$ \ConnectionLostException
ex -> HostName -> m a
handler (forall a. Show a => a -> HostName
show @ConnectionLostException ConnectionLostException
ex)
    ]

newtype SentinelConnection = SentinelConnection (MVar SentinelConnection')

data SentinelConnection'
  = SentinelConnection'
      { SentinelConnection' -> Bool
rcCheckFailover       :: Bool
      , SentinelConnection' -> Unique
rcToken               :: Unique
      , SentinelConnection' -> SentinelConnectInfo
rcSentinelConnectInfo :: SentinelConnectInfo
      , SentinelConnection' -> ConnectInfo
rcMasterConnectInfo   :: Redis.ConnectInfo
      , SentinelConnection' -> Connection
rcBaseConnection      :: Redis.Connection
      }

-- | Configuration of Sentinel hosts.
data SentinelConnectInfo
  = SentinelConnectInfo
      { SentinelConnectInfo -> NonEmpty (HostName, PortID)
connectSentinels  :: NonEmpty (HostName, PortID)
        -- ^ List of sentinels.
      , SentinelConnectInfo -> ByteString
connectMasterName :: ByteString
        -- ^ Name of master to connect to.
      , SentinelConnectInfo -> ConnectInfo
connectBaseInfo   :: Redis.ConnectInfo
        -- ^ This is used to configure auth and other parameters for Redis connection,
        -- but 'Redis.connectHost' and 'Redis.connectPort' are ignored.
      }
  deriving (Int -> SentinelConnectInfo -> ShowS
[SentinelConnectInfo] -> ShowS
SentinelConnectInfo -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [SentinelConnectInfo] -> ShowS
$cshowList :: [SentinelConnectInfo] -> ShowS
show :: SentinelConnectInfo -> HostName
$cshow :: SentinelConnectInfo -> HostName
showsPrec :: Int -> SentinelConnectInfo -> ShowS
$cshowsPrec :: Int -> SentinelConnectInfo -> ShowS
Show)

-- | Exception thrown by "Database.Redis.Sentinel".
data RedisSentinelException
  = NoSentinels (NonEmpty (HostName, PortID))
    -- ^ Thrown if no sentinel can be reached.
  deriving (Int -> RedisSentinelException -> ShowS
[RedisSentinelException] -> ShowS
RedisSentinelException -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [RedisSentinelException] -> ShowS
$cshowList :: [RedisSentinelException] -> ShowS
show :: RedisSentinelException -> HostName
$cshow :: RedisSentinelException -> HostName
showsPrec :: Int -> RedisSentinelException -> ShowS
$cshowsPrec :: Int -> RedisSentinelException -> ShowS
Show, Typeable)

deriving instance Exception RedisSentinelException