{-# LANGUAGE TupleSections #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.Redis.Connection where

import Control.Exception
import qualified Control.Monad.Catch as Catch
import Control.Monad.IO.Class(liftIO, MonadIO)
import Control.Monad(when, forM_)
import Control.Concurrent.MVar(MVar, newMVar)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import Data.Functor(void)
import qualified Data.IntMap.Strict as IntMap
import Data.Pool
import qualified Data.Time as Time
import Network.TLS (ClientParams)
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T

import qualified Database.Redis.ProtocolPipelining as PP
import Database.Redis.Core(Redis, Hooks, runRedisInternal, runRedisClusteredInternal, defaultHooks)
import Database.Redis.Protocol(Reply(..))
import Database.Redis.Cluster(ShardMap(..), Node, Shard(..))
import qualified Database.Redis.Cluster as Cluster
import qualified Database.Redis.ConnectionContext as CC
import Database.Redis.Commands
    ( ping
    , select
    , authOpts
    , defaultAuthOpts
    , AuthOpts(..)
    , clusterInfo
    , clusterSlots
    , command
    , ClusterInfoResponseState (..)
    , ClusterInfoResponse (..)
    , ClusterSlotsResponse(..)
    , ClusterSlotsResponseEntry(..)
    , ClusterSlotsNode(..))

--------------------------------------------------------------------------------
-- Connection
--

-- |A threadsafe pool of network connections to a Redis server. Use the
--  'connect' function to create one.
data Connection
    = NonClusteredConnection (Pool PP.Connection)
    | ClusteredConnection (MVar ShardMap) (Pool Cluster.Connection)

-- |Information for connnecting to a Redis server.
--
-- It is recommended to not use the 'ConnInfo' data constructor directly.
-- Instead use 'defaultConnectInfo' and update it with record syntax. For
-- example to connect to a password protected Redis server running on localhost
-- and listening to the default port:
--
-- @
-- myConnectInfo :: ConnectInfo
-- myConnectInfo = defaultConnectInfo {connectAuth = Just \"secret\"}
-- @
--
-- Or better yet, use 'parseConnectInfo' to parse a URL.
--
data ConnectInfo = ConnInfo
    { ConnectInfo -> ConnectAddr
connectAddr           :: !CC.ConnectAddr
    , ConnectInfo -> Maybe ByteString
connectAuth           :: !(Maybe B.ByteString)
    -- ^ When the server is protected by a password, set 'connectAuth' to 'Just'
    --   the password. Each connection will then authenticate by the 'auth'
    --   command.
    , ConnectInfo -> Maybe ByteString
connectUsername       :: !(Maybe B.ByteString)
    -- ^ When ACL is used set 'connectUsername' as the user.
    , ConnectInfo -> Integer
connectDatabase       :: !Integer
    -- ^ Each connection will 'select' the database with the given index.
    , ConnectInfo -> Port
connectMaxConnections :: !Int
    -- ^ Maximum number of connections to keep open. The smallest acceptable
    --   value is 1.
    , ConnectInfo -> Maybe Port
connectNumStripes     :: !(Maybe Int)
    -- ^ Number of stripes in the connection pool.
    , ConnectInfo -> NominalDiffTime
connectMaxIdleTime    :: !Time.NominalDiffTime
    -- ^ Amount of time for which an unused connection is kept open. The
    --   smallest acceptable value is 0.5 seconds. If the @timeout@ value in
    --   your redis.conf file is non-zero, it should be larger than
    --   'connectMaxIdleTime'.
    , ConnectInfo -> Maybe NominalDiffTime
connectTimeout        :: !(Maybe Time.NominalDiffTime)
    -- ^ Optional timeout until connection to Redis gets
    --   established. 'ConnectTimeoutException' gets thrown if no socket
    --   get connected in this interval of time.
    , ConnectInfo -> Maybe ClientParams
connectTLSParams      :: !(Maybe ClientParams)
    -- ^ Optional TLS parameters. TLS will be enabled if this is provided.
    , ConnectInfo -> Hooks
connectHooks          :: !Hooks
    -- ^ Connection hooks. See "Database.Redis.Hooks" for usage and examples.
    , ConnectInfo -> Text
connectPoolLabel      :: !T.Text
    -- ^ Label of the connection pool for instrumentation.
    } deriving Port -> ConnectInfo -> ShowS
[ConnectInfo] -> ShowS
ConnectInfo -> String
(Port -> ConnectInfo -> ShowS)
-> (ConnectInfo -> String)
-> ([ConnectInfo] -> ShowS)
-> Show ConnectInfo
forall a.
(Port -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Port -> ConnectInfo -> ShowS
showsPrec :: Port -> ConnectInfo -> ShowS
$cshow :: ConnectInfo -> String
show :: ConnectInfo -> String
$cshowList :: [ConnectInfo] -> ShowS
showList :: [ConnectInfo] -> ShowS
Show

data ConnectError = ConnectAuthError Reply
                  | ConnectSelectError Reply
    deriving (ConnectError -> ConnectError -> Bool
(ConnectError -> ConnectError -> Bool)
-> (ConnectError -> ConnectError -> Bool) -> Eq ConnectError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectError -> ConnectError -> Bool
== :: ConnectError -> ConnectError -> Bool
$c/= :: ConnectError -> ConnectError -> Bool
/= :: ConnectError -> ConnectError -> Bool
Eq, Port -> ConnectError -> ShowS
[ConnectError] -> ShowS
ConnectError -> String
(Port -> ConnectError -> ShowS)
-> (ConnectError -> String)
-> ([ConnectError] -> ShowS)
-> Show ConnectError
forall a.
(Port -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Port -> ConnectError -> ShowS
showsPrec :: Port -> ConnectError -> ShowS
$cshow :: ConnectError -> String
show :: ConnectError -> String
$cshowList :: [ConnectError] -> ShowS
showList :: [ConnectError] -> ShowS
Show)

instance Exception ConnectError

-- |Default information for connecting:
--
-- @
--  connectAddr           = ConnectAddrHostPort \"localhost\" 6379 -- Redis default port
--  connectAuth           = Nothing         -- No password
--  connectUsername       = Nothing         -- No user
--  connectDatabase       = 0               -- SELECT database 0
--  connectMaxConnections = 50              -- Up to 50 connections
--  connectNumStripes     = Just 1          -- A single stripe
--  connectMaxIdleTime    = 30              -- Keep open for 30 seconds
--  connectTimeout        = Nothing         -- Don't add timeout logic
--  connectTLSParams      = Nothing         -- Do not use TLS
--  connectHooks          = defaultHooks    -- Do nothing
--  connectPoolLabel      = ""              -- no label
-- @
--
defaultConnectInfo :: ConnectInfo
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = ConnInfo
    { connectAddr :: ConnectAddr
connectAddr           = String -> PortNumber -> ConnectAddr
CC.ConnectAddrHostPort String
"localhost" PortNumber
6379
    , connectAuth :: Maybe ByteString
connectAuth           = Maybe ByteString
forall a. Maybe a
Nothing
    , connectUsername :: Maybe ByteString
connectUsername       = Maybe ByteString
forall a. Maybe a
Nothing
    , connectDatabase :: Integer
connectDatabase       = Integer
0
    , connectMaxConnections :: Port
connectMaxConnections = Port
50
    , connectNumStripes :: Maybe Port
connectNumStripes     = Port -> Maybe Port
forall a. a -> Maybe a
Just Port
1
    , connectMaxIdleTime :: NominalDiffTime
connectMaxIdleTime    = NominalDiffTime
30
    , connectTimeout :: Maybe NominalDiffTime
connectTimeout        = Maybe NominalDiffTime
forall a. Maybe a
Nothing
    , connectTLSParams :: Maybe ClientParams
connectTLSParams      = Maybe ClientParams
forall a. Maybe a
Nothing
    , connectHooks :: Hooks
connectHooks          = Hooks
defaultHooks
    , connectPoolLabel :: Text
connectPoolLabel      = Text
""
    }

createConnection :: ConnectInfo -> IO PP.Connection
createConnection :: ConnectInfo -> IO Connection
createConnection ConnInfo{Port
Integer
Maybe Port
Maybe ByteString
Maybe NominalDiffTime
Maybe ClientParams
Text
NominalDiffTime
ConnectAddr
Hooks
connectAddr :: ConnectInfo -> ConnectAddr
connectAuth :: ConnectInfo -> Maybe ByteString
connectUsername :: ConnectInfo -> Maybe ByteString
connectDatabase :: ConnectInfo -> Integer
connectMaxConnections :: ConnectInfo -> Port
connectNumStripes :: ConnectInfo -> Maybe Port
connectMaxIdleTime :: ConnectInfo -> NominalDiffTime
connectTimeout :: ConnectInfo -> Maybe NominalDiffTime
connectTLSParams :: ConnectInfo -> Maybe ClientParams
connectHooks :: ConnectInfo -> Hooks
connectPoolLabel :: ConnectInfo -> Text
connectAddr :: ConnectAddr
connectAuth :: Maybe ByteString
connectUsername :: Maybe ByteString
connectDatabase :: Integer
connectMaxConnections :: Port
connectNumStripes :: Maybe Port
connectMaxIdleTime :: NominalDiffTime
connectTimeout :: Maybe NominalDiffTime
connectTLSParams :: Maybe ClientParams
connectHooks :: Hooks
connectPoolLabel :: Text
..} = do
    let timeoutOptUs :: Maybe Port
timeoutOptUs =
          NominalDiffTime -> Port
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (NominalDiffTime -> Port)
-> (NominalDiffTime -> NominalDiffTime) -> NominalDiffTime -> Port
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NominalDiffTime
1000000 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
*) (NominalDiffTime -> Port) -> Maybe NominalDiffTime -> Maybe Port
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe NominalDiffTime
connectTimeout
    conn' <- ConnectAddr
-> Maybe Port -> Maybe ClientParams -> Hooks -> IO Connection
PP.connectWithHooks ConnectAddr
connectAddr Maybe Port
timeoutOptUs Maybe ClientParams
connectTLSParams Hooks
connectHooks
    PP.beginReceiving conn'

    runRedisInternal conn' $ do
        -- AUTH
        forM_ connectAuth $ \ByteString
pass -> do
            resp <- ByteString -> AuthOpts -> Redis (Either Reply Status)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> AuthOpts -> m (f Status)
authOpts ByteString
pass AuthOpts
defaultAuthOpts{ authOptsUsername = connectUsername}
            case resp of
              Left Reply
r -> IO () -> Redis ()
forall a. IO a -> Redis a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Redis ()) -> IO () -> Redis ()
forall a b. (a -> b) -> a -> b
$ ConnectError -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ConnectError -> IO ()) -> ConnectError -> IO ()
forall a b. (a -> b) -> a -> b
$ Reply -> ConnectError
ConnectAuthError Reply
r
              Either Reply Status
_      -> () -> Redis ()
forall a. a -> Redis a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        -- SELECT
        when (connectDatabase /= 0) $ do
          resp <- select connectDatabase
          case resp of
              Left Reply
r -> IO () -> Redis ()
forall a. IO a -> Redis a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Redis ()) -> IO () -> Redis ()
forall a b. (a -> b) -> a -> b
$ ConnectError -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ConnectError -> IO ()) -> ConnectError -> IO ()
forall a b. (a -> b) -> a -> b
$ Reply -> ConnectError
ConnectSelectError Reply
r
              Either Reply Status
