{-# LANGUAGE CPP                   #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TupleSections         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
#if MIN_VERSION_base(4,9,0) && __GLASGOW_HASKELL__ >= 802
#define HAS_TYPE_ERROR
#endif
module Servant.Server.Internal
  ( module Servant.Server.Internal
  , module Servant.Server.Internal.BasicAuth
  , module Servant.Server.Internal.Context
  , module Servant.Server.Internal.Handler
  , module Servant.Server.Internal.Router
  , module Servant.Server.Internal.RoutingApplication
  , module Servant.Server.Internal.ServantErr
  ) where
import           Control.Monad
                 (join, when)
import           Control.Monad.Trans
                 (liftIO)
import           Control.Monad.Trans.Resource
                 (runResourceT)
import qualified Data.ByteString                            as B
import qualified Data.ByteString.Builder                    as BB
import qualified Data.ByteString.Char8                      as BC8
import qualified Data.ByteString.Lazy                       as BL
import           Data.Either
                 (partitionEithers)
import           Data.Maybe
                 (fromMaybe, isNothing, mapMaybe, maybeToList)
import           Data.Semigroup
                 ((<>))
import           Data.String
                 (IsString (..))
import           Data.String.Conversions
                 (cs)
import           Data.Tagged
                 (Tagged (..), retag, untag)
import qualified Data.Text                                  as T
import           Data.Typeable
import           GHC.TypeLits
                 (KnownNat, KnownSymbol, natVal, symbolVal)
import qualified Network.HTTP.Media                         as NHM
import           Network.HTTP.Types                         hiding
                 (Header, ResponseHeaders)
import           Network.Socket
                 (SockAddr)
import           Network.Wai
                 (Application, Request, httpVersion, isSecure, lazyRequestBody,
                 rawQueryString, remoteHost, requestBody, requestHeaders,
                 requestMethod, responseLBS, responseStream, vault)
import           Prelude ()
import           Prelude.Compat
import           Servant.API
                 ((:<|>) (..), (:>), Accept (..), BasicAuth, Capture',
                 CaptureAll, Description, EmptyAPI, FramingRender (..),
                 FramingUnrender (..), FromSourceIO (..), Header', If,
                 IsSecure (..), QueryFlag, QueryParam', QueryParams, Raw,
                 ReflectMethod (reflectMethod), RemoteHost, ReqBody',
                 SBool (..), SBoolI (..), SourceIO, Stream, StreamBody',
                 Summary, ToSourceIO (..), Vault, Verb, WithNamedContext)
import           Servant.API.ContentTypes
                 (AcceptHeader (..), AllCTRender (..), AllCTUnrender (..),
                 AllMime, MimeRender (..), MimeUnrender (..), canHandleAcceptH)
import           Servant.API.Modifiers
                 (FoldLenient, FoldRequired, RequestArgument,
                 unfoldRequestArgument)
import           Servant.API.ResponseHeaders
                 (GetHeaders, Headers, getHeaders, getResponse)
import qualified Servant.Types.SourceT                      as S
import           Web.HttpApiData
                 (FromHttpApiData, parseHeader, parseQueryParam,
                 parseUrlPieceMaybe, parseUrlPieces)
import           Servant.Server.Internal.BasicAuth
import           Servant.Server.Internal.Context
import           Servant.Server.Internal.Handler
import           Servant.Server.Internal.Router
import           Servant.Server.Internal.RoutingApplication
import           Servant.Server.Internal.ServantErr
#ifdef HAS_TYPE_ERROR
import           GHC.TypeLits
                 (ErrorMessage (..), TypeError)
#endif
class HasServer api context where
  type ServerT api (m :: * -> *) :: *
  route ::
       Proxy api
    -> Context context
    -> Delayed env (Server api)
    -> Router env
  hoistServerWithContext
      :: Proxy api
      -> Proxy context
      -> (forall x. m x -> n x)
      -> ServerT api m
      -> ServerT api n
type Server api = ServerT api Handler
instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) context where
  type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m
  route Proxy context server = choice (route pa context ((\ (a :<|> _) -> a) <$> server))
                                      (route pb context ((\ (_ :<|> b) -> b) <$> server))
    where pa = Proxy :: Proxy a
          pb = Proxy :: Proxy b
  
  hoistServerWithContext _ pc nt (a :<|> b) =
    hoistServerWithContext (Proxy :: Proxy a) pc nt a :<|>
    hoistServerWithContext (Proxy :: Proxy b) pc nt b
instance (KnownSymbol capture, FromHttpApiData a, HasServer api context)
      => HasServer (Capture' mods capture a :> api) context where
  type ServerT (Capture' mods capture a :> api) m =
     a -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context d =
    CaptureRouter $
        route (Proxy :: Proxy api)
              context
              (addCapture d $ \ txt -> case parseUrlPieceMaybe txt of
                 Nothing -> delayedFail err400
                 Just v  -> return v
              )
instance (KnownSymbol capture, FromHttpApiData a, HasServer api context)
      => HasServer (CaptureAll capture a :> api) context where
  type ServerT (CaptureAll capture a :> api) m =
    [a] -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context d =
    CaptureAllRouter $
        route (Proxy :: Proxy api)
              context
              (addCapture d $ \ txts -> case parseUrlPieces txts of
                 Left _  -> delayedFail err400
                 Right v -> return v
              )
allowedMethodHead :: Method -> Request -> Bool
allowedMethodHead method request = method == methodGet && requestMethod request == methodHead
allowedMethod :: Method -> Request -> Bool
allowedMethod method request = allowedMethodHead method request || requestMethod request == method
methodCheck :: Method -> Request -> DelayedIO ()
methodCheck method request
  | allowedMethod method request = return ()
  | otherwise                    = delayedFail err405
acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> DelayedIO ()
acceptCheck proxy accH
  | canHandleAcceptH proxy (AcceptHeader accH) = return ()
  | otherwise                                  = delayedFail err406
methodRouter :: (AllCTRender ctypes a)
             => (b -> ([(HeaderName, B.ByteString)], a))
             -> Method -> Proxy ctypes -> Status
             -> Delayed env (Handler b)
             -> Router env
methodRouter splitHeaders method proxy status action = leafRouter route'
  where
    route' env request respond =
          let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
          in runAction (action `addMethodCheck` methodCheck method request
                               `addAcceptCheck` acceptCheck proxy accH
                       ) env request respond $ \ output -> do
               let (headers, b) = splitHeaders output
               case handleAcceptH proxy (AcceptHeader accH) b of
                 Nothing -> FailFatal err406 
                 Just (contentT, body) ->
                      let bdy = if allowedMethodHead method request then "" else body
                      in Route $ responseLBS status ((hContentType, cs contentT) : headers) bdy
instance {-# OVERLAPPABLE #-}
         ( AllCTRender ctypes a, ReflectMethod method, KnownNat status
         ) => HasServer (Verb method status ctypes a) context where
  type ServerT (Verb method status ctypes a) m = m a
  hoistServerWithContext _ _ nt s = nt s
  route Proxy _ = methodRouter ([],) method (Proxy :: Proxy ctypes) status
    where method = reflectMethod (Proxy :: Proxy method)
          status = toEnum . fromInteger $ natVal (Proxy :: Proxy status)
instance {-# OVERLAPPING #-}
         ( AllCTRender ctypes a, ReflectMethod method, KnownNat status
         , GetHeaders (Headers h a)
         ) => HasServer (Verb method status ctypes (Headers h a)) context where
  type ServerT (Verb method status ctypes (Headers h a)) m = m (Headers h a)
  hoistServerWithContext _ _ nt s = nt s
  route Proxy _ = methodRouter (\x -> (getHeaders x, getResponse x)) method (Proxy :: Proxy ctypes) status
    where method = reflectMethod (Proxy :: Proxy method)
          status = toEnum . fromInteger $ natVal (Proxy :: Proxy status)
instance {-# OVERLAPPABLE #-}
         ( MimeRender ctype chunk, ReflectMethod method, KnownNat status,
           FramingRender framing, ToSourceIO chunk a
         ) => HasServer (Stream method status framing ctype a) context where
  type ServerT (Stream method status framing ctype a) m = m a
  hoistServerWithContext _ _ nt s = nt s
  route Proxy _ = streamRouter ([],) method status (Proxy :: Proxy framing) (Proxy :: Proxy ctype)
      where method = reflectMethod (Proxy :: Proxy method)
            status = toEnum . fromInteger $ natVal (Proxy :: Proxy status)
instance {-# OVERLAPPING #-}
         ( MimeRender ctype chunk, ReflectMethod method, KnownNat status,
           FramingRender framing, ToSourceIO chunk a,
           GetHeaders (Headers h a)
         ) => HasServer (Stream method status framing ctype (Headers h a)) context where
  type ServerT (Stream method status framing ctype (Headers h a)) m = m (Headers h a)
  hoistServerWithContext _ _ nt s = nt s
  route Proxy _ = streamRouter (\x -> (getHeaders x, getResponse x)) method status (Proxy :: Proxy framing) (Proxy :: Proxy ctype)
      where method = reflectMethod (Proxy :: Proxy method)
            status = toEnum . fromInteger $ natVal (Proxy :: Proxy status)
streamRouter :: forall ctype a c chunk env framing. (MimeRender ctype chunk, FramingRender framing, ToSourceIO chunk a) =>
                (c -> ([(HeaderName, B.ByteString)], a))
             -> Method
             -> Status
             -> Proxy framing
             -> Proxy ctype
             -> Delayed env (Handler c)
             -> Router env
streamRouter splitHeaders method status framingproxy ctypeproxy action = leafRouter $ \env request respond ->
          let accH    = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
              cmediatype = NHM.matchAccept [contentType ctypeproxy] accH
              accCheck = when (isNothing cmediatype) $ delayedFail err406
              contentHeader = (hContentType, NHM.renderHeader . maybeToList $ cmediatype)
          in runAction (action `addMethodCheck` methodCheck method request
                               `addAcceptCheck` accCheck
                       ) env request respond $ \ output ->
                let (headers, fa) = splitHeaders output
                    sourceT = toSourceIO fa
                    S.SourceT kStepLBS = framingRender framingproxy (mimeRender ctypeproxy :: chunk -> BL.ByteString) sourceT
                in Route $ responseStream status (contentHeader : headers) $ \write flush -> do
                    let loop S.Stop          = flush
                        loop (S.Error err)   = fail err 
                        loop (S.Skip s)      = loop s
                        loop (S.Effect ms)   = ms >>= loop
                        loop (S.Yield lbs s) = do
                            write (BB.lazyByteString lbs)
                            flush
                            loop s
                    kStepLBS loop
instance
  (KnownSymbol sym, FromHttpApiData a, HasServer api context
  , SBoolI (FoldRequired mods), SBoolI (FoldLenient mods)
  )
  => HasServer (Header' mods sym a :> api) context where
  type ServerT (Header' mods sym a :> api) m =
    RequestArgument mods a -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context subserver = route (Proxy :: Proxy api) context $
      subserver `addHeaderCheck` withRequest headerCheck
    where
      headerName :: IsString n => n
      headerName = fromString $ symbolVal (Proxy :: Proxy sym)
      headerCheck :: Request -> DelayedIO (RequestArgument mods a)
      headerCheck req =
          unfoldRequestArgument (Proxy :: Proxy mods) errReq errSt mev
        where
          mev :: Maybe (Either T.Text a)
          mev = fmap parseHeader $ lookup headerName (requestHeaders req)
          errReq = delayedFailFatal err400
            { errBody = "Header " <> headerName <> " is required"
            }
          errSt e = delayedFailFatal err400
              { errBody = cs $ "Error parsing header "
                               <> headerName
                               <> " failed: " <> e
              }
instance
  ( KnownSymbol sym, FromHttpApiData a, HasServer api context
  , SBoolI (FoldRequired mods), SBoolI (FoldLenient mods)
  )
  => HasServer (QueryParam' mods sym a :> api) context where
  type ServerT (QueryParam' mods sym a :> api) m =
    RequestArgument mods a -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context subserver =
    let querytext req = parseQueryText $ rawQueryString req
        paramname = cs $ symbolVal (Proxy :: Proxy sym)
        parseParam :: Request -> DelayedIO (RequestArgument mods a)
        parseParam req =
            unfoldRequestArgument (Proxy :: Proxy mods) errReq errSt mev
          where
            mev :: Maybe (Either T.Text a)
            mev = fmap parseQueryParam $ join $ lookup paramname $ querytext req
            errReq = delayedFailFatal err400
              { errBody = cs $ "Query parameter " <> paramname <> " is required"
              }
            errSt e = delayedFailFatal err400
              { errBody = cs $ "Error parsing query parameter "
                               <> paramname <> " failed: " <> e
              }
        delayed = addParameterCheck subserver . withRequest $ \req ->
                    parseParam req
    in route (Proxy :: Proxy api) context delayed
instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
      => HasServer (QueryParams sym a :> api) context where
  type ServerT (QueryParams sym a :> api) m =
    [a] -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context subserver = route (Proxy :: Proxy api) context $
      subserver `addParameterCheck` withRequest paramsCheck
    where
      paramname = cs $ symbolVal (Proxy :: Proxy sym)
      paramsCheck req =
          case partitionEithers $ fmap parseQueryParam params of
              ([], parsed) -> return parsed
              (errs, _)    -> delayedFailFatal err400
                  { errBody = cs $ "Error parsing query parameter(s) "
                                   <> paramname <> " failed: "
                                   <> T.intercalate ", " errs
                  }
        where
          params :: [T.Text]
          params = mapMaybe snd
                 . filter (looksLikeParam . fst)
                 . parseQueryText
                 . rawQueryString
                 $ req
          looksLikeParam name = name == paramname || name == (paramname <> "[]")
instance (KnownSymbol sym, HasServer api context)
      => HasServer (QueryFlag sym :> api) context where
  type ServerT (QueryFlag sym :> api) m =
    Bool -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context subserver =
    let querytext r = parseQueryText $ rawQueryString r
        param r = case lookup paramname (querytext r) of
          Just Nothing  -> True  
          Just (Just v) -> examine v 
          Nothing       -> False 
    in  route (Proxy :: Proxy api) context (passToServer subserver param)
    where paramname = cs $ symbolVal (Proxy :: Proxy sym)
          examine v | v == "true" || v == "1" || v == "" = True
                    | otherwise = False
instance HasServer Raw context where
  type ServerT Raw m = Tagged m Application
  hoistServerWithContext _ _ _ = retag
  route Proxy _ rawApplication = RawRouter $ \ env request respond -> runResourceT $ do
    
    
    
    r <- runDelayed rawApplication env request
    liftIO $ go r request respond
    where go r request respond = case r of
            Route app   -> untag app request (respond . Route)
            Fail a      -> respond $ Fail a
            FailFatal e -> respond $ FailFatal e
instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods)
         ) => HasServer (ReqBody' mods list a :> api) context where
  type ServerT (ReqBody' mods list a :> api) m =
    If (FoldLenient mods) (Either String a) a -> ServerT api m
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
  route Proxy context subserver
      = route (Proxy :: Proxy api) context $
          addBodyCheck subserver ctCheck bodyCheck
    where
      
      ctCheck = withRequest $ \ request -> do
        
        
        
        
        let contentTypeH = fromMaybe "application/octet-stream"
                         $ lookup hContentType $ requestHeaders request
        case canHandleCTypeH (Proxy :: Proxy list) (cs contentTypeH) :: Maybe (BL.ByteString -> Either String a) of
          Nothing -> delayedFail err415
          Just f  -> return f
      
      bodyCheck f = withRequest $ \ request -> do
        mrqbody <- f <$> liftIO (lazyRequestBody request)
        case sbool :: SBool (FoldLenient mods) of
          STrue -> return mrqbody
          SFalse -> case mrqbody of
            Left e  -> delayedFailFatal err400 { errBody = cs e }
            Right v -> return v
instance
    ( FramingUnrender framing, FromSourceIO chunk a, MimeUnrender ctype chunk
    , HasServer api context
    ) => HasServer (StreamBody' mods framing ctype a :> api) context
  where
    type ServerT (StreamBody' mods framing ctype a :> api) m = a -> ServerT api m
    hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
    route Proxy context subserver = route (Proxy :: Proxy api) context $
        addBodyCheck subserver ctCheck bodyCheck
      where
        ctCheck :: DelayedIO (SourceIO chunk -> a)
        
        ctCheck = return fromSourceIO
        bodyCheck :: (SourceIO chunk -> a) -> DelayedIO a
        bodyCheck fromRS = withRequest $ \req -> do
            let mimeUnrender'    = mimeUnrender (Proxy :: Proxy ctype) :: BL.ByteString -> Either String chunk
            let framingUnrender' = framingUnrender (Proxy :: Proxy framing) mimeUnrender' :: SourceIO B.ByteString ->  SourceIO chunk
            let body = requestBody req
            let rs = S.fromAction B.null body
            let rs' = fromRS $ framingUnrender' rs
            return rs'
instance (KnownSymbol path, HasServer api context) => HasServer (path :> api) context where
  type ServerT (path :> api) m = ServerT api m
  route Proxy context subserver =
    pathRouter
      (cs (symbolVal proxyPath))
      (route (Proxy :: Proxy api) context subserver)
    where proxyPath = Proxy :: Proxy path
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt s
instance HasServer api context => HasServer (RemoteHost :> api) context where
  type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m
  route Proxy context subserver =
    route (Proxy :: Proxy api) context (passToServer subserver remoteHost)
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
instance HasServer api context => HasServer (IsSecure :> api) context where
  type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m
  route Proxy context subserver =
    route (Proxy :: Proxy api) context (passToServer subserver secure)
    where secure req = if isSecure req then Secure else NotSecure
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
instance HasServer api context => HasServer (Vault :> api) context where
  type ServerT (Vault :> api) m = Vault -> ServerT api m
  route Proxy context subserver =
    route (Proxy :: Proxy api) context (passToServer subserver vault)
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
instance HasServer api context => HasServer (HttpVersion :> api) context where
  type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m
  route Proxy context subserver =
    route (Proxy :: Proxy api) context (passToServer subserver httpVersion)
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
instance HasServer api ctx => HasServer (Summary desc :> api) ctx where
  type ServerT (Summary desc :> api) m = ServerT api m
  route _ = route (Proxy :: Proxy api)
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt s
instance HasServer api ctx => HasServer (Description desc :> api) ctx where
  type ServerT (Description desc :> api) m = ServerT api m
  route _ = route (Proxy :: Proxy api)
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt s
data EmptyServer = EmptyServer deriving (Typeable, Eq, Show, Bounded, Enum)
emptyServer :: ServerT EmptyAPI m
emptyServer = Tagged EmptyServer
instance HasServer EmptyAPI context where
  type ServerT EmptyAPI m = Tagged m EmptyServer
  route Proxy _ _ = StaticRouter mempty mempty
  hoistServerWithContext _ _ _ = retag
instance ( KnownSymbol realm
         , HasServer api context
         , HasContextEntry context (BasicAuthCheck usr)
         )
    => HasServer (BasicAuth realm usr :> api) context where
  type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m
  route Proxy context subserver =
    route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck)
    where
       realm = BC8.pack $ symbolVal (Proxy :: Proxy realm)
       basicAuthContext = getContextEntry context
       authCheck = withRequest $ \ req -> runBasicAuth req realm basicAuthContext
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
ct_wildcard :: B.ByteString
ct_wildcard = "*" <> "/" <> "*" 
instance (HasContextEntry context (NamedContext name subContext), HasServer subApi subContext)
  => HasServer (WithNamedContext name subContext subApi) context where
  type ServerT (WithNamedContext name subContext subApi) m =
    ServerT subApi m
  route Proxy context delayed =
    route subProxy subContext delayed
    where
      subProxy :: Proxy subApi
      subProxy = Proxy
      subContext :: Context subContext
      subContext = descendIntoNamedContext (Proxy :: Proxy name) context
  hoistServerWithContext _ _ nt s = hoistServerWithContext (Proxy :: Proxy subApi) (Proxy :: Proxy subContext) nt s
#ifdef HAS_TYPE_ERROR
instance TypeError (HasServerArrowKindError arr) => HasServer ((arr :: k -> l) :> api) context
  where
    type ServerT (arr :> api) m = TypeError (HasServerArrowKindError arr)
    
    route _ _ _ = error "servant-server panic: impossible happened in HasServer (arr :> api)"
    hoistServerWithContext _ _ _ = id
type HasServerArrowKindError arr =
    'Text "Expected something of kind Symbol or *, got: k -> l on the LHS of ':>'."
    ':$$: 'Text "Maybe you haven't applied enough arguments to"
    ':$$: 'ShowType arr
instance TypeError (HasServerArrowTypeError a b) => HasServer (a -> b) context
  where
    type ServerT (a -> b) m = TypeError (HasServerArrowTypeError a b)
    route _ _ _ = error "servant-server panic: impossible happened in HasServer (a -> b)"
    hoistServerWithContext _ _ _ = id
type HasServerArrowTypeError a b =
    'Text "No instance HasServer (a -> b)."
    ':$$: 'Text "Maybe you have used '->' instead of ':>' between "
    ':$$: 'ShowType a
    ':$$: 'Text "and"
    ':$$: 'ShowType b
#endif