{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.Redis.ConnectionContext (
    ConnectionContext(..)
  , ConnectTimeout(..)
  , ConnectionLostException(..)
  , ConnectAddr(..)
  , connect
  , disconnect
  , send
  , recv
  , errConnClosed
  , enableTLS
  , flush
  , ioErrorToConnLost
) where

import Control.Monad(when)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy as LB
import qualified Data.IORef as IOR
import Control.Concurrent.MVar(newMVar, readMVar, swapMVar)
import Control.Exception(bracketOnError, Exception, throwIO, try, finally, mask_)
import Data.Functor(void)
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen)
import System.IO.Error(catchIOError)
import System.Timeout (timeout)

data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context Handle

instance Show ConnectionContext where
    show :: ConnectionContext -> String
show (NormalHandle Handle
_) = String
"NormalHandle"
    show (TLSContext Context
_ Handle
_) = String
"TLSContext"

data Connection = Connection
    { Connection -> ConnectionContext
ctx :: ConnectionContext
    , Connection -> IORef (Maybe ByteString)
lastRecvRef :: IOR.IORef (Maybe B.ByteString) }

instance Show Connection where
    show :: Connection -> String
show Connection{IORef (Maybe ByteString)
ConnectionContext
ctx :: Connection -> ConnectionContext
lastRecvRef :: Connection -> IORef (Maybe ByteString)
ctx :: ConnectionContext
lastRecvRef :: IORef (Maybe ByteString)
..} = String
"Connection{ ctx = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ConnectionContext -> String
forall a. Show a => a -> String
show ConnectionContext
ctx String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", lastRecvRef = IORef}"

data ConnectPhase
  = PhaseUnknown
  | PhaseResolve
  | PhaseOpenSocket
  deriving (Int -> ConnectPhase -> ShowS
[ConnectPhase] -> ShowS
ConnectPhase -> String
(Int -> ConnectPhase -> ShowS)
-> (ConnectPhase -> String)
-> ([ConnectPhase] -> ShowS)
-> Show ConnectPhase
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectPhase -> ShowS
showsPrec :: Int -> ConnectPhase -> ShowS
$cshow :: ConnectPhase -> String
show :: ConnectPhase -> String
$cshowList :: [ConnectPhase] -> ShowS
showList :: [ConnectPhase] -> ShowS
Show)

newtype ConnectTimeout = ConnectTimeout ConnectPhase
  deriving (Int -> ConnectTimeout -> ShowS
[ConnectTimeout] -> ShowS
ConnectTimeout -> String
(Int -> ConnectTimeout -> ShowS)
-> (ConnectTimeout -> String)
-> ([ConnectTimeout] -> ShowS)
-> Show ConnectTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectTimeout -> ShowS
showsPrec :: Int -> ConnectTimeout -> ShowS
$cshow :: ConnectTimeout -> String
show :: ConnectTimeout -> String
$cshowList :: [ConnectTimeout] -> ShowS
showList :: [ConnectTimeout] -> ShowS
Show)

instance Exception ConnectTimeout

data ConnectionLostException = ConnectionLost deriving Int -> ConnectionLostException -> ShowS
[ConnectionLostException] -> ShowS
ConnectionLostException -> String
(Int -> ConnectionLostException -> ShowS)
-> (ConnectionLostException -> String)
-> ([ConnectionLostException] -> ShowS)
-> Show ConnectionLostException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionLostException -> ShowS
showsPrec :: Int -> ConnectionLostException -> ShowS
$cshow :: ConnectionLostException -> String
show :: ConnectionLostException -> String
$cshowList :: [ConnectionLostException] -> ShowS
showList :: [ConnectionLostException] -> ShowS
Show
instance Exception ConnectionLostException

data ConnectAddr
  = ConnectAddrHostPort NS.HostName NS.PortNumber
  | ConnectAddrUnixSocket String
  deriving (ConnectAddr -> ConnectAddr -> Bool
(ConnectAddr -> ConnectAddr -> Bool)
-> (ConnectAddr -> ConnectAddr -> Bool) -> Eq ConnectAddr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectAddr -> ConnectAddr -> Bool
== :: ConnectAddr -> ConnectAddr -> Bool
$c/= :: ConnectAddr -> ConnectAddr -> Bool
/= :: ConnectAddr -> ConnectAddr -> Bool
Eq, Int -> ConnectAddr -> ShowS
[ConnectAddr] -> ShowS
ConnectAddr -> String
(Int -> ConnectAddr -> ShowS)
-> (ConnectAddr -> String)
-> ([ConnectAddr] -> ShowS)
-> Show ConnectAddr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectAddr -> ShowS
showsPrec :: Int -> ConnectAddr -> ShowS
$cshow :: ConnectAddr -> String
show :: ConnectAddr -> String
$cshowList :: [ConnectAddr] -> ShowS
showList :: [ConnectAddr] -> ShowS
Show)