_      -> () -> Redis ()
forall a. a -> Redis a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    return conn'

-- | Constructs a 'Connection' pool to a Redis server designated by the
--  given 'ConnectInfo'.
--
-- The function always succeeds, because the first connection is not actually established
-- until the first call to the server.
connect :: ConnectInfo -> IO Connection
connect :: ConnectInfo -> IO Connection
connect cInfo :: ConnectInfo
cInfo@ConnInfo{Port
Integer
Maybe Port
Maybe ByteString
Maybe NominalDiffTime
Maybe ClientParams
Text
NominalDiffTime
ConnectAddr
Hooks
connectAddr :: ConnectInfo -> ConnectAddr
connectAuth :: ConnectInfo -> Maybe ByteString
connectUsername :: ConnectInfo -> Maybe ByteString
connectDatabase :: ConnectInfo -> Integer
connectMaxConnections :: ConnectInfo -> Port
connectNumStripes :: ConnectInfo -> Maybe Port
connectMaxIdleTime :: ConnectInfo -> NominalDiffTime
connectTimeout :: ConnectInfo -> Maybe NominalDiffTime
connectTLSParams :: ConnectInfo -> Maybe ClientParams
connectHooks :: ConnectInfo -> Hooks
connectPoolLabel :: ConnectInfo -> Text
connectAddr :: ConnectAddr
connectAuth :: Maybe ByteString
connectUsername :: Maybe ByteString
connectDatabase :: Integer
connectMaxConnections :: Port
connectNumStripes :: Maybe Port
connectMaxIdleTime :: NominalDiffTime
connectTimeout :: Maybe NominalDiffTime
connectTLSParams :: Maybe ClientParams
connectHooks :: Hooks
connectPoolLabel :: Text
..} = Pool Connection -> Connection
NonClusteredConnection (Pool Connection -> Connection)
-> IO (Pool Connection) -> IO Connection
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
    PoolConfig Connection -> IO (Pool Connection)
