{-# LANGUAGE CPP, RankNTypes, RecordWildCards, ScopedTypeVariables #-}
module Data.Acid.Remote
(
acidServer
, acidServerSockAddr
, acidServer'
, openRemoteState
, openRemoteStateSockAddr
, skipAuthenticationCheck
, skipAuthenticationPerform
, sharedSecretCheck
, sharedSecretPerform
, AcidRemoteException(..)
, CommChannel(..)
, process
, processRemoteState
) where
import Prelude hiding ( catch )
import Control.Concurrent.STM ( atomically )
import Control.Concurrent.STM.TMVar ( newEmptyTMVar, readTMVar, takeTMVar, tryTakeTMVar, putTMVar )
import Control.Concurrent.STM.TQueue
import Control.Exception ( AsyncException(ThreadKilled)
, Exception(fromException), IOException, Handler(..)
, SomeException, catch, catches, throw, bracketOnError )
import Control.Exception ( throwIO, finally )
import Control.Monad ( forever, liftM, join, when )
import Control.Concurrent ( ThreadId, forkIO, threadDelay, killThread, myThreadId )
import Control.Concurrent.MVar ( MVar, newEmptyMVar, putMVar, takeMVar )
import Control.Concurrent.Chan ( newChan, readChan, writeChan )
import Data.Acid.Abstract
import Data.Acid.Core
import Data.Acid.Common
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
import qualified Data.ByteString as Strict
import Data.ByteString.Char8 ( pack )
import qualified Data.ByteString.Lazy as Lazy
import Data.IORef ( newIORef, readIORef, writeIORef )
import Data.Serialize
import Data.Set ( Set, member )
import GHC.IO.Exception ( IOErrorType(..) )
import Network.BSD ( PortNumber, getProtocolNumber, getHostByName, hostAddress )
import Network.Socket
import Network.Socket.ByteString as NSB ( recv, sendAll )
import System.Directory ( removeFile )
import System.IO ( Handle, hPrint, hFlush, hClose, stderr, IOMode(..) )
import System.IO.Error ( ioeGetErrorType, isFullError, isDoesNotExistError )
debugStrLn :: String -> IO ()
debugStrLn :: String -> IO ()
debugStrLn String
s =
do
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
data CommChannel = CommChannel
{ CommChannel -> ByteString -> IO ()
ccPut :: Strict.ByteString -> IO ()
, CommChannel -> Int -> IO ByteString
ccGetSome :: Int -> IO (Strict.ByteString)
, CommChannel -> IO ()
ccClose :: IO ()
}
data AcidRemoteException
= RemoteConnectionError
| AcidStateClosed
| SerializeError String
| AuthenticationError String
deriving (AcidRemoteException -> AcidRemoteException -> Bool
(AcidRemoteException -> AcidRemoteException -> Bool)
-> (AcidRemoteException -> AcidRemoteException -> Bool)
-> Eq AcidRemoteException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AcidRemoteException -> AcidRemoteException -> Bool
== :: AcidRemoteException -> AcidRemoteException -> Bool
$c/= :: AcidRemoteException -> AcidRemoteException -> Bool
/= :: AcidRemoteException -> AcidRemoteException -> Bool
Eq, Int -> AcidRemoteException -> ShowS
[AcidRemoteException] -> ShowS
AcidRemoteException -> String
(Int -> AcidRemoteException -> ShowS)
-> (AcidRemoteException -> String)
-> ([AcidRemoteException] -> ShowS)
-> Show AcidRemoteException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AcidRemoteException -> ShowS
showsPrec :: Int -> AcidRemoteException -> ShowS
$cshow :: AcidRemoteException -> String
show :: AcidRemoteException -> String
$cshowList :: [AcidRemoteException] -> ShowS
showList :: [AcidRemoteException] -> ShowS
Show)
instance Exception AcidRemoteException
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel Handle
handle =
CommChannel { ccPut :: ByteString -> IO ()
ccPut = \ByteString
bs -> Handle -> ByteString -> IO ()
Strict.hPut Handle
handle ByteString
bs IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
handle
, ccGetSome :: Int -> IO ByteString
ccGetSome = Handle -> Int -> IO ByteString
Strict.hGetSome Handle
handle
, ccClose :: IO ()
ccClose = Handle -> IO ()
hClose Handle
handle
}
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel Socket
socket =
CommChannel { ccPut :: ByteString -> IO ()
ccPut = Socket -> ByteString -> IO ()
sendAll Socket
socket
, ccGetSome :: Int -> IO ByteString
ccGetSome = Socket -> Int -> IO ByteString
NSB.recv Socket
socket
, ccClose :: IO ()
ccClose = Socket -> IO ()
close Socket
socket
}
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck CommChannel
_ = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform CommChannel
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sharedSecretCheck :: Set Strict.ByteString
-> (CommChannel -> IO Bool)
sharedSecretCheck :: Set ByteString -> CommChannel -> IO Bool
sharedSecretCheck Set ByteString
secrets CommChannel
cc =
do bs <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
if member bs secrets
then do ccPut cc (pack "OK")
return True
else do ccPut cc (pack "FAIL")
return False
sharedSecretPerform :: Strict.ByteString
-> (CommChannel -> IO ())
sharedSecretPerform :: ByteString -> CommChannel -> IO ()
sharedSecretPerform ByteString
pw CommChannel
cc =
do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc ByteString
pw
r <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
if r == (pack "OK")
then return ()
else throwIO (AuthenticationError "shared secret authentication failed.")
acidServerSockAddr :: (CommChannel -> IO Bool)
-> SockAddr
-> AcidState st
-> IO ()
acidServerSockAddr :: forall st.
(CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
acidServerSockAddr CommChannel -> IO Bool
checkAuth SockAddr
sockAddr AcidState st
acidState
= do listenSocket <- SockAddr -> IO Socket
listenOn SockAddr
sockAddr
(acidServer' checkAuth listenSocket acidState) `finally` (cleanup listenSocket)
where
cleanup :: Socket -> IO ()
cleanup Socket
socket =
do Socket -> IO ()
close Socket
socket
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
case SockAddr
sockAddr of
(SockAddrUnix String
path) -> String -> IO ()
removeFile String
path
SockAddr
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#endif
acidServer :: (CommChannel -> IO Bool)
-> PortNumber
-> AcidState st
-> IO ()
acidServer :: forall st.
(CommChannel -> IO Bool) -> PortNumber -> AcidState st -> IO ()
acidServer CommChannel -> IO Bool
checkAuth PortNumber
port AcidState st
acidState
= (CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
forall st.
(CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
acidServerSockAddr CommChannel -> IO Bool
checkAuth (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port HostAddress
0) AcidState st
acidState
listenOn :: SockAddr -> IO Socket
listenOn :: SockAddr -> IO Socket
listenOn SockAddr
sockAddr = do
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
proto <- case SockAddr
sockAddr of
(SockAddrUnix {}) -> ProtocolNumber -> IO ProtocolNumber
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ProtocolNumber
0
SockAddr
_ -> String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
#else
proto <- getProtocolNumber "tcp"
#endif
bracketOnError
(socket af Stream proto)
close
(\Socket
sock -> do
Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
sockAddr
Socket -> Int -> IO ()
listen Socket
sock Int
maxListenQueue
Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
)
where
af :: Family
af = case SockAddr
sockAddr of
(SockAddrInet {}) -> Family
AF_INET
(SockAddrInet6 {}) -> Family
AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
(SockAddrUnix {}) -> Family
AF_UNIX
#endif
acidServer' :: (CommChannel -> IO Bool)
-> Socket
-> AcidState st
-> IO ()
acidServer' :: forall st.
(CommChannel -> IO Bool) -> Socket -> AcidState st -> IO ()
acidServer' CommChannel -> IO Bool
checkAuth Socket
listenSocket AcidState st
acidState
= do
let loop :: IO b
loop = IO ThreadId -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO ThreadId -> IO b) -> IO ThreadId -> IO b
forall a b. (a -> b) -> a -> b
$
do (socket, _sockAddr) <- Socket -> IO (Socket, SockAddr)
accept Socket
listenSocket
let commChannel = Socket -> CommChannel
socketToCommChannel Socket
socket
forkIO $ do authorized <- checkAuth commChannel
when authorized $
process commChannel acidState
ccClose commChannel
infi :: IO b
infi = IO ()
forall {b}. IO b
loop IO () -> (Show (ZonkAny 0) => ZonkAny 0 -> IO ()) -> IO ()
forall e. IO () -> (Show e => e -> IO ()) -> IO ()
`catchSome` Show (ZonkAny 0) => ZonkAny 0 -> IO ()
ZonkAny 0 -> IO ()
forall e. Show e => e -> IO ()
logError IO () -> IO b -> IO b
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO b
infi
IO ()
forall {b}. IO b
infi
where
logError :: (Show e) => e -> IO ()
logError :: forall e. Show e => e -> IO ()
logError e
e = Handle -> e -> IO ()
forall a. Show a => Handle -> a -> IO ()
hPrint Handle
stderr e
e
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = IOErrorType -> Bool
isResourceVanishedType (IOErrorType -> Bool)
-> (IOException -> IOErrorType) -> IOException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> IOErrorType
ioeGetErrorType
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType IOErrorType
ResourceVanished = Bool
True
isResourceVanishedType IOErrorType
_ = Bool
False
catchSome :: IO () -> (Show e => e -> IO ()) -> IO ()
catchSome :: forall e. IO () -> (Show e => e -> IO ()) -> IO ()
catchSome IO ()
op Show e => e -> IO ()
_h =
IO ()
op IO () -> [Handler ()] -> IO ()
forall a. IO a -> [Handler a] -> IO a
`catches` [ (IOException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((IOException -> IO ()) -> Handler ())
-> (IOException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(IOException
e :: IOException) ->
if IOException -> Bool
isFullError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isDoesNotExistError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isResourceVanishedError IOException
e
then () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else IOException -> IO ()
forall a e. (HasCallStack, Exception e) => e -> a
throw IOException
e
]
data Command = RunQuery (Tagged Lazy.ByteString)
| RunUpdate (Tagged Lazy.ByteString)
| CreateCheckpoint
| CreateArchive
instance Serialize Command where
put :: Putter Command
put Command
cmd = case Command
cmd of
RunQuery Tagged ByteString
query -> do Putter Word8
putWord8 Word8
0; Putter (Tagged ByteString)
forall t. Serialize t => Putter t
put Tagged ByteString
query
RunUpdate Tagged ByteString
update -> do Putter Word8
putWord8 Word8
1; Putter (Tagged ByteString)
forall t. Serialize t => Putter t
put Tagged ByteString
update
Command
CreateCheckpoint -> Putter Word8
putWord8 Word8
2
Command
CreateArchive -> Putter Word8
putWord8 Word8
3
get :: Get Command
get = do tag <- Get Word8
getWord8
case tag of
Word8
0 -> (Tagged ByteString -> Command)
-> Get (Tagged ByteString) -> Get Command
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Tagged ByteString -> Command
RunQuery Get (Tagged ByteString)
forall t. Serialize t => Get t
get
Word8
1 -> (Tagged ByteString -> Command)
-> Get (Tagged ByteString) -> Get Command
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Tagged ByteString -> Command
RunUpdate Get (Tagged ByteString)
forall t. Serialize t => Get t
get
Word8
2 -> Command -> Get Command
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Command
CreateCheckpoint
Word8
3 -> Command -> Get Command
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Command
CreateArchive
Word8
_ -> String -> Get Command
forall a. HasCallStack => String -> a
error (String -> Get Command) -> String -> Get Command
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: Serialize.get for Command, invalid tag: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> String
forall a. Show a => a -> String
show Word8
tag
data Response = Result Lazy.ByteString | Acknowledgement | ConnectionError
instance Serialize Response where
put :: Putter Response
put Response
resp = case Response
resp of
Result ByteString
result -> do Putter Word8
putWord8 Word8
0; Putter ByteString
forall t. Serialize t => Putter t
put ByteString
result
Response
Acknowledgement -> Putter Word8
putWord8 Word8
1
Response
ConnectionError -> Putter Word8
putWord8 Word8
2
get :: Get Response
get = do tag <- Get Word8
getWord8
case tag of
Word8
0 -> (ByteString -> Response) -> Get ByteString -> Get Response
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ByteString -> Response
Result Get ByteString
forall t. Serialize t => Get t
get
Word8
1 -> Response -> Get Response
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement
Word8
2 -> Response -> Get Response
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Response
ConnectionError
Word8
_ -> String -> Get Response
forall a. HasCallStack => String -> a
error (String -> Get Response) -> String -> Get Response
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: Serialize.get for Response, invalid tag: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> String
forall a. Show a => a -> String
show Word8
tag
process :: CommChannel
-> AcidState st
-> IO ()
process :: forall st. CommChannel -> AcidState st -> IO ()
process CommChannel{IO ()
Int -> IO ByteString
ByteString -> IO ()
ccPut :: CommChannel -> ByteString -> IO ()
ccGetSome :: CommChannel -> Int -> IO ByteString
ccClose :: CommChannel -> IO ()
ccPut :: ByteString -> IO ()
ccGetSome :: Int -> IO ByteString
ccClose :: IO ()
..} AcidState st
acidState
= do chan <- IO (Chan (IO Response))
forall a. IO (Chan a)
newChan
forkIO $ forever $ do response <- join (readChan chan)
ccPut (encode response)
worker chan (runGetPartial get Strict.empty)
where worker :: Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan Result Command
inp
= case Result Command
inp of
Fail String
msg ByteString
_ -> AcidRemoteException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (String -> AcidRemoteException
SerializeError String
msg)
Partial ByteString -> Result Command
cont -> do bs <- Int -> IO ByteString
ccGetSome Int
1024
if Strict.null bs then
return ()
else
worker chan (cont bs)
Done Command
cmd ByteString
rest -> do Chan (IO Response) -> Command -> IO ()
processCommand Chan (IO Response)
chan Command
cmd; Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (Get Command -> ByteString -> Result Command
forall a. Get a -> ByteString -> Result a
runGetPartial Get Command
forall t. Serialize t => Get t
get ByteString
rest)
processCommand :: Chan (IO Response) -> Command -> IO ()
processCommand Chan (IO Response)
chan Command
cmd =
case Command
cmd of
RunQuery Tagged ByteString
query -> do result <- AcidState st -> Tagged ByteString -> IO ByteString
forall st. AcidState st -> Tagged ByteString -> IO ByteString
queryCold AcidState st
acidState Tagged ByteString
query
writeChan chan (return $ Result result)
RunUpdate Tagged ByteString
update -> do result <- AcidState st -> Tagged ByteString -> IO (MVar ByteString)
forall st.
AcidState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleColdUpdate AcidState st
acidState Tagged ByteString
update
writeChan chan (liftM Result $ takeMVar result)
Command
CreateCheckpoint -> do AcidState st -> IO ()
forall st. AcidState st -> IO ()
createCheckpoint AcidState st
acidState
Chan (IO Response) -> IO Response -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (Response -> IO Response
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement)
Command
CreateArchive -> do AcidState st -> IO ()
forall st. AcidState st -> IO ()
createArchive AcidState st
acidState
Chan (IO Response) -> IO Response -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (Response -> IO Response
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement)
data RemoteState st = RemoteState (Command -> IO (MVar Response)) (IO ())
openRemoteState :: IsAcidic st =>
(CommChannel -> IO ())
-> HostName
-> PortNumber
-> IO (AcidState st)
openRemoteState :: forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> String -> PortNumber -> IO (AcidState st)
openRemoteState CommChannel -> IO ()
performAuthorization String
host PortNumber
port =
do he <- String -> IO HostEntry
getHostByName String
host
openRemoteStateSockAddr performAuthorization (SockAddrInet port (hostAddress he))
openRemoteStateSockAddr :: IsAcidic st =>
(CommChannel -> IO ())
-> SockAddr
-> IO (AcidState st)
openRemoteStateSockAddr :: forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> SockAddr -> IO (AcidState st)
openRemoteStateSockAddr CommChannel -> IO ()
performAuthorization SockAddr
sockAddr
= IO (AcidState st) -> IO (AcidState st)
forall a. IO a -> IO a
withSocketsDo (IO (AcidState st) -> IO (AcidState st))
-> IO (AcidState st) -> IO (AcidState st)
forall a b. (a -> b) -> a -> b
$
do IO CommChannel -> IO (AcidState st)
forall st. IsAcidic st => IO CommChannel -> IO (AcidState st)
processRemoteState IO CommChannel
reconnect
where
af :: Family
af :: Family
af = case SockAddr
sockAddr of
(SockAddrInet {}) -> Family
AF_INET
(SockAddrInet6 {}) -> Family
AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
(SockAddrUnix {}) -> Family
AF_UNIX
#endif
reconnect :: IO CommChannel
reconnect :: IO CommChannel
reconnect
= (do String -> IO ()
debugStrLn String
"Reconnecting."
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
proto <- case SockAddr
sockAddr of
(SockAddrUnix {}) -> ProtocolNumber -> IO ProtocolNumber
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ProtocolNumber
0
SockAddr
_ -> String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
#else
proto <- getProtocolNumber "tcp"
#endif
handle <- bracketOnError
(socket af Stream proto)
close
(\Socket
sock -> do
Socket -> SockAddr -> IO ()
connect Socket
sock SockAddr
sockAddr
Socket -> IOMode -> IO Handle
socketToHandle Socket
sock IOMode
ReadWriteMode
)
let cc = Handle -> CommChannel
handleToCommChannel Handle
handle
performAuthorization cc
debugStrLn "Reconnected."
return cc
)
IO CommChannel -> (IOException -> IO CommChannel) -> IO CommChannel
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
((\IOException
_ -> Int -> IO ()
threadDelay Int
1000000 IO () -> IO CommChannel -> IO CommChannel
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO CommChannel
reconnect) :: IOError -> IO CommChannel)
processRemoteState :: IsAcidic st =>
IO CommChannel
-> IO (AcidState st)
processRemoteState :: forall st. IsAcidic st => IO CommChannel -> IO (AcidState st)
processRemoteState IO CommChannel
reconnect
= do cmdQueue <- STM (TQueue (Command, MVar Response))
-> IO (TQueue (Command, MVar Response))
forall a. STM a -> IO a
atomically STM (TQueue (Command, MVar Response))
forall a. STM (TQueue a)
newTQueue
ccTMV <- atomically newEmptyTMVar
isClosed <- newIORef False
let actor :: Command -> IO (MVar Response)
actor Command
command =
do String -> IO ()
debugStrLn String
"actor: begin."
IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
isClosed IO Bool -> (Bool -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> IO () -> IO ()) -> IO () -> Bool -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (AcidRemoteException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO AcidRemoteException
AcidStateClosed)
ref <- IO (MVar Response)
forall a. IO (MVar a)
newEmptyMVar
atomically $ writeTQueue cmdQueue (command, ref)
debugStrLn "actor: end."
return ref
expireQueue TQueue (Response -> IO a)
listenQueue =
do mCallback <- STM (Maybe (Response -> IO a)) -> IO (Maybe (Response -> IO a))
forall a. STM a -> IO a
atomically (STM (Maybe (Response -> IO a)) -> IO (Maybe (Response -> IO a)))
-> STM (Maybe (Response -> IO a)) -> IO (Maybe (Response -> IO a))
forall a b. (a -> b) -> a -> b
$ TQueue (Response -> IO a) -> STM (Maybe (Response -> IO a))
forall a. TQueue a -> STM (Maybe a)
tryReadTQueue TQueue (Response -> IO a)
listenQueue
case mCallback of
Maybe (Response -> IO a)
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Just Response -> IO a
callback) ->
do Response -> IO a
callback Response
ConnectionError
TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO a)
listenQueue
handleReconnect :: SomeException -> IO ()
handleReconnect SomeException
e
= case SomeException -> Maybe AsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e of
(Just AsyncException
ThreadKilled) ->
do String -> IO ()
debugStrLn String
"handleReconnect: ThreadKilled. Not attempting to reconnect."
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Maybe AsyncException
_ ->
do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect begin."
tmv <- STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> IO (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a. STM a -> IO a
atomically (STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> IO (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)))
-> STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> IO (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a b. (a -> b) -> a -> b
$ TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a. TMVar a -> STM (Maybe a)
tryTakeTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
case tmv of
Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
Nothing ->
do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect: error handling already in progress."
String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect end."
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Just (CommChannel
oldCC, TQueue (Response -> IO ())
oldListenQueue, ThreadId
oldListenerTID)) ->
do thisTID <- IO ThreadId
myThreadId
when (thisTID /= oldListenerTID) (killThread oldListenerTID)
ccClose oldCC
expireQueue oldListenQueue
cc <- reconnect
listenQueue <- atomically $ newTQueue
listenerTID <- forkIO $ listener cc listenQueue
atomically $ putTMVar ccTMV (cc, listenQueue, listenerTID)
debugStrLn $ "handleReconnect end."
return ()
listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue
= ByteString -> IO ()
forall {b}. ByteString -> IO b
getResponse ByteString
Strict.empty IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
handleReconnect
where
getResponse :: ByteString -> IO b
getResponse ByteString
leftover =
do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: listening for Response."
let go :: Result Response -> IO ByteString
go Result Response
inp = case Result Response
inp of
Fail String
msg ByteString
_ -> String -> IO ByteString
forall a. HasCallStack => String -> a
error (String -> IO ByteString) -> String -> IO ByteString
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
Partial ByteString -> Result Response
cont -> do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: ccGetSome"
bs <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
go (cont bs)
Done Response
resp ByteString
rest -> do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: getting callback"
callback <- STM (Response -> IO ()) -> IO (Response -> IO ())
forall a. STM a -> IO a
atomically (STM (Response -> IO ()) -> IO (Response -> IO ()))
-> STM (Response -> IO ()) -> IO (Response -> IO ())
forall a b. (a -> b) -> a -> b
$ TQueue (Response -> IO ()) -> STM (Response -> IO ())
forall a. TQueue a -> STM a
readTQueue TQueue (Response -> IO ())
listenQueue
debugStrLn $ "listener: passing Response to callback"
callback (resp :: Response)
return rest
rest <- Result Response -> IO ByteString
go (Get Response -> ByteString -> Result Response
forall a. Get a -> ByteString -> Result a
runGetPartial Get Response
forall t. Serialize t => Get t
get ByteString
leftover)
getResponse rest
actorThread :: IO ()
actorThread = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
do String -> IO ()
debugStrLn String
"actorThread: waiting for something to do."
(cc, cmd) <- STM (CommChannel, Command) -> IO (CommChannel, Command)
forall a. STM a -> IO a
atomically (STM (CommChannel, Command) -> IO (CommChannel, Command))
-> STM (CommChannel, Command) -> IO (CommChannel, Command)
forall a b. (a -> b) -> a -> b
$
do (cmd, ref) <- TQueue (Command, MVar Response) -> STM (Command, MVar Response)
forall a. TQueue a -> STM a
readTQueue TQueue (Command, MVar Response)
cmdQueue
(cc, listenQueue, _) <- readTMVar ccTMV
writeTQueue listenQueue (putMVar ref)
return (cc, cmd)
debugStrLn "actorThread: sending command."
ccPut cc (encode cmd) `catch` handleReconnect
debugStrLn "actorThread: sent."
return ()
shutdown :: ThreadId -> IO ()
shutdown ThreadId
actorTID =
do String -> IO ()
debugStrLn String
"shutdown: update isClosed IORef to True."
IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Bool
isClosed Bool
True
String -> IO ()
debugStrLn String
"shutdown: killing actor thread."
ThreadId -> IO ()
killThread ThreadId
actorTID
String -> IO ()
debugStrLn String
"shutdown: taking ccTMV."
(cc, listenQueue, listenerTID) <- STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> IO (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a. STM a -> IO a
atomically (STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> IO (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> IO (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a b. (a -> b) -> a -> b
$ TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a. TMVar a -> STM a
takeTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
debugStrLn "shutdown: killing listener thread."
killThread listenerTID
debugStrLn "shutdown: expiring listen queue."
expireQueue listenQueue
debugStrLn "shutdown: closing connection."
ccClose cc
return ()
cc <- reconnect
listenQueue <- atomically $ newTQueue
actorTID <- forkIO $ actorThread
listenerTID <- forkIO $ listener cc listenQueue
atomically $ putTMVar ccTMV (cc, listenQueue, listenerTID)
return (toAcidState $ RemoteState actor (shutdown actorTID))
remoteQuery :: QueryEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery :: forall event.
QueryEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery RemoteState (EventState event)
acidState MethodMap (EventState event)
mmap event
event
= do let encoded :: ByteString
encoded = MethodSerialiser event -> event -> ByteString
forall method. MethodSerialiser method -> method -> ByteString
encodeMethod MethodSerialiser event
ms event
event
resp <- RemoteState (EventState event)
-> Tagged ByteString -> IO ByteString
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState (EventState event)
acidState (event -> ByteString
forall ev. Method ev => ev -> ByteString
methodTag event
event, ByteString
encoded)
return (case decodeResult ms resp of
Left String
msg -> String -> MethodResult event
forall a. HasCallStack => String -> a
error (String -> MethodResult event) -> String -> MethodResult event
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
Right MethodResult event
result -> MethodResult event
result)
where
(State (EventState event) (MethodResult event)
_, MethodSerialiser event
ms) = MethodMap (EventState event)
-> event
-> (State (EventState event) (MethodResult event),
MethodSerialiser event)
forall method.
Method method =>
MethodMap (MethodState method)
-> method
-> (State (MethodState method) (MethodResult method),
MethodSerialiser method)
lookupHotMethodAndSerialiser MethodMap (EventState event)
mmap event
event
remoteQueryCold :: RemoteState st -> Tagged Lazy.ByteString -> IO Lazy.ByteString
remoteQueryCold :: forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold rs :: RemoteState st
rs@(RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) Tagged ByteString
event
= do resp <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar (MVar Response -> IO Response) -> IO (MVar Response) -> IO Response
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunQuery Tagged ByteString
event)
case resp of
(Result ByteString
result) -> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
result
Response
ConnectionError -> do String -> IO ()
debugStrLn String
"retrying query event."
RemoteState st -> Tagged ByteString -> IO ByteString
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState st
rs Tagged ByteString
event
Response
Acknowledgement -> String -> IO ByteString
forall a. HasCallStack => String -> a
error String
"Data.Acid.Remote: remoteQueryCold got Acknowledgement. That should never happen."
scheduleRemoteUpdate :: UpdateEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (MVar (EventResult event))
scheduleRemoteUpdate :: forall event.
UpdateEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
scheduleRemoteUpdate (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) MethodMap (EventState event)
mmap event
event
= do let encoded :: ByteString
encoded = MethodSerialiser event -> event -> ByteString
forall method. MethodSerialiser method -> method -> ByteString
encodeMethod MethodSerialiser event
ms event
event
parsed <- IO (MVar (MethodResult event))
forall a. IO (MVar a)
newEmptyMVar
respRef <- fn (RunUpdate (methodTag event, encoded))
forkIO $ do Result resp <- takeMVar respRef
putMVar parsed (case decodeResult ms resp of
Left String
msg -> String -> MethodResult event
forall a. HasCallStack => String -> a
error (String -> MethodResult event) -> String -> MethodResult event
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
Right MethodResult event
result -> MethodResult event
result)
return parsed
where
(State (EventState event) (MethodResult event)
_, MethodSerialiser event
ms) = MethodMap (EventState event)
-> event
-> (State (EventState event) (MethodResult event),
MethodSerialiser event)
forall method.
Method method =>
MethodMap (MethodState method)
-> method
-> (State (MethodState method) (MethodResult method),
MethodSerialiser method)
lookupHotMethodAndSerialiser MethodMap (EventState event)
mmap event
event
scheduleRemoteColdUpdate :: RemoteState st -> Tagged Lazy.ByteString -> IO (MVar Lazy.ByteString)
scheduleRemoteColdUpdate :: forall st.
RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleRemoteColdUpdate (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) Tagged ByteString
event
= do parsed <- IO (MVar ByteString)
forall a. IO (MVar a)
newEmptyMVar
respRef <- fn (RunUpdate event)
forkIO $ do Result resp <- takeMVar respRef
putMVar parsed resp
return parsed
closeRemoteState :: RemoteState st -> IO ()
closeRemoteState :: forall st. RemoteState st -> IO ()
closeRemoteState (RemoteState Command -> IO (MVar Response)
_fn IO ()
shutdown) = IO ()
shutdown
createRemoteCheckpoint :: RemoteState st -> IO ()
createRemoteCheckpoint :: forall st. RemoteState st -> IO ()
createRemoteCheckpoint (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown)
= do Response -> IO ()
Acknowledgement <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar (MVar Response -> IO Response) -> IO (MVar Response) -> IO Response
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn Command
CreateCheckpoint
return ()
createRemoteArchive :: RemoteState st -> IO ()
createRemoteArchive :: forall st. RemoteState st -> IO ()
createRemoteArchive (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown)
= do Response -> IO ()
Acknowledgement <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar (MVar Response -> IO Response) -> IO (MVar Response) -> IO Response
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn Command
CreateArchive
return ()
toAcidState :: forall st . IsAcidic st => RemoteState st -> AcidState st
toAcidState :: forall st. IsAcidic st => RemoteState st -> AcidState st
toAcidState RemoteState st
remote
= AcidState { _scheduleUpdate :: forall event.
(UpdateEvent event, EventState event ~ st) =>
event -> IO (MVar (EventResult event))
_scheduleUpdate = RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
forall event.
UpdateEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
scheduleRemoteUpdate RemoteState st
RemoteState (EventState event)
remote MethodMap st
MethodMap (EventState event)
mmap
, scheduleColdUpdate :: Tagged ByteString -> IO (MVar ByteString)
scheduleColdUpdate = RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
forall st.
RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleRemoteColdUpdate RemoteState st
remote
, _query :: forall event.
(QueryEvent event, EventState event ~ st) =>
event -> IO (EventResult event)
_query = RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
forall event.
QueryEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery RemoteState st
RemoteState (EventState event)
remote MethodMap st
MethodMap (EventState event)
mmap
, queryCold :: Tagged ByteString -> IO ByteString
queryCold = RemoteState st -> Tagged ByteString -> IO ByteString
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState st
remote
, createCheckpoint :: IO ()
createCheckpoint = RemoteState st -> IO ()
forall st. RemoteState st -> IO ()
createRemoteCheckpoint RemoteState st
remote
, createArchive :: IO ()
createArchive = RemoteState st -> IO ()
forall st. RemoteState st -> IO ()
createRemoteArchive RemoteState st
remote
, closeAcidState :: IO ()
closeAcidState = RemoteState st -> IO ()
forall st. RemoteState st -> IO ()
closeRemoteState RemoteState st
remote
, acidSubState :: AnyState st
acidSubState = RemoteState st -> AnyState st
forall (sub_st :: * -> *) st.
Typeable sub_st =>
sub_st st -> AnyState st
mkAnyState RemoteState st
remote
}
where
mmap :: MethodMap st
mmap :: MethodMap st
mmap = [MethodContainer st] -> MethodMap st
forall st. [MethodContainer st] -> MethodMap st
mkMethodMap ([Event st] -> [MethodContainer st]
forall st. [Event st] -> [MethodContainer st]
eventsToMethods [Event st]
forall st. IsAcidic st => [Event st]
acidEvents)