{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.Redis.ConnectionContext (
ConnectionContext(..)
, ConnectTimeout(..)
, ConnectionLostException(..)
, ConnectAddr(..)
, connect
, disconnect
, send
, recv
, errConnClosed
, enableTLS
, flush
, ioErrorToConnLost
) where
import Control.Monad(when)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy as LB
import qualified Data.IORef as IOR
import Control.Concurrent.MVar(newMVar, readMVar, swapMVar)
import Control.Exception(bracketOnError, Exception, throwIO, try, finally, mask_)
import Data.Functor(void)
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen)
import System.IO.Error(catchIOError)
import System.Timeout (timeout)
data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context Handle
instance Show ConnectionContext where
show :: ConnectionContext -> String
show (NormalHandle Handle
_) = String
"NormalHandle"
show (TLSContext Context
_ Handle
_) = String
"TLSContext"
data Connection = Connection
{ Connection -> ConnectionContext
ctx :: ConnectionContext
, Connection -> IORef (Maybe ByteString)
lastRecvRef :: IOR.IORef (Maybe B.ByteString) }
instance Show Connection where
show :: Connection -> String
show Connection{IORef (Maybe ByteString)
ConnectionContext
ctx :: Connection -> ConnectionContext
lastRecvRef :: Connection -> IORef (Maybe ByteString)
ctx :: ConnectionContext
lastRecvRef :: IORef (Maybe ByteString)
..} = String
"Connection{ ctx = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ConnectionContext -> String
forall a. Show a => a -> String
show ConnectionContext
ctx String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", lastRecvRef = IORef}"
data ConnectPhase
= PhaseUnknown
| PhaseResolve
| PhaseOpenSocket
deriving (Int -> ConnectPhase -> ShowS
[ConnectPhase] -> ShowS
ConnectPhase -> String
(Int -> ConnectPhase -> ShowS)
-> (ConnectPhase -> String)
-> ([ConnectPhase] -> ShowS)
-> Show ConnectPhase
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectPhase -> ShowS
showsPrec :: Int -> ConnectPhase -> ShowS
$cshow :: ConnectPhase -> String
show :: ConnectPhase -> String
$cshowList :: [ConnectPhase] -> ShowS
showList :: [ConnectPhase] -> ShowS
Show)
newtype ConnectTimeout = ConnectTimeout ConnectPhase
deriving (Int -> ConnectTimeout -> ShowS
[ConnectTimeout] -> ShowS
ConnectTimeout -> String
(Int -> ConnectTimeout -> ShowS)
-> (ConnectTimeout -> String)
-> ([ConnectTimeout] -> ShowS)
-> Show ConnectTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectTimeout -> ShowS
showsPrec :: Int -> ConnectTimeout -> ShowS
$cshow :: ConnectTimeout -> String
show :: ConnectTimeout -> String
$cshowList :: [ConnectTimeout] -> ShowS
showList :: [ConnectTimeout] -> ShowS
Show)
instance Exception ConnectTimeout
data ConnectionLostException = ConnectionLost deriving Int -> ConnectionLostException -> ShowS
[ConnectionLostException] -> ShowS
ConnectionLostException -> String
(Int -> ConnectionLostException -> ShowS)
-> (ConnectionLostException -> String)
-> ([ConnectionLostException] -> ShowS)
-> Show ConnectionLostException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionLostException -> ShowS
showsPrec :: Int -> ConnectionLostException -> ShowS
$cshow :: ConnectionLostException -> String
show :: ConnectionLostException -> String
$cshowList :: [ConnectionLostException] -> ShowS
showList :: [ConnectionLostException] -> ShowS
Show
instance Exception ConnectionLostException
data ConnectAddr
= ConnectAddrHostPort NS.HostName NS.PortNumber
| ConnectAddrUnixSocket String
deriving (ConnectAddr -> ConnectAddr -> Bool
(ConnectAddr -> ConnectAddr -> Bool)
-> (ConnectAddr -> ConnectAddr -> Bool) -> Eq ConnectAddr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectAddr -> ConnectAddr -> Bool
== :: ConnectAddr -> ConnectAddr -> Bool
$c/= :: ConnectAddr -> ConnectAddr -> Bool
/= :: ConnectAddr -> ConnectAddr -> Bool
Eq, Int -> ConnectAddr -> ShowS
[ConnectAddr] -> ShowS
ConnectAddr -> String
(Int -> ConnectAddr -> ShowS)
-> (ConnectAddr -> String)
-> ([ConnectAddr] -> ShowS)
-> Show ConnectAddr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectAddr -> ShowS
showsPrec :: Int -> ConnectAddr -> ShowS
$cshow :: ConnectAddr -> String
show :: ConnectAddr -> String
$cshowList :: [ConnectAddr] -> ShowS
showList :: [ConnectAddr] -> ShowS
Show)
connect :: ConnectAddr -> Maybe Int -> Maybe TLS.ClientParams -> IO ConnectionContext
connect :: ConnectAddr
-> Maybe Int -> Maybe ClientParams -> IO ConnectionContext
connect ConnectAddr
connectAddr Maybe Int
timeoutOpt Maybe ClientParams
mTlsParams =
IO Handle
-> (Handle -> IO ())
-> (Handle -> IO ConnectionContext)
-> IO ConnectionContext
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Handle
hConnect Handle -> IO ()
hClose ((Handle -> IO ConnectionContext) -> IO ConnectionContext)
-> (Handle -> IO ConnectionContext) -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ \Handle
h -> do
Handle -> Bool -> IO ()
hSetBinaryMode Handle
h Bool
True
case (Maybe ClientParams
mTlsParams, ConnectAddr
connectAddr) of
(Just ClientParams
defaultTlsParams, ConnectAddrHostPort String
host PortNumber
port) -> do
let tlsParams :: ClientParams
tlsParams = ClientParams
defaultTlsParams {
TLS.clientServerIdentification = (host, Char8.pack $ show port)
}
ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS ClientParams
tlsParams (Handle -> ConnectionContext
NormalHandle Handle
h)
(Maybe ClientParams, ConnectAddr)
_ -> ConnectionContext -> IO ConnectionContext
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectionContext -> IO ConnectionContext)
-> ConnectionContext -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ Handle -> ConnectionContext
NormalHandle Handle
h
where
hConnect :: IO Handle
hConnect = do
phaseMVar <- ConnectPhase -> IO (MVar ConnectPhase)
forall a. a -> IO (MVar a)
newMVar ConnectPhase
PhaseUnknown
let doConnect = MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
phaseMVar
case timeoutOpt of
Maybe Int
Nothing -> IO Handle
doConnect
Just Int
micros -> do
result <- Int -> IO Handle -> IO (Maybe Handle)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
micros IO Handle
doConnect
case result of
Just Handle
h -> Handle -> IO Handle
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h
Maybe Handle
Nothing -> do
phase <- MVar ConnectPhase -> IO ConnectPhase
forall a. MVar a -> IO a
readMVar MVar ConnectPhase
phaseMVar
errConnectTimeout phase
hConnect' :: MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
mvar = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO Handle) -> IO Handle)
-> (Socket -> IO Handle) -> IO Handle
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
Socket -> SocketOption -> Int -> IO ()
NS.setSocketOption Socket
sock SocketOption
NS.KeepAlive Int
1
IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseResolve
IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseOpenSocket
Socket -> IOMode -> IO Handle
NS.socketToHandle Socket
sock IOMode
ReadWriteMode
where
createSock :: IO Socket
createSock = case ConnectAddr
connectAddr of
ConnectAddrHostPort String
hostName PortNumber
portNumber -> do
addrInfo <- String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostName PortNumber
portNumber
connectSocket addrInfo
ConnectAddrUnixSocket String
addr -> IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
(Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket Family
NS.AF_UNIX SocketType
NS.Stream ProtocolNumber
NS.defaultProtocol)
Socket -> IO ()
NS.close
(\Socket
sock -> Socket -> SockAddr -> IO ()
NS.connect Socket
sock (String -> SockAddr
NS.SockAddrUnix String
addr) IO () -> IO Socket -> IO Socket
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)
getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostname PortNumber
port = Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
forall (t :: * -> *).
GetAddrInfo t =>
Maybe AddrInfo -> Maybe String -> Maybe String -> IO (t AddrInfo)
NS.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
hostname) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port)
where
hints :: AddrInfo
hints = AddrInfo
NS.defaultHints
{ NS.addrSocketType = NS.Stream }
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout :: forall a. ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase = ConnectTimeout -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (ConnectTimeout -> IO a) -> ConnectTimeout -> IO a
forall a b. (a -> b) -> a -> b
$ ConnectPhase -> ConnectTimeout
ConnectTimeout ConnectPhase
phase
connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket :: [AddrInfo] -> IO Socket
connectSocket [] = String -> IO Socket
forall a. HasCallStack => String -> a
error String
"connectSocket: unexpected empty list"
connectSocket (AddrInfo
addr:[AddrInfo]
rest) = IO (Either IOError Socket)
tryConnect IO (Either IOError Socket)
-> (Either IOError Socket -> IO Socket) -> IO Socket
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Right Socket
sock -> Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
Left IOError
err -> if [AddrInfo] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [AddrInfo]
rest
then IOError -> IO Socket
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO IOError
err
else [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
rest
where
tryConnect :: IO (Either IOError NS.Socket)
tryConnect :: IO (Either IOError Socket)
tryConnect = IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket))
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
IO () -> IO (Either IOError ())
forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> SockAddr -> IO ()
NS.connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
NS.addrAddress AddrInfo
addr) IO (Either IOError ())
-> (Either IOError () -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Right () -> Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> Either IOError Socket
forall a b. b -> Either a b
Right Socket
sock)
Left IOError
err -> Socket -> IO ()
NS.close Socket
sock IO () -> IO (Either IOError Socket) -> IO (Either IOError Socket)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> Either IOError Socket
forall a b. a -> Either a b
Left IOError
err)
where
createSock :: IO Socket
createSock = Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket (AddrInfo -> Family
NS.addrFamily AddrInfo
addr)
(AddrInfo -> SocketType
NS.addrSocketType AddrInfo
addr)
(AddrInfo -> ProtocolNumber
NS.addrProtocol AddrInfo
addr)
send :: ConnectionContext -> B.ByteString -> IO ()
send :: ConnectionContext -> ByteString -> IO ()
send (NormalHandle Handle
h) ByteString
requestData =
IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (Handle -> ByteString -> IO ()
B.hPut Handle
h ByteString
requestData)
send (TLSContext Context
ctx Handle
_) ByteString
requestData =
IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> ByteString
LB.fromStrict ByteString
requestData))
recv :: ConnectionContext -> IO B.ByteString
recv :: ConnectionContext -> IO ByteString
recv (NormalHandle Handle
h) = IO ByteString -> IO ByteString
forall a. IO a -> IO a
ioErrorToConnLost (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
B.hGetSome Handle
h Int
4096
recv (TLSContext Context
ctx Handle
_) = Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost :: forall a. IO a -> IO a
ioErrorToConnLost IO a
a = IO a
a IO a -> (IOError -> IO a) -> IO a
forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` IO a -> IOError -> IO a
forall a b. a -> b -> a
const IO a
forall a. IO a
errConnClosed
errConnClosed :: IO a
errConnClosed :: forall a. IO a
errConnClosed = ConnectionLostException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO ConnectionLostException
ConnectionLost
enableTLS :: TLS.ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS :: ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS ClientParams
tlsParams (NormalHandle Handle
h) = do
ctx <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
h ClientParams
tlsParams
TLS.handshake ctx
return $! TLSContext ctx h
enableTLS ClientParams
_ c :: ConnectionContext
c@(TLSContext Context
_ Handle
_) = ConnectionContext -> IO ConnectionContext
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionContext
c
disconnect :: ConnectionContext -> IO ()
disconnect :: ConnectionContext -> IO ()
disconnect (NormalHandle Handle
h) = IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
open <- Handle -> IO Bool
hIsOpen Handle
h
when open $ hClose h
disconnect (TLSContext Context
ctx Handle
h) =
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Context -> IO ()
TLS.contextClose Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` (Handle -> IO Bool
hIsOpen Handle
h IO Bool -> (Bool -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
open -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
open (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
hClose Handle
h)
flush :: ConnectionContext -> IO ()
flush :: ConnectionContext -> IO ()
flush (NormalHandle Handle
h) = Handle -> IO ()
hFlush Handle
h
flush (TLSContext Context
ctx Handle
_) = Context -> IO ()
TLS.contextFlush Context
ctx