{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-} 

{- |
Module      :  Neovim.RPC.Common
Description :  Common functons for the RPC module
Copyright   :  (c) Sebastian Witte
License     :  Apache-2.0

Maintainer  :  woozletoff@gmail.com
Stability   :  experimental
-}
module Neovim.RPC.Common where

import Neovim.OS (getSocketUnix)

import Control.Applicative (Alternative ((<|>)))
import Control.Monad (unless)
import Data.Int (Int64)
import Data.Map (Map)
import Data.MessagePack (Object)
import Data.Streaming.Network (getSocketTCP)
import Data.String (IsString (fromString))
import Data.Time (UTCTime)
import Neovim.Compat.Megaparsec as P (
    MonadParsec (eof, try),
    Parser,
    anySingle,
    anySingleBut,
    many,
    parse,
    single,
    some,
 )
import Network.Socket as N (socketToHandle)
import System.Log.Logger (errorM, warningM)

import Data.List (intercalate)
import qualified Data.List as List
import Data.Maybe (catMaybes)
import qualified Text.Megaparsec.Char.Lexer as L
import UnliftIO.Environment (lookupEnv)
import UnliftIO

import Prelude

-- | Things shared between the socket reader and the event handler.
data RPCConfig = RPCConfig
    { RPCConfig
-> TVar (Map Int64 (UTCTime, TMVar (Either Object Object)))
recipients :: TVar (Map Int64 (UTCTime, TMVar (Either Object Object)))
    -- ^ A map from message identifiers (as per RPC spec) to a tuple with a
    -- timestamp and a 'TMVar' that is used to communicate the result back to
    -- the calling thread.
    , RPCConfig -> TVar Int64
nextMessageId :: TVar Int64
    -- ^ Message identifier for the next message as per RPC spec.
    }

{- | Create a new basic configuration containing a communication channel for
 remote procedure call events and an empty lookup table for functions to
 mediate.
-}
newRPCConfig :: (Applicative io, MonadUnliftIO io) => io RPCConfig
newRPCConfig :: forall (io :: * -> *).
(Applicative io, MonadUnliftIO io) =>
io RPCConfig
newRPCConfig =
    TVar (Map Int64 (UTCTime, TMVar (Either Object Object)))
-> TVar Int64 -> RPCConfig
RPCConfig
        (TVar (Map Int64 (UTCTime, TMVar (Either Object Object)))
 -> TVar Int64 -> RPCConfig)
-> io (TVar (Map Int64 (UTCTime, TMVar (Either Object Object))))
-> io (TVar Int64 -> RPCConfig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (TVar (Map Int64 (UTCTime, TMVar (Either Object Object))))
-> io (TVar (Map Int64 (UTCTime, TMVar (Either Object Object))))
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Map Int64 (UTCTime, TMVar (Either Object Object))
-> IO (TVar (Map Int64 (UTCTime, TMVar (Either Object Object))))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Map Int64 (UTCTime, TMVar (Either Object Object))
forall a. Monoid a => a
mempty)
        io (TVar Int64 -> RPCConfig) -> io (TVar Int64) -> io RPCConfig
forall a b. io (a -> b) -> io a -> io b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (TVar Int64) -> io (TVar Int64)
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Int64 -> IO (TVar Int64)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int64
1)

-- | Simple data type defining the kind of socket the socket reader should use.
data SocketType
    = -- | Use the handle for receiving msgpack-rpc messages. This is
      -- suitable for an embedded neovim which is used in test cases.
      Stdout Handle
    | -- | Read the connection information from the environment
      -- variable @NVIM@.
      Environment
    | -- | Use a unix socket.
      UnixSocket FilePath
    | -- | Use an IP socket. First argument is the port and the
      -- second is the host name.
      TCP Int String

{- | Create a 'Handle' from the given socket description.

 The handle is not automatically closed.
-}
createHandle ::
    (Functor io, MonadUnliftIO io) =>
    SocketType ->
    io Handle
createHandle :: forall (io :: * -> *).
(Functor io, MonadUnliftIO io) =>
SocketType -> io Handle
createHandle = \case
    Stdout Handle
h -> do
        IO () -> io ()
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> io ()) -> IO () -> io ()
forall a b. (a -> b) -> a -> b
$ Handle -> BufferMode -> IO ()
forall (m :: * -> *). MonadIO m => Handle -> BufferMode -> m ()
hSetBuffering Handle
h (Maybe Int -> BufferMode
BlockBuffering Maybe Int
forall a. Maybe a
Nothing)
        Handle -> io Handle
forall a. a -> io a
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h
    UnixSocket String