forall a. PoolConfig a -> IO (Pool a)
newPool (Text -> PoolConfig Connection -> PoolConfig Connection
forall a. Text -> PoolConfig a -> PoolConfig a
setPoolLabel Text
connectPoolLabel (PoolConfig Connection -> PoolConfig Connection)
-> (PoolConfig Connection -> PoolConfig Connection)
-> PoolConfig Connection
-> PoolConfig Connection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Port -> PoolConfig Connection -> PoolConfig Connection
forall a. Maybe Port -> PoolConfig a -> PoolConfig a
setNumStripes Maybe Port
connectNumStripes (PoolConfig Connection -> PoolConfig Connection)
-> PoolConfig Connection -> PoolConfig Connection
forall a b. (a -> b) -> a -> b
$ IO Connection
-> (Connection -> IO ()) -> Double -> Port -> PoolConfig Connection
forall a. IO a -> (a -> IO ()) -> Double -> Port -> PoolConfig a
defaultPoolConfig (ConnectInfo -> IO Connection
createConnection ConnectInfo
cInfo) Connection -> IO ()
PP.disconnect (NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac NominalDiffTime
connectMaxIdleTime) Port
connectMaxConnections)

-- |Constructs a 'Connection' pool to a Redis server designated by the
--  given 'ConnectInfo', then tests if the server is actually there.
--
--  Throws an 'ConnectError' exception if the connection to the Redis server can't be
--  established.
checkedConnect :: ConnectInfo -> IO Connection
checkedConnect :: ConnectInfo -> IO Connection
checkedConnect ConnectInfo
connInfo = do
    conn <- ConnectInfo -> IO Connection