connect :: ConnectAddr -> Maybe Int -> Maybe TLS.ClientParams -> IO ConnectionContext
connect :: ConnectAddr
-> Maybe Int -> Maybe ClientParams -> IO ConnectionContext
connect ConnectAddr
connectAddr Maybe Int
timeoutOpt Maybe ClientParams
mTlsParams =
  IO Handle
-> (Handle -> IO ())
-> (Handle -> IO ConnectionContext)
-> IO ConnectionContext
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Handle
hConnect Handle -> IO ()
hClose ((Handle -> IO ConnectionContext) -> IO ConnectionContext)
-> (Handle -> IO ConnectionContext) -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ \Handle
h -> do
    Handle -> Bool -> IO ()
hSetBinaryMode Handle
h Bool
True
    case (Maybe ClientParams
mTlsParams, ConnectAddr
connectAddr) of
      (Just ClientParams
defaultTlsParams, ConnectAddrHostPort String
host PortNumber
port) -> do
        -- The defaultTlsParams are used to connect to the first
        -- host in the cluster, other hosts have different
        -- hostnames and so require a different server
        -- identification params
        let tlsParams :: ClientParams
tlsParams = ClientParams
defaultTlsParams {
              TLS.clientServerIdentification =  (host, Char8.pack $ show port)
            }
        ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS ClientParams
tlsParams (Handle -> ConnectionContext
NormalHandle Handle
h)
      (Maybe ClientParams, ConnectAddr)
_ -> ConnectionContext -> IO ConnectionContext
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectionContext -> IO ConnectionContext)
-> ConnectionContext -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ Handle -> ConnectionContext
NormalHandle Handle
h
  where
        hConnect :: IO Handle
hConnect = do
          phaseMVar <- ConnectPhase -> IO (MVar ConnectPhase)
forall a. a -> IO (MVar a)
newMVar ConnectPhase
PhaseUnknown
          let doConnect = MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
phaseMVar
          case timeoutOpt of
            Maybe Int
Nothing -> IO Handle
doConnect
            Just Int
micros -> do
              result <- Int -> IO Handle -> IO (Maybe Handle)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
micros IO Handle
doConnect
              case result of
                Just Handle
h -> Handle -> IO Handle
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h
                Maybe Handle
Nothing -> do
                  phase <- MVar ConnectPhase -> IO ConnectPhase
forall a. MVar a -> IO a
readMVar MVar ConnectPhase
phaseMVar
                  errConnectTimeout phase
        hConnect' :: MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
mvar = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO Handle) -> IO Handle)
-> (Socket -> IO Handle) -> IO Handle
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
          Socket -> SocketOption -> Int -> IO ()
NS.setSocketOption Socket
sock SocketOption
NS.KeepAlive Int
1
          IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseResolve
          IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseOpenSocket
          Socket -> IOMode -> IO Handle
NS.socketToHandle Socket
sock IOMode
ReadWriteMode
          where
            createSock :: IO Socket
createSock = case ConnectAddr
connectAddr of
              ConnectAddrHostPort String
hostName PortNumber
portNumber -> do
                addrInfo <- String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostName PortNumber
portNumber
                connectSocket addrInfo
              ConnectAddrUnixSocket String
addr -> IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
                (Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket Family
NS.AF_UNIX SocketType
NS.Stream ProtocolNumber
NS.defaultProtocol)
                Socket -> IO ()
NS.close
                (\Socket
sock -> Socket -> SockAddr -> IO ()
NS.connect Socket
sock (String -> SockAddr
NS.SockAddrUnix String
addr) IO () -> IO Socket -> IO Socket
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)

getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostname PortNumber
port = Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
forall (t :: * -> *).
GetAddrInfo t =>
Maybe AddrInfo -> Maybe String -> Maybe String -> IO (t AddrInfo)
NS.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
hostname) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port)
  where
    hints :: AddrInfo
