{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}

module MCP.Server.Transport.Stdio
  ( -- * STDIO Transport
    transportRunStdio
  ) where

import           Control.Monad          (when)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Data.Aeson
import qualified Data.ByteString.Lazy   as BSL
import qualified Data.Text              as T
import qualified Data.Text.Encoding     as TE
import qualified Data.Text.IO           as TIO
import           System.IO              (hFlush, hPutStrLn, stderr, stdout)

import           MCP.Server.Handlers
import           MCP.Server.JsonRpc
import           MCP.Server.Types


-- | Transport-specific implementation for STDIO
import           System.IO              (hSetEncoding, utf8)

transportRunStdio :: (MonadIO m) => McpServerInfo -> McpServerHandlers m -> m ()
transportRunStdio :: forall (m :: * -> *).
MonadIO m =>
McpServerInfo -> McpServerHandlers m -> m ()
transportRunStdio McpServerInfo
serverInfo McpServerHandlers m
handlers = do
  -- Ensure UTF-8 encoding for all handles
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TextEncoding -> IO ()
hSetEncoding Handle
stderr TextEncoding
utf8
    Handle -> TextEncoding -> IO ()
hSetEncoding Handle
stdout TextEncoding
utf8
  m ()
loop
  where
    loop :: m ()
loop = do
      input <- IO Text -> m Text
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
TIO.getLine
      when (not $ T.null $ T.strip input) $ do
        -- Use TIO.hPutStrLn for UTF-8 output
        liftIO $ TIO.hPutStrLn stderr $ "Received request: " <> input
        case eitherDecode (BSL.fromStrict $ TE.encodeUtf8 input) of
          Left String
err -> IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> Text -> IO ()
TIO.hPutStrLn Handle
stderr (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Parse error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
err
          Right Value
jsonValue -> do
            case Value -> Either String JsonRpcMessage
parseJsonRpcMessage Value
jsonValue of
              Left String
err -> IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> Text -> IO ()
TIO.hPutStrLn Handle
stderr (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"JSON-RPC parse error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
err
              Right JsonRpcMessage
message -> do
                IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> Text -> IO ()
TIO.hPutStrLn Handle
stderr (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Processing message: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (String -> String
forall a. Show a => a -> String
show (JsonRpcMessage -> String
getMessageSummary JsonRpcMessage
message))
                response <- McpServerInfo
-> McpServerHandlers m
-> JsonRpcMessage
-> m (Maybe JsonRpcMessage)
forall (m :: * -> *).
MonadIO m =>
McpServerInfo
-> McpServerHandlers m
-> JsonRpcMessage
-> m (Maybe JsonRpcMessage)
handleMcpMessage McpServerInfo
serverInfo McpServerHandlers m
handlers JsonRpcMessage
message
                case response of
                  Just JsonRpcMessage
responseMsg -> do
                    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> Text -> IO ()
TIO.hPutStrLn Handle
stderr (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Sending response for: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (String -> String
forall a. Show a => a -> String
show (JsonRpcMessage -> String
getMessageSummary JsonRpcMessage
message))
                    let responseText :: Text
responseText = StrictByteString -> Text
TE.decodeUtf8 (StrictByteString -> Text) -> StrictByteString -> Text
forall a b. (a -> b) -> a -> b
$ ByteString -> StrictByteString
BSL.toStrict (ByteString -> StrictByteString) -> ByteString -> StrictByteString
forall a b. (a -> b) -> a -> b
$ Value -> ByteString
forall a. ToJSON a => a -> ByteString
encode (Value -> ByteString) -> Value -> ByteString
forall a b. (a -> b) -> a -> b
$ JsonRpcMessage -> Value
encodeJsonRpcMessage JsonRpcMessage
responseMsg
                    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Text -> IO ()
TIO.putStrLn Text
responseText
                    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
hFlush Handle
stdout
                  Maybe JsonRpcMessage
Nothing -> IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> Text -> IO ()
TIO.hPutStrLn Handle
stderr (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"No response needed for: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (String -> String
forall a. Show a => a -> String
show (JsonRpcMessage -> String
getMessageSummary JsonRpcMessage
message))
        loop