connect ConnectInfo
connInfo
    runRedis conn $ void ping
    return conn

-- |Constructs a 'Connection' pool to a Redis cluster designated by the
--  given 'ConnectInfo', then tests if the server is actually there.
--
--  Throws an 'ClusterConnectError' exception if the connection to the Redis server can't be
--  established.
checkedConnectCluster :: ConnectInfo -> IO Connection
checkedConnectCluster :: ConnectInfo -> IO Connection
checkedConnectCluster ConnectInfo
connInfo = do
  conn <- ConnectInfo -> IO Connection
connectCluster ConnectInfo
connInfo
  res <- runRedis conn clusterInfo
  case res of
    Right ClusterInfoResponse
r -> case ClusterInfoResponse -> ClusterInfoResponseState
clusterInfoResponseState ClusterInfoResponse
r of
      ClusterInfoResponseState
OK -> Connection -> IO Connection
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Connection
conn
      ClusterInfoResponseState
Down -> ClusterDownError -> IO Connection
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ClusterDownError -> IO Connection)
-> ClusterDownError -> IO Connection
forall a b. (a -> b) -> a -> b
$ ClusterInfoResponse -> ClusterDownError
ClusterDownError ClusterInfoResponse
r
    Left Reply
e -> ClusterConnectError -> IO Connection
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ClusterConnectError -> IO Connection)
-> ClusterConnectError -> IO Connection
forall a b. (a -> b) -> a -> b
$ Reply -> ClusterConnectError
ClusterConnectError Reply
e