f ->
        IO Handle -> io Handle
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Handle -> io Handle) -> IO Handle -> io Handle
forall a b. (a -> b) -> a -> b
$ SocketType -> IO Handle
forall (io :: * -> *).
(Functor io, MonadUnliftIO io) =>
SocketType -> io Handle
createHandle (SocketType -> IO Handle)
-> (Handle -> SocketType) -> Handle -> IO Handle
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> SocketType
Stdout (Handle -> IO Handle) -> IO Handle -> IO Handle
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Socket -> IOMode -> IO Handle) -> IOMode -> Socket -> IO Handle
forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> IOMode -> IO Handle
socketToHandle IOMode
ReadWriteMode (Socket -> IO Handle) -> IO Socket -> IO Handle
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> IO Socket
getSocketUnix String
f
    TCP Int
p String
h ->
        SocketType -> io Handle
forall (io :: * -> *).
(Functor io, MonadUnliftIO io) =>
SocketType -> io Handle
createHandle (SocketType -> io Handle)
-> (Handle -> SocketType) -> Handle -> io Handle
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> SocketType
Stdout (Handle -> io Handle) -> io Handle -> io Handle
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> String -> io Handle
forall (io :: * -> *).
MonadUnliftIO io =>
Int -> String -> io Handle
createTCPSocketHandle Int
p String
h
    SocketType
Environment ->
        SocketType -> io Handle
forall (io :: * -> *).
(Functor io, MonadUnliftIO io) =>
SocketType -> io Handle
createHandle (SocketType -> io Handle)
-> (Handle -> SocketType) -> Handle -> io Handle
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> SocketType
Stdout (Handle -> io Handle) -> io Handle -> io Handle
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< io Handle
createSocketHandleFromEnvironment
  where
    createTCPSocketHandle :: (MonadUnliftIO io) => Int -> String -> io Handle
    createTCPSocketHandle :: forall (io :: * -> *).
MonadUnliftIO io =>
Int -> String -> io Handle
createTCPSocketHandle Int
p String
h =
        IO Handle -> io Handle
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Handle -> io Handle) -> IO Handle -> io Handle
forall a b. (a -> b) -> a -> b
$
            ByteString -> Int -> IO (Socket, SockAddr)
getSocketTCP (String -> ByteString
forall a. IsString a => String -> a
fromString String
h) Int
p
                IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO Handle) -> IO Handle
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Socket -> IOMode -> IO Handle) -> IOMode -> Socket -> IO Handle
forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> IOMode -> IO Handle
socketToHandle IOMode
ReadWriteMode (Socket -> IO Handle)
-> ((Socket, SockAddr) -> Socket)
-> (Socket, SockAddr)
-> IO Handle
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst

    createSocketHandleFromEnvironment :: io Handle
createSocketHandleFromEnvironment = IO Handle -> io Handle
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Handle -> io Handle) -> IO Handle -> io Handle
forall a b. (a -> b) -> a -> b
$ do
        -- NVIM_LISTEN_ADDRESS is for backwards compatibility
        [String]
envValues <- [Maybe String] -> [String]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe String] -> [String]) -> IO [Maybe String] -> IO [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> IO (Maybe String)) -> [String] -> IO [Maybe String]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM String -> IO (Maybe String)
forall (m :: * -> *). MonadIO m => String -> m (Maybe String)
lookupEnv [String
"NVIM", String
"NVIM_LISTEN_ADDRESS"]
        [SocketType]
listenAdresses <- (String -> IO SocketType) -> [String] -> IO [SocketType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM String -> IO SocketType
forall (m :: * -> *). MonadFail m => String -> m SocketType
parseNvimEnvironmentVariable [String]
envValues
        case [SocketType]
listenAdresses of
            (SocketType
s : [SocketType]
_) -> SocketType -> IO Handle
forall (io :: * -> *).
(Functor io, MonadUnliftIO io) =>
SocketType -> io Handle
createHandle SocketType
s
            [SocketType]
_ -> do
                let errMsg :: String
errMsg =
                        [String] -> String
unlines
                            [ String
"Unhandled socket type from environment variable: "
                            , String
"\t" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
envValues
                            ]
                IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"createHandle" String
errMsg
                String -> IO Handle
forall a. HasCallStack => String -> a
error String
errMsg

parseNvimEnvironmentVariable :: MonadFail m => String -> m SocketType
parseNvimEnvironmentVariable :: forall (m :: * -> *). MonadFail m => String -> m SocketType
parseNvimEnvironmentVariable String
envValue =
    (ParseErrorBundle String Void -> m SocketType)
-> (SocketType -> m SocketType)
-> Either (ParseErrorBundle String Void) SocketType
-> m SocketType
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> m SocketType
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m SocketType)
-> (ParseErrorBundle String Void -> String)
-> ParseErrorBundle String Void
-> m SocketType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseErrorBundle String Void -> String
forall a. Show a => a -> String
show) SocketType -> m SocketType
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (ParseErrorBundle String Void) SocketType -> m SocketType)
-> Either (ParseErrorBundle String Void) SocketType -> m SocketType
forall a b. (a -> b) -> a -> b
$ Parsec Void String SocketType
-> String
-> String
-> Either (ParseErrorBundle String Void) SocketType
forall e s a.
Parsec e s a -> String -> s -> Either (ParseErrorBundle s e) a
parse (Parsec Void String SocketType -> Parsec Void String SocketType
forall a.
ParsecT Void String Identity a -> ParsecT Void String Identity a
forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
P.try Parsec Void String SocketType
pTcpAddress Parsec Void String SocketType
-> Parsec Void String SocketType -> Parsec Void String SocketType
forall a.
ParsecT Void String Identity a
-> ParsecT Void String Identity a -> ParsecT Void String Identity a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parsec Void String SocketType
pUnixSocket) String
envValue String
envValue