hints = AddrInfo
NS.defaultHints
      { NS.addrSocketType = NS.Stream }

errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout :: forall a. ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase = ConnectTimeout -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ConnectTimeout -> IO a) -> ConnectTimeout -> IO a
forall a b. (a -> b) -> a -> b
$ ConnectPhase -> ConnectTimeout
ConnectTimeout ConnectPhase
phase

connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket :: [AddrInfo] -> IO Socket
connectSocket [] = String -> IO Socket
forall a. HasCallStack => String -> a
error String
"connectSocket: unexpected empty list"
connectSocket (AddrInfo
addr:[AddrInfo]
rest) = IO (Either IOError Socket)
tryConnect IO (Either IOError Socket)
-> (Either IOError Socket -> IO Socket) -> IO Socket
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Right Socket
sock -> Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
  Left IOError
err   -> if [AddrInfo] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [AddrInfo]
rest
                then IOError -> IO Socket
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO IOError
err
                else [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
rest
  where
    tryConnect :: IO (Either IOError NS.Socket)
    tryConnect :: IO (Either IOError Socket)
tryConnect = IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO (Either IOError Socket))
 -> IO (Either IOError Socket))
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
      IO () -> IO (Either IOError ())
forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> SockAddr -> IO ()
NS.connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
NS.addrAddress AddrInfo
addr) IO (Either IOError ())
-> (Either IOError () -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Right () -> Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> Either IOError Socket
forall a b. b -> Either a b
Right Socket
sock)
      Left IOError
err -> Socket -> IO ()
NS.close Socket
sock IO () -> IO (Either IOError Socket) -> IO (Either IOError Socket)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> Either IOError Socket
forall a b. a -> Either a b
Left IOError
err)
      where
        createSock :: IO Socket
createSock = Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket (AddrInfo -> Family
NS.addrFamily AddrInfo
addr)
                               (AddrInfo -> SocketType
NS.addrSocketType AddrInfo
addr)
                               (AddrInfo -> ProtocolNumber
NS.addrProtocol AddrInfo
addr)

send :: ConnectionContext -> B.ByteString -> IO ()
send :: ConnectionContext -> ByteString -> IO ()
send (NormalHandle Handle
h) ByteString
requestData =
    IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (Handle -> ByteString -> IO ()
B.hPut Handle
h ByteString
requestData)
send (TLSContext Context
ctx Handle
_) ByteString
requestData =
    IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> ByteString
LB.fromStrict ByteString
requestData))

recv :: ConnectionContext -> IO B.ByteString
recv :: ConnectionContext -> IO ByteString
recv (NormalHandle Handle
h) = IO ByteString -> IO ByteString
forall a. IO a -> IO a
ioErrorToConnLost (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
B.hGetSome Handle
h Int
4096
recv (TLSContext Context
ctx Handle
_) = Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx


ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost :: forall a. IO a -> IO a
ioErrorToConnLost IO a
a = IO a
a IO a -> (IOError -> IO a) -> IO a
forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` IO a -> IOError -> IO a
forall a b. a -> b -> a
const IO a
forall a. IO a
errConnClosed

errConnClosed :: IO a
errConnClosed :: forall a. IO a
errConnClosed = ConnectionLostException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO ConnectionLostException
ConnectionLost


enableTLS :: TLS.ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS :: ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS ClientParams
tlsParams (NormalHandle Handle
h) = do
  ctx <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
h ClientParams
tlsParams
  TLS.handshake ctx
  return $! TLSContext ctx h
enableTLS ClientParams
_ c :: ConnectionContext
c@(TLSContext Context
_ Handle
_) = ConnectionContext -> IO ConnectionContext
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionContext
c


disconnect :: ConnectionContext -> IO ()
disconnect :: ConnectionContext -> IO ()
disconnect (NormalHandle Handle
h) = IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  open <- Handle -> IO Bool
hIsOpen Handle
h
  when open $ hClose h
disconnect (TLSContext Context
ctx Handle
h) =
  Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Context -> IO ()
TLS.contextClose Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` (Handle -> IO Bool
hIsOpen Handle
h IO Bool -> (Bool -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
open -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
open (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
hClose Handle
h)

flush :: ConnectionContext -> IO ()
flush :: ConnectionContext -> IO ()
flush (NormalHandle Handle
h) = Handle -> IO ()
hFlush Handle
h
flush (TLSContext Context
ctx Handle
_) = Context -> IO ()
TLS.contextFlush Context
ctx