newtype ClusterDownError = ClusterDownError ClusterInfoResponse
  deriving (ClusterDownError -> ClusterDownError -> Bool
(ClusterDownError -> ClusterDownError -> Bool)
-> (ClusterDownError -> ClusterDownError -> Bool)
-> Eq ClusterDownError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ClusterDownError -> ClusterDownError -> Bool
== :: ClusterDownError -> ClusterDownError -> Bool
$c/= :: ClusterDownError -> ClusterDownError -> Bool
/= :: ClusterDownError -> ClusterDownError -> Bool
Eq, Port -> ClusterDownError -> ShowS
[ClusterDownError] -> ShowS
ClusterDownError -> String
(Port -> ClusterDownError -> ShowS)
-> (ClusterDownError -> String)
-> ([ClusterDownError] -> ShowS)
-> Show ClusterDownError
forall a.
(Port -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Port -> ClusterDownError -> ShowS
showsPrec :: Port -> ClusterDownError -> ShowS
$cshow :: ClusterDownError -> String
show :: ClusterDownError -> String
$cshowList :: [ClusterDownError] -> ShowS
showList :: [ClusterDownError] -> ShowS
Show)

instance Exception ClusterDownError

-- |Destroy all idle resources in the pool, works for all types of the connection.
disconnect :: Connection -> IO ()
disconnect :: Connection -> IO ()
disconnect (NonClusteredConnection Pool Connection
pool) = Pool Connection -> IO ()
forall a. Pool a -> IO ()
destroyAllResources Pool Connection
pool
disconnect (ClusteredConnection MVar ShardMap
_ Pool Connection
pool) = Pool Connection -> IO ()
forall a. Pool a -> IO ()
destroyAllResources Pool Connection
pool

-- | Memory bracket around 'connect' and 'disconnect'.
withConnect :: (Catch.MonadMask m, MonadIO m) => ConnectInfo -> (Connection -> m c) -> m c
withConnect :: forall (m :: * -> *) c.
(MonadMask m, MonadIO m) =>
ConnectInfo -> (Connection -> m c) -> m c
withConnect ConnectInfo
connInfo = m Connection -> (Connection -> m ()) -> (Connection -> m c) -> m c
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
Catch.bracket (IO Connection -> m Connection
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Connection -> m Connection) -> IO Connection -> m Connection
forall a b. (a -> b) -> a -> b
$ ConnectInfo -> IO Connection
connect ConnectInfo
connInfo) (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Connection -> IO ()) -> Connection -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
disconnect)

-- | Memory bracket around 'checkedConnect' and 'disconnect'
withCheckedConnect :: ConnectInfo -> (Connection -> IO c) -> IO c
withCheckedConnect :: forall c. ConnectInfo -> (Connection -> IO c) -> IO c
withCheckedConnect ConnectInfo
connInfo = IO Connection
-> (Connection -> IO ()) -> (Connection -> IO c) -> IO c
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (ConnectInfo -> IO Connection
checkedConnect ConnectInfo
connInfo) Connection -> IO ()
disconnect

-- |Interact with a Redis datastore specified by the given 'Connection'.
--
--  Each call of 'runRedis' takes a network connection from the 'Connection'
--  pool and runs the given 'Redis' action. Calls to 'runRedis' may thus block
--  while all connections from the pool are in use.
runRedis :: Connection -> Redis a -> IO a
runRedis :: forall a. Connection -> Redis a -> IO a
runRedis (NonClusteredConnection Pool Connection
pool) Redis a
redis =
    Pool Connection -> (Connection -> IO a) -> IO a
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
pool ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> Connection -> Redis a -> IO a
forall a. Connection -> Redis a -> IO a
runRedisInternal Connection
conn Redis a
redis
runRedis (ClusteredConnection MVar ShardMap
_ Pool Connection
pool) Redis a
redis =
    Pool Connection -> (Connection -> IO a) -> IO a
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
pool ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> Connection -> IO ShardMap -> Redis a -> IO a
forall a. Connection -> IO ShardMap -> Redis a -> IO a
runRedisClusteredInternal Connection
conn (Connection -> IO ShardMap
refreshShardMap Connection
conn) Redis a
redis

-- |Interact with a Redis datastore specified by the given 'Connection', but return early
--  if acquiring from the connection pool would block.
--
--  Like 'runRedis', but if all connections in the 'Connection' pool are used, it
--  immediately returns 'Nothing'. This can be useful for logging purposes.
runRedisNonBlocking :: Connection -> Redis a -> IO (Maybe a)
runRedisNonBlocking :: forall a. Connection -> Redis a -> IO (Maybe a)
runRedisNonBlocking (NonClusteredConnection Pool Connection
pool) Redis a
redis =
  Pool Connection -> (Connection -> IO a) -> IO (Maybe a)
forall a r. Pool a -> (a -> IO r) -> IO (Maybe r)
tryWithResource Pool Connection
pool ((Connection -> IO a) -> IO (Maybe a))
-> (Connection -> IO a) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> Connection -> Redis a -> IO a
forall a. Connection -> Redis a -> IO a
runRedisInternal Connection
conn Redis a
redis
runRedisNonBlocking (ClusteredConnection MVar ShardMap
_ Pool Connection
pool) Redis a
redis =
    Pool Connection -> (Connection -> IO a) -> IO (Maybe a)
forall a r. Pool a -> (a -> IO r) -> IO (Maybe r)
tryWithResource Pool Connection
pool ((Connection -> IO a) -> IO (Maybe a))
-> (Connection -> IO a) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> Connection -> IO ShardMap -> Redis a -> IO a
forall a. Connection -> IO ShardMap -> Redis a -> IO a
runRedisClusteredInternal Connection
conn (Connection -> IO ShardMap
refreshShardMap Connection
conn) Redis a
redis

newtype ClusterConnectError = ClusterConnectError Reply
    deriving (ClusterConnectError -> ClusterConnectError -> Bool
(ClusterConnectError -> ClusterConnectError -> Bool)
-> (ClusterConnectError -> ClusterConnectError -> Bool)
-> Eq ClusterConnectError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ClusterConnectError -> ClusterConnectError -> Bool
== :: ClusterConnectError -> ClusterConnectError -> Bool
$c/= :: ClusterConnectError -> ClusterConnectError -> Bool
/= :: ClusterConnectError -> ClusterConnectError -> Bool
Eq, Port -> ClusterConnectError -> ShowS
[ClusterConnectError] -> ShowS
ClusterConnectError -> String
(Port -> ClusterConnectError -> ShowS)
-> (ClusterConnectError -> String)
-> ([ClusterConnectError] -> ShowS)
-> Show ClusterConnectError
forall a.
(Port -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Port -> ClusterConnectError -> ShowS
showsPrec :: Port -> ClusterConnectError -> ShowS
$cshow :: ClusterConnectError -> String
show :: ClusterConnectError -> String
$cshowList :: [ClusterConnectError] -> ShowS
showList :: [ClusterConnectError] -> ShowS
Show)

instance Exception ClusterConnectError

-- |Constructs a 'ShardMap' of connections to clustered nodes. The argument is
-- a 'ConnectInfo' for any node in the cluster
--
-- Some Redis commands are currently not supported in cluster mode
-- - CONFIG, AUTH
-- - SCAN
-- - MOVE, SELECT
-- - RESET
connectCluster :: ConnectInfo -> IO Connection
connectCluster :: ConnectInfo -> IO Connection
connectCluster ConnectInfo
bootstrapConnInfo = do
    IO Connection
-> (Connection -> IO ())
-> (Connection -> IO Connection)
-> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (ConnectInfo -> IO Connection
createConnection ConnectInfo
bootstrapConnInfo) Connection -> IO ()
PP.disconnect ((Connection -> IO Connection) -> IO Connection)
-> (Connection -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
        slotsResponse <- Connection
-> Redis (Either Reply ClusterSlotsResponse)
-> IO (Either Reply ClusterSlotsResponse)
forall a. Connection -> Redis a -> IO a
runRedisInternal Connection
conn Redis (Either Reply ClusterSlotsResponse)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
m (f ClusterSlotsResponse)
clusterSlots
        shardMapVar <- case slotsResponse of
            Left Reply
e -> ClusterConnectError -> IO (MVar ShardMap)
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ClusterConnectError -> IO (MVar ShardMap))
-> ClusterConnectError -> IO (MVar ShardMap)
forall a b. (a -> b) -> a -> b
$ Reply -> ClusterConnectError
ClusterConnectError Reply
e
            Right ClusterSlotsResponse
slots -> do
                shardMap <- ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse ClusterSlotsResponse
slots
                newMVar shardMap
        commandInfos <- runRedisInternal conn command
        let timeoutOptUs =
              NominalDiffTime -> Port
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (NominalDiffTime -> Port)
-> (NominalDiffTime -> NominalDiffTime) -> NominalDiffTime -> Port
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NominalDiffTime
1000000 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
*) (NominalDiffTime -> Port) -> Maybe NominalDiffTime -> Maybe Port
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnectInfo -> Maybe NominalDiffTime
connectTimeout ConnectInfo
bootstrapConnInfo
        case commandInfos of
            Left Reply
e -> ClusterConnectError -> IO Connection
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ClusterConnectError -> IO Connection)
-> ClusterConnectError -> IO Connection
forall a b. (a -> b) -> a -> b
$ Reply -> ClusterConnectError
ClusterConnectError Reply
e
            Right [CommandInfo]
infos -> do
                pool <- PoolConfig Connection -> IO (Pool Connection)
forall a. PoolConfig a -> IO (Pool a)
newPool (Text -> PoolConfig Connection -> PoolConfig Connection
forall a. Text -> PoolConfig a -> PoolConfig a
setPoolLabel (ConnectInfo -> Text
connectPoolLabel ConnectInfo
bootstrapConnInfo)
                                (PoolConfig Connection -> PoolConfig Connection)
-> PoolConfig Connection -> PoolConfig Connection
forall a b. (a -> b) -> a -> b
$ Maybe Port -> PoolConfig Connection -> PoolConfig Connection
forall a. Maybe Port -> PoolConfig a -> PoolConfig a
setNumStripes (ConnectInfo -> Maybe Port
connectNumStripes ConnectInfo
bootstrapConnInfo)
                                (PoolConfig Connection -> PoolConfig Connection)
-> PoolConfig Connection -> PoolConfig Connection
forall a b. (a -> b) -> a -> b
$ IO Connection
-> (Connection -> IO ()) -> Double -> Port -> PoolConfig Connection
forall a. IO a -> (a -> IO ()) -> Double -> Port -> PoolConfig a
defaultPoolConfig
                                    (Maybe ByteString
-> Maybe ByteString
-> Maybe ClientParams
-> [CommandInfo]
-> MVar ShardMap
-> Maybe Port
-> Hooks
-> IO Connection
Cluster.connectWith
                                      (ConnectInfo -> Maybe ByteString
connectUsername ConnectInfo
bootstrapConnInfo)
                                      (ConnectInfo -> Maybe ByteString
connectAuth ConnectInfo
bootstrapConnInfo)
                                      (ConnectInfo -> Maybe ClientParams
connectTLSParams ConnectInfo
bootstrapConnInfo)
                                      [CommandInfo]
infos MVar ShardMap
shardMapVar Maybe Port
timeoutOptUs
                                      (Hooks -> IO Connection) -> Hooks -> IO Connection
forall a b. (a -> b) -> a -> b
$ ConnectInfo -> Hooks
connectHooks ConnectInfo
bootstrapConnInfo)
                                    Connection -> IO ()
Cluster.disconnect
                                    (NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (NominalDiffTime -> Double) -> NominalDiffTime -> Double
forall a b. (a -> b) -> a -> b
$ ConnectInfo -> NominalDiffTime
connectMaxIdleTime ConnectInfo
bootstrapConnInfo)
                                    (ConnectInfo -> Port
connectMaxConnections ConnectInfo
bootstrapConnInfo))
                return $ ClusteredConnection shardMapVar pool

shardMapFromClusterSlotsResponse :: ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse :: ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse ClusterSlotsResponse{[ClusterSlotsResponseEntry]
clusterSlotsResponseEntries :: [ClusterSlotsResponseEntry]
clusterSlotsResponseEntries :: ClusterSlotsResponse -> [ClusterSlotsResponseEntry]
..} = IntMap Shard -> ShardMap
ShardMap (IntMap Shard -> ShardMap) -> IO (IntMap Shard) -> IO ShardMap
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClusterSlotsResponseEntry
 -> IO (IntMap Shard) -> IO (IntMap Shard))
-> IO (IntMap Shard)
-> [ClusterSlotsResponseEntry]
-> IO (IntMap Shard)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ClusterSlotsResponseEntry -> IO (IntMap Shard) -> IO (IntMap Shard)
mkShardMap (IntMap Shard -> IO (IntMap Shard)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntMap Shard
forall a. IntMap a
IntMap.empty)  [ClusterSlotsResponseEntry]
clusterSlotsResponseEntries where
    mkShardMap :: ClusterSlotsResponseEntry -> IO (IntMap.IntMap Shard) -> IO (IntMap.IntMap Shard)
    mkShardMap :: ClusterSlotsResponseEntry -> IO (IntMap Shard) -> IO (IntMap Shard)
mkShardMap ClusterSlotsResponseEntry{Port
[ClusterSlotsNode]
ClusterSlotsNode
clusterSlotsResponseEntryStartSlot :: Port
clusterSlotsResponseEntryEndSlot :: Port
clusterSlotsResponseEntryMaster :: ClusterSlotsNode
clusterSlotsResponseEntryReplicas :: [ClusterSlotsNode]
clusterSlotsResponseEntryReplicas :: ClusterSlotsResponseEntry -> [ClusterSlotsNode]
clusterSlotsResponseEntryMaster :: ClusterSlotsResponseEntry -> ClusterSlotsNode
clusterSlotsResponseEntryEndSlot :: ClusterSlotsResponseEntry -> Port
clusterSlotsResponseEntryStartSlot :: ClusterSlotsResponseEntry -> Port
..} IO (IntMap Shard)
accumulator = do
        accumulated <- IO (IntMap Shard)
accumulator
        let master = Bool -> ClusterSlotsNode -> Node
nodeFromClusterSlotNode Bool
True ClusterSlotsNode
clusterSlotsResponseEntryMaster
        let replicas = (ClusterSlotsNode -> Node) -> [ClusterSlotsNode] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Bool -> ClusterSlotsNode -> Node
nodeFromClusterSlotNode Bool
False) [ClusterSlotsNode]
clusterSlotsResponseEntryReplicas
        let shard = Node -> [Node] -> Shard
Shard Node
master [Node]
replicas
        let slotMap = [(Port, Shard)] -> IntMap Shard
forall a. [(Port, a)] -> IntMap a
IntMap.fromList ([(Port, Shard)] -> IntMap Shard)
-> [(Port, Shard)] -> IntMap Shard
forall a b. (a -> b) -> a -> b
$ (Port -> (Port, Shard)) -> [Port] -> [(Port, Shard)]
forall a b. (a -> b) -> [a] -> [b]
map (, Shard
shard) [Port
clusterSlotsResponseEntryStartSlot..Port
clusterSlotsResponseEntryEndSlot]
        return $ IntMap.union slotMap accumulated
    nodeFromClusterSlotNode :: Bool -> ClusterSlotsNode -> Node
    nodeFromClusterSlotNode :: Bool -> ClusterSlotsNode -> Node
nodeFromClusterSlotNode Bool
isMaster ClusterSlotsNode{Port
ByteString
clusterSlotsNodeIP :: ByteString
clusterSlotsNodePort :: Port
clusterSlotsNodeID :: ByteString
clusterSlotsNodeID :: ClusterSlotsNode -> ByteString
clusterSlotsNodePort :: ClusterSlotsNode -> Port
clusterSlotsNodeIP :: ClusterSlotsNode -> ByteString
..} =
        let hostname :: String
hostname = ByteString -> String
Char8.unpack ByteString
clusterSlotsNodeIP
            role :: NodeRole
role = if Bool
isMaster then NodeRole
Cluster.Master else NodeRole
Cluster.Slave
        in
            ByteString -> NodeRole -> String -> Port -> Node
Cluster.Node ByteString
clusterSlotsNodeID NodeRole
role String
hostname (Port -> Port
forall a. Enum a => Port -> a
toEnum Port
clusterSlotsNodePort)

refreshShardMap :: Cluster.Connection -> IO ShardMap
refreshShardMap :: Connection -> IO ShardMap
refreshShardMap Cluster.Connection{connectionNodes :: Connection -> HashMap ByteString NodeConnection
connectionNodes=HashMap ByteString NodeConnection
nodeConns} = do
    let Cluster.NodeConnection{nodeConnectionContext :: NodeConnection -> ConnectionContext
nodeConnectionContext=ConnectionContext
ctx} = [NodeConnection] -> NodeConnection
forall a. HasCallStack => [a] -> a
head ([NodeConnection] -> NodeConnection)
-> [NodeConnection] -> NodeConnection
forall a b. (a -> b) -> a -> b
$ HashMap ByteString NodeConnection -> [NodeConnection]
forall k v. HashMap k v -> [v]
HM.elems HashMap ByteString NodeConnection
nodeConns
    pipelineConn <- ConnectionContext -> IO Connection
PP.fromCtx ConnectionContext
ctx
    _ <- PP.beginReceiving pipelineConn
    slotsResponse <- runRedisInternal pipelineConn clusterSlots
    case slotsResponse of
        Left Reply
e -> ClusterConnectError -> IO ShardMap
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ClusterConnectError -> IO ShardMap)
-> ClusterConnectError -> IO ShardMap
forall a b. (a -> b) -> a -> b
$ Reply -> ClusterConnectError
ClusterConnectError Reply
e
        Right ClusterSlotsResponse
slots -> ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse ClusterSlotsResponse
slots