pUnixSocket :: P.Parser SocketType
pUnixSocket :: Parsec Void String SocketType
pUnixSocket = String -> SocketType
UnixSocket (String -> SocketType)
-> ParsecT Void String Identity String
-> Parsec Void String SocketType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ParsecT Void String Identity Char
-> ParsecT Void String Identity String
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
P.some ParsecT Void String Identity Char
ParsecT Void String Identity (Token String)
forall e s (m :: * -> *). MonadParsec e s m => m (Token s)
anySingle Parsec Void String SocketType
-> ParsecT Void String Identity () -> Parsec Void String SocketType
forall a b.
ParsecT Void String Identity a
-> ParsecT Void String Identity b -> ParsecT Void String Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ParsecT Void String Identity ()
forall e s (m :: * -> *). MonadParsec e s m => m ()
P.eof

pTcpAddress :: P.Parser SocketType
pTcpAddress :: Parsec Void String SocketType
pTcpAddress = do
    [String]
prefixes <- ParsecT Void String Identity String
-> ParsecT Void String Identity [String]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
P.some (ParsecT Void String Identity String
-> ParsecT Void String Identity String
forall a.
ParsecT Void String Identity a -> ParsecT Void String Identity a
forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
P.try (ParsecT Void String Identity Char
-> ParsecT Void String Identity String
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
P.many (Token String -> ParsecT Void String Identity (Token String)
forall e s (m :: * -> *).
MonadParsec e s m =>
Token s -> m (Token s)
P.anySingleBut Char
Token String
':') ParsecT Void String Identity String
-> ParsecT Void String Identity (Token String)
-> ParsecT Void String Identity String
forall a b.
ParsecT Void String Identity a
-> ParsecT Void String Identity b -> ParsecT Void String Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Token String -> ParsecT Void String Identity (Token String)
forall e s (m :: * -> *).
MonadParsec e s m =>
Token s -> m (Token s)
P.single Char
Token String
':'))
    Int
port <- ParsecT Void String Identity ()
-> ParsecT Void String Identity Int
-> ParsecT Void String Identity Int
forall e s (m :: * -> *) a. MonadParsec e s m => m () -> m a -> m a
L.lexeme ParsecT Void String Identity ()
forall e s (m :: * -> *). MonadParsec e s m => m ()
P.eof ParsecT Void String Identity Int
forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m a
L.decimal
    ParsecT Void String Identity ()
forall e s (m :: * -> *). MonadParsec e s m => m ()
P.eof
    SocketType -> Parsec Void String SocketType
forall a. a -> ParsecT Void String Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SocketType -> Parsec Void String SocketType)
-> SocketType -> Parsec Void String SocketType
forall a b. (a -> b) -> a -> b
$ Int -> String -> SocketType
TCP Int
port (String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
List.intercalate String
":" [String]
prefixes)

{- | Close the handle and print a warning if the conduit chain has been
 interrupted prematurely.
-}
cleanUpHandle :: (MonadUnliftIO io) => Handle -> Bool -> io ()
cleanUpHandle :: forall (io :: * -> *). MonadUnliftIO io => Handle -> Bool -> io ()
cleanUpHandle Handle
h Bool
completed = IO () -> io ()
forall a. IO a -> io a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> io ()) -> IO () -> io ()
forall a b. (a -> b) -> a -> b
$ do
    Handle -> IO ()
forall (m :: * -> *). MonadIO m => Handle -> m ()
hClose Handle
h
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
completed (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> String -> IO ()
warningM String
"cleanUpHandle" String
"Cleanup called on uncompleted handle."