{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE CPP                #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MagicHash          #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# OPTIONS_HADDOCK hide, not-home #-}
module Network.Http.Inconvenience (
    URL,
    modifyContextSSL,
    establishConnection,
    get,
    post,
    postForm,
    encodedFormBody,
    put,
    baselineContextSSL,
    concatHandler',
    TooManyRedirects(..),
    HttpClientError(..),
    ConnectionAddress(..),
    connectionAddressFromURI,
    connectionAddressFromURL,
    openConnectionAddress,
        
    splitURI
) where
import Blaze.ByteString.Builder (Builder)
import qualified Blaze.ByteString.Builder as Builder (fromByteString,
                                                      fromWord8, toByteString)
import qualified Blaze.ByteString.Builder.Char8 as Builder (fromString)
import Control.Exception (Exception, bracket, throw)
import Control.Monad (when, unless)
import Data.Bits (Bits (..))
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import Data.ByteString.Internal (c2w, w2c)
import Data.Char (intToDigit, toLower, digitToInt, isHexDigit)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.List (intersperse)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Typeable (Typeable)
import Data.Word (Word16)
import GHC.Exts (Int(..),word2Int#, uncheckedShiftRL#)
import GHC.Word (Word8 (..))
import Network.URI (URI (..), URIAuth (..), isAbsoluteURI,
                    parseRelativeReference, parseURI, uriToString)
import OpenSSL (withOpenSSL)
import OpenSSL.Session (SSLContext)
import qualified OpenSSL.Session as SSL
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
import System.IO.Unsafe (unsafePerformIO)
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid (Monoid (..), mappend)
#endif
import Network.Http.Connection
import Network.Http.RequestBuilder
import Network.Http.Types
#if defined(linux_HOST_OS) || defined(freebsd_HOST_OS)
import System.Directory (doesDirectoryExist)
#endif
type URL = ByteString
unescBytes :: [Char] -> ByteString
unescBytes = S.pack . go
  where
    go [] = []
    go ('%':h:l:rest)
      | isHexDigit h, isHexDigit l = toEnum b : go rest
      where
        b = (16 * digitToInt h) + digitToInt l
    go (c:rest) = c : go rest
urlEncode :: ByteString -> URL
urlEncode = Builder.toByteString . urlEncodeBuilder
{-# INLINE urlEncode #-}
urlEncodeBuilder :: ByteString -> Builder
urlEncodeBuilder = go mempty
  where
    go !b !s = maybe b' esc (S.uncons y)
      where
        (x,y)     = S.span (flip Set.member urlEncodeTable) s
        b'        = b `mappend` Builder.fromByteString x
        esc (c,r) = let b'' = if c == ' '
                                then b' `mappend` Builder.fromWord8 (c2w '+')
                                else b' `mappend` hexd c
                    in go b'' r
hexd :: Char -> Builder
hexd c0 = Builder.fromWord8 (c2w '%') `mappend` Builder.fromWord8 hi
                                      `mappend` Builder.fromWord8 low
  where
    !c        = c2w c0
    toDigit   = c2w . intToDigit
    !low      = toDigit $ fromEnum $ c .&. 0xf
    !hi       = toDigit $ (c .&. 0xf0) `shiftr` 4
    shiftr (W8# a#) (I# b#) = I# (word2Int# (uncheckedShiftRL# a# b#))
urlEncodeTable :: Set Char
urlEncodeTable = Set.fromList $! filter f $! map w2c [0..255]
  where
    f c | c >= 'A' && c <= 'Z' = True
        | c >= 'a' && c <= 'z' = True
        | c >= '0' && c <= '9' = True
    f c = c `elem` ("$-_.!~*'(),"::String)
global :: IORef SSLContext
global = unsafePerformIO $ do
    ctx <- baselineContextSSL
    newIORef ctx
{-# NOINLINE global #-}
modifyContextSSL :: (SSLContext -> IO SSLContext) -> IO ()
modifyContextSSL f = do
    ctx <- readIORef global
    ctx' <- f ctx
    writeIORef global ctx'
establishConnection :: URL -> IO (Connection)
establishConnection r' = do
    establish u
  where
    u = parseURL r'
{-# INLINE establishConnection #-}
establish :: URI -> IO (Connection)
establish u =
    case scheme of
        "http:"  -> do
                        openConnection host port
        "https:" -> do
                        ctx <- readIORef global
                        openConnectionSSL ctx host ports
        "unix:"  -> do
                        openConnectionUnix $ uriPath u
        _        -> error ("Unknown URI scheme " ++ scheme)
  where
    scheme = uriScheme u
    auth = case uriAuthority u of
        Just x  -> x
        Nothing -> URIAuth "" "localhost" ""
    host = S.pack (uriRegName auth)
    port = case uriPort auth of
        ""  -> 80
        _   -> read $ tail $ uriPort auth :: Word16
    ports = case uriPort auth of
        ""  -> 443
        _   -> read $ tail $ uriPort auth :: Word16
data ConnectionAddress
  = ConnectionAddressHttp     !Hostname !Word16 
  | ConnectionAddressHttps    !Hostname !Word16 
  | ConnectionAddressHttpUnix !ByteString 
  deriving Show
openConnectionAddress :: ConnectionAddress -> IO Connection
openConnectionAddress ca = case ca of
  ConnectionAddressHttp  host port  -> openConnection host port
  ConnectionAddressHttps host ports -> do
    ctx <- readIORef global
    openConnectionSSL ctx host ports
  ConnectionAddressHttpUnix fp -> do
    c <- openConnectionUnix (S.unpack fp)
    return c { cHost = mempty } 
connectionAddressFromURL :: URL -> Either String (ConnectionAddress, String, ByteString, String)
connectionAddressFromURL r' = do
  r <- either (\_ -> Left "invalid UTF-8 encoding") return (T.decodeUtf8' r')
  u <- maybe (Left "invalid URI syntax") return (parseURI (T.unpack r))
  connectionAddressFromURI u
connectionAddressFromURI :: URI -> Either String (ConnectionAddress, String, ByteString, String)
connectionAddressFromURI u = fmap addxinfo $
    case map toLower (uriScheme u) of
        "http:"      -> do
          _ <- getUrlPath
          return (ConnectionAddressHttp host (port 80), urlpath)
        "https:"     -> do
          _ <- getUrlPath
          return (ConnectionAddressHttps host (port 443), urlpath)
        "unix:" -> do
          noPort
          noHost
          when (null (uriPath u)) $
            Left "invalid empty path in unix: URI"
          unless (null (uriQuery u)) $
            Left "invalid query component in unix: URI"
          unless (null (uriFragment u)) $
            Left "invalid fragment component in unix: URI"
          let rfp = unescBytes (uriPath u)
          when (S.length rfp > 104) $
            Left "unix domain socket path must be at most 104 bytes long"
          when (S.elem '\x00' rfp) $
            Left "unix domain socket path must not contain NUL bytes"
          return (ConnectionAddressHttpUnix rfp, "")
        "http+unix:" -> do
          noPort
          fp <- getUnixPath
          return (ConnectionAddressHttpUnix fp, urlpath)
        _ -> Left ("Unknown URI scheme " ++ uriScheme u)
  where
    addxinfo (ca,p) = (ca, uriUserInfo auth, p, uriFragment u)
    auth = case uriAuthority u of
        Just x  -> x
        Nothing -> URIAuth "" "" ""
    noPort = if null (uriPort auth) then Right () else Left "invalid port number in URI"
    noHost = case uriAuthority u of
               Nothing -> return ()
               Just (URIAuth "" "" "") -> return ()
               Just _ -> Left "invalid host component in uri"
    getUrlPath = case uriRegName auth of
                   "" -> Left "missing/empty host component in uri"
                   p  -> Right p
    getUnixPath = case uriRegName auth of
                   "" -> Left "missing/empty host component in uri"
                   fp -> do
                     let rfp = unescBytes fp
                     when (S.length rfp > 104) $
                       Left "unix domain socket path must be at most 104 bytes long"
                     when (S.elem '\x00' rfp) $
                       Left "unix domain socket path must not contain NUL bytes"
                     Right rfp
    urlpath = S.pack (uriPath u ++ uriQuery u)
    host = S.pack (uriRegName auth)
    port def = case uriPort auth of
      ""  -> def
      _   -> read $ tail $ uriPort auth :: Word16
baselineContextSSL :: IO SSLContext
baselineContextSSL = withOpenSSL $ do
    ctx <- SSL.context
    SSL.contextSetDefaultCiphers ctx
#if defined(darwin_HOST_OS)
    SSL.contextSetVerificationMode ctx SSL.VerifyNone
#elif defined(mingw32_HOST_OS)
    SSL.contextSetVerificationMode ctx SSL.VerifyNone
#elif defined(freebsd_HOST_OS)
    SSL.contextSetCAFile ctx "/usr/local/etc/ssl/cert.pem"
    SSL.contextSetVerificationMode ctx $ SSL.VerifyPeer True True Nothing
#elif defined(openbsd_HOST_OS)
    SSL.contextSetCAFile ctx "/etc/ssl/cert.pem"
    SSL.contextSetVerificationMode ctx $ SSL.VerifyPeer True True Nothing
#else
    fedora <- doesDirectoryExist "/etc/pki/tls"
    if fedora
        then do
            SSL.contextSetCAFile ctx "/etc/pki/tls/certs/ca-bundle.crt"
        else do
            SSL.contextSetCADirectory ctx "/etc/ssl/certs"
    SSL.contextSetVerificationMode ctx $ SSL.VerifyPeer True True Nothing
#endif
    return ctx
parseURL :: URL -> URI
parseURL r' =
    case parseURI r of
        Just u  -> u
        Nothing -> error ("Can't parse URI " ++ r)
  where
    r = T.unpack $ T.decodeUtf8 r'
path :: URI -> ByteString
path u = case url of
            ""  -> "/"
            _   -> url
  where
    url = T.encodeUtf8 $! T.pack
                      $! concat [uriPath u, uriQuery u, uriFragment u]
get :: URL
    
    -> (Response -> InputStream ByteString -> IO β)
    
    -> IO β
get r' handler = getN 0 r' handler
getN n r' handler = do
    bracket
        (establish u)
        (teardown)
        (process)
  where
    teardown = closeConnection
    u = parseURL r'
    q = buildRequest1 $ do
            http GET (path u)
            setAccept "*/*"
    process c = do
        sendRequest c q emptyBody
        receiveResponse c (wrapRedirect u n handler)
wrapRedirect
    :: URI
    -> Int
    -> (Response -> InputStream ByteString -> IO β)
    -> Response
    -> InputStream ByteString
    -> IO β
wrapRedirect u n handler p i = do
    if (s == 301 || s == 302 || s == 303 || s == 307)
        then case lm of
                Just l  -> getN n' (splitURI u l) handler
                Nothing -> handler p i
        else handler p i
  where
    s  = getStatusCode p
    lm = getHeader p "Location"
    !n' = if n < 5
            then n + 1
            else throw $! TooManyRedirects n
splitURI :: URI -> URL -> URL
splitURI old new' =
  let
    new = S.unpack new'
  in
    if isAbsoluteURI new
       then
            new'
       else
         let
            rel = parseRelativeReference new
         in
            case rel of
                Nothing -> new'
                Just x  -> S.pack $ uriToString id old {
                                                    uriPath = uriPath x,
                                                    uriQuery = uriQuery x,
                                                    uriFragment = uriFragment x
                                                   } ""
data TooManyRedirects = TooManyRedirects Int
        deriving (Typeable, Show, Eq)
instance Exception TooManyRedirects
post :: URL
    
    -> ContentType
    
    -> (OutputStream Builder -> IO α)
    
    -> (Response -> InputStream ByteString -> IO β)
    
    -> IO β
post r' t body handler = do
    bracket
        (establish u)
        (teardown)
        (process)
  where
    teardown = closeConnection
    u = parseURL r'
    q = buildRequest1 $ do
            http POST (path u)
            setAccept "*/*"
            setContentType t
    process c = do
        _ <- sendRequest c q body
        x <- receiveResponse c handler
        return x
postForm
    :: URL
    
    -> [(ByteString, ByteString)]
    
    -> (Response -> InputStream ByteString -> IO β)
    
    -> IO β
postForm r' nvs handler = do
    bracket
        (establish u)
        (teardown)
        (process)
  where
    teardown = closeConnection
    u = parseURL r'
    q = buildRequest1 $ do
            http POST (path u)
            setAccept "*/*"
            setContentType "application/x-www-form-urlencoded"
    process c = do
        _ <- sendRequest c q (encodedFormBody nvs)
        x <- receiveResponse c handler
        return x
encodedFormBody :: [(ByteString,ByteString)] -> OutputStream Builder -> IO ()
encodedFormBody nvs o = do
    Streams.write (Just b) o
  where
    b = mconcat $ intersperse (Builder.fromString "&") $ map combine nvs
    combine :: (ByteString,ByteString) -> Builder
    combine (n',v') = mconcat [urlEncodeBuilder n', Builder.fromString "=", urlEncodeBuilder v']
put :: URL
    
    -> ContentType
    
    -> (OutputStream Builder -> IO α)
    
    -> (Response -> InputStream ByteString -> IO β)
    
    -> IO β
put r' t body handler = do
    bracket
        (establish u)
        (teardown)
        (process)
  where
    teardown = closeConnection
    u = parseURL r'
    q = buildRequest1 $ do
            http PUT (path u)
            setAccept "*/*"
            setHeader "Content-Type" t
    process c = do
        _ <- sendRequest c q body
        x <- receiveResponse c handler
        return x
concatHandler' :: Response -> InputStream ByteString -> IO ByteString
concatHandler' p i =
    if s >= 300
        then throw (HttpClientError s m)
        else concatHandler p i
  where
    s = getStatusCode p
    m = getStatusMessage p
data HttpClientError = HttpClientError Int ByteString
        deriving (Typeable)
instance Exception HttpClientError
instance Show HttpClientError where
    show (HttpClientError s msg) = Prelude.show s ++ " " ++ S.unpack msg