module Network.Connection
    (
    
      Connection
    , connectionID
    , ConnectionParams(..)
    , TLSSettings(..)
    , ProxySettings(..)
    , SockSettings
    
    , LineTooLong(..)
    , HostNotResolved(..)
    , HostCannotConnect(..)
    
    , initConnectionContext
    , ConnectionContext
    
    , connectFromHandle
    , connectFromSocket
    , connectTo
    , connectionClose
    
    , connectionGet
    , connectionGetExact
    , connectionGetChunk
    , connectionGetChunk'
    , connectionGetLine
    , connectionWaitForInput
    , connectionPut
    
    , connectionSetSecure
    , connectionIsSecure
    ) where
import Control.Applicative
import Control.Concurrent.MVar
import Control.Monad (join)
import qualified Control.Exception as E
import qualified System.IO.Error as E (mkIOError, eofErrorType)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLS
import System.X509 (getSystemCertificateStore)
import Network.Socks5
import qualified Network as N
import Network.Socket
import Network.BSD (getProtocolNumber)
import qualified Network.Socket as N (close)
import qualified Network.Socket.ByteString as N
import Data.Default.Class
import Data.Data
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as L
import System.Environment
import System.Timeout
import System.IO
import qualified Data.Map as M
import Network.Connection.Types
type Manager = MVar (M.Map TLS.SessionID TLS.SessionData)
data LineTooLong = LineTooLong deriving (Show,Typeable)
data HostNotResolved = HostNotResolved String deriving (Show,Typeable)
data HostCannotConnect = HostCannotConnect String [E.IOException] deriving (Show,Typeable)
instance E.Exception LineTooLong
instance E.Exception HostNotResolved
instance E.Exception HostCannotConnect
connectionSessionManager :: Manager -> TLS.SessionManager
connectionSessionManager mvar = TLS.SessionManager
    { TLS.sessionResume     = \sessionID -> withMVar mvar (return . M.lookup sessionID)
    , TLS.sessionEstablish  = \sessionID sessionData ->
                               modifyMVar_ mvar (return . M.insert sessionID sessionData)
    , TLS.sessionInvalidate = \sessionID -> modifyMVar_ mvar (return . M.delete sessionID)
    }
initConnectionContext :: IO ConnectionContext
initConnectionContext = ConnectionContext <$> getSystemCertificateStore
makeTLSParams :: ConnectionContext -> ConnectionID -> TLSSettings -> TLS.ClientParams
makeTLSParams cg cid ts@(TLSSettingsSimple {}) =
    (TLS.defaultParamsClient (fst cid) portString)
        { TLS.clientSupported = def { TLS.supportedCiphers = TLS.ciphersuite_all }
        , TLS.clientShared    = def
            { TLS.sharedCAStore         = globalCertificateStore cg
            , TLS.sharedValidationCache = validationCache
            
            }
        }
  where validationCache
            | settingDisableCertificateValidation ts =
                TLS.ValidationCache (\_ _ _ -> return TLS.ValidationCachePass)
                                    (\_ _ _ -> return ())
            | otherwise = def
        portString = BC.pack $ show $ snd cid
makeTLSParams _ cid (TLSSettings p) =
    p { TLS.clientServerIdentification = (fst cid, portString) }
 where portString = BC.pack $ show $ snd cid
withBackend :: (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend f conn = readMVar (connectionBackend conn) >>= f
connectionNew :: ConnectionID -> ConnectionBackend -> IO Connection
connectionNew cid backend =
    Connection <$> newMVar backend
               <*> newMVar (Just B.empty)
               <*> pure cid
connectFromHandle :: ConnectionContext
                  -> Handle
                  -> ConnectionParams
                  -> IO Connection
connectFromHandle cg h p = withSecurity (connectionUseSecure p)
    where withSecurity Nothing            = connectionNew cid $ ConnectionStream h
          withSecurity (Just tlsSettings) = tlsEstablish h (makeTLSParams cg cid tlsSettings) >>= connectionNew cid . ConnectionTLS
          cid = (connectionHostname p, connectionPort p)
connectFromSocket :: ConnectionContext
                  -> Socket
                  -> ConnectionParams
                  -> IO Connection
connectFromSocket cg sock p = withSecurity (connectionUseSecure p)
    where withSecurity Nothing            = connectionNew cid $ ConnectionSocket sock
          withSecurity (Just tlsSettings) = tlsEstablish sock (makeTLSParams cg cid tlsSettings) >>= connectionNew cid . ConnectionTLS
          cid = (connectionHostname p, connectionPort p)
connectTo :: ConnectionContext 
          -> ConnectionParams  
          -> IO Connection     
connectTo cg cParams = do
    conFct <- getConFct (connectionUseSocks cParams)
    let doConnect = conFct (connectionHostname cParams) (N.PortNumber $ connectionPort cParams)
    E.bracketOnError doConnect N.close $ \h->
        connectFromSocket cg h cParams
  where
        getConFct Nothing                            = return resolve'
        getConFct (Just (OtherProxy h p))            = return $ \_ _ -> resolve' h (N.PortNumber p)
        getConFct (Just (SockSettingsSimple h p))    = return $ socksConnectTo' h (N.PortNumber p)
        getConFct (Just (SockSettingsEnvironment v)) = do
            
            
            let name = maybe "SOCKS_SERVER" id v
            evar <- E.try (getEnv name)
            case evar of
                Left (_ :: E.IOException) -> return resolve'
                Right var                 ->
                    case parseSocks var of
                        Nothing             -> return resolve'
                        Just (sHost, sPort) -> return $ socksConnectTo' sHost (N.PortNumber $ fromIntegral (sPort :: Int))
        
        parseSocks s =
            case break (== ':') s of
                (sHost, "")        -> Just (sHost, 1080)
                (sHost, ':':portS) ->
                    case reads portS of
                        [(sPort,"")] -> Just (sHost, sPort)
                        _            -> Nothing
                _                  -> Nothing
        resolve' host portid = do
            let serv = case portid of
                            N.Service serv -> serv
                            N.PortNumber n -> show n
                            _              -> error "cannot resolve service" 
            proto <- getProtocolNumber "tcp"
            let hints = defaultHints { addrFlags = [AI_ADDRCONFIG]
                                     , addrProtocol = proto
                                     , addrSocketType = Stream }
            addrs <- getAddrInfo (Just hints) (Just host) (Just serv)
            firstSuccessful $ map tryToConnect addrs
          where
            tryToConnect addr =
                E.bracketOnError
                    (socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
                    (N.close)
                    (\sock -> connect sock (addrAddress addr) >> return sock)
            firstSuccessful = go []
              where
                go :: [E.IOException] -> [IO a] -> IO a
                go []      [] = E.throwIO $ HostNotResolved host
                go l@(_:_) [] = E.throwIO $ HostCannotConnect host l
                go acc     (act:followingActs) = do
                    er <- E.try act
                    case er of
                        Left err -> go (err:acc) followingActs
                        Right r  -> return r
connectionPut :: Connection -> ByteString -> IO ()
connectionPut connection content = withBackend doWrite connection
    where doWrite (ConnectionStream h) = B.hPut h content >> hFlush h
          doWrite (ConnectionSocket s) = N.sendAll s content
          doWrite (ConnectionTLS ctx)  = TLS.sendData ctx $ L.fromChunks [content]
connectionGetExact :: Connection -> Int -> IO ByteString
connectionGetExact conn x = loop B.empty 0
  where loop bs y
          | y == x = return bs
          | otherwise = do
            next <- connectionGet conn (x  y)
            loop (B.append bs next) (y + (B.length next))
connectionGet :: Connection -> Int -> IO ByteString
connectionGet conn size
  | size < 0  = fail "Network.Connection.connectionGet: size < 0"
  | size == 0 = return B.empty
  | otherwise = connectionGetChunkBase "connectionGet" conn $ B.splitAt size
connectionGetChunk :: Connection -> IO ByteString
connectionGetChunk conn =
    connectionGetChunkBase "connectionGetChunk" conn $ \s -> (s, B.empty)
connectionGetChunk' :: Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunk' = connectionGetChunkBase "connectionGetChunk'"
connectionWaitForInput :: Connection -> Int -> IO Bool
connectionWaitForInput conn timeout_ms = maybe False (const True) <$> timeout timeout_ns tryGetChunk
  where tryGetChunk = connectionGetChunkBase "connectionWaitForInput" conn $ \buf -> ((), buf)
        timeout_ns  = timeout_ms * 1000
connectionGetChunkBase :: String -> Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunkBase loc conn f =
    modifyMVar (connectionBuffer conn) $ \m ->
        case m of
            Nothing -> throwEOF conn loc
            Just buf
              | B.null buf -> do
                  chunk <- withBackend getMoreData conn
                  if B.null chunk
                     then closeBuf chunk
                     else updateBuf chunk
              | otherwise ->
                  updateBuf buf
  where
    getMoreData (ConnectionTLS tlsctx) = TLS.recvData tlsctx
    getMoreData (ConnectionSocket sock) = N.recv sock 1500
    getMoreData (ConnectionStream h)   = B.hGetSome h (16 * 1024)
    updateBuf buf = case f buf of (a, !buf') -> return (Just buf', a)
    closeBuf  buf = case f buf of (a, _buf') -> return (Nothing, a)
connectionGetLine :: Int           
                  -> Connection    
                  -> IO ByteString 
connectionGetLine limit conn = more (throwEOF conn loc) 0 id
  where
    loc = "connectionGetLine"
    lineTooLong = E.throwIO LineTooLong
    
    
    more eofK !currentSz !dl =
        getChunk (\s -> let len = B.length s
                         in if currentSz + len > limit
                               then lineTooLong
                               else more eofK (currentSz + len) (dl . (s:)))
                 (\s -> done (dl . (s:)))
                 (done dl)
    done :: ([ByteString] -> [ByteString]) -> IO ByteString
    done dl = return $! B.concat $ dl []
    
    getChunk :: (ByteString -> IO r) 
             -> (ByteString -> IO r) 
             -> IO r                 
             -> IO r
    getChunk moreK doneK eofK =
      join $ connectionGetChunkBase loc conn $ \s ->
        if B.null s
          then (eofK, B.empty)
          else case B.break (== 10) s of
                 (a, b)
                   | B.null b  -> (moreK a, B.empty)
                   | otherwise -> (doneK a, B.tail b)
throwEOF :: Connection -> String -> IO a
throwEOF conn loc =
    E.throwIO $ E.mkIOError E.eofErrorType loc' Nothing (Just path)
  where
    loc' = "Network.Connection." ++ loc
    path = let (host, port) = connectionID conn
            in host ++ ":" ++ show port
connectionClose :: Connection -> IO ()
connectionClose = withBackend backendClose
    where backendClose (ConnectionTLS ctx)  = ignoreIOExc (TLS.bye ctx) `E.finally` TLS.contextClose ctx
          backendClose (ConnectionSocket sock) = N.close sock
          backendClose (ConnectionStream h) = hClose h
          ignoreIOExc action = action `E.catch` \(_ :: E.IOException) -> return ()
connectionSetSecure :: ConnectionContext
                    -> Connection
                    -> TLSSettings
                    -> IO ()
connectionSetSecure cg connection params =
    modifyMVar_ (connectionBuffer connection) $ \b ->
    modifyMVar (connectionBackend connection) $ \backend ->
        case backend of
            (ConnectionStream h) -> do ctx <- tlsEstablish h (makeTLSParams cg (connectionID connection) params)
                                       return (ConnectionTLS ctx, Just B.empty)
            (ConnectionSocket s) -> do ctx <- tlsEstablish s (makeTLSParams cg (connectionID connection) params)
                                       return (ConnectionTLS ctx, Just B.empty)
            (ConnectionTLS _)    -> return (backend, b)
connectionIsSecure :: Connection -> IO Bool
connectionIsSecure conn = withBackend isSecure conn
    where isSecure (ConnectionStream _) = return False
          isSecure (ConnectionSocket _) = return False
          isSecure (ConnectionTLS _)    = return True
tlsEstablish :: TLS.HasBackend backend => backend -> TLS.ClientParams -> IO TLS.Context
tlsEstablish handle tlsParams = do
    ctx <- TLS.contextNew handle tlsParams
    TLS.handshake ctx
    return ctx