{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
#if __GLASGOW_HASKELL__ < 710
{-# LANGUAGE OverlappingInstances #-}
#endif
{-# LANGUAGE ScopedTypeVariables #-}
module Servant.Server.Internal.BasicAuth where
#if __GLASGOW_HASKELL__ < 710
import           Data.Functor           ((<$>))
#endif
import           Control.Monad          (guard)
import qualified Data.ByteString        as BS
import           Data.ByteString.Base64 (decodeLenient)
import           Data.CaseInsensitive   (CI(..))
import           Data.Monoid            ((<>))
import           Data.Typeable          (Typeable)
import           Data.Word8             (isSpace, toLower, _colon)
import           GHC.Generics
import           Snap.Core
import           Servant.API.BasicAuth (BasicAuthData(BasicAuthData))
import           Servant.Server.Internal.RoutingApplication
import           Servant.Server.Internal.ServantErr
data BasicAuthResult usr
  = Unauthorized
  | BadPassword
  | NoSuchUser
  | Authorized usr
  deriving (Eq, Show, Read, Generic, Typeable, Functor)
newtype BasicAuthCheck m usr = BasicAuthCheck
  { unBasicAuthCheck :: BasicAuthData
                     -> m (BasicAuthResult usr)
  }
  deriving (Generic, Typeable, Functor)
mkBAChallengerHdr :: BS.ByteString -> (CI BS.ByteString, BS.ByteString)
mkBAChallengerHdr realm = ("WWW-Authenticate", "Basic realm=\"" <> realm <> "\"")
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr req = do
    ah <- getHeader "Authorization" req
    let (b, rest) = BS.break isSpace ah
    guard (BS.map toLower b == "basic")
    let decoded = decodeLenient (BS.dropWhile isSpace rest)
    let (username, passWithColonAtHead) = BS.break (== _colon) decoded
    (_, password) <- BS.uncons passWithColonAtHead
    return (BasicAuthData username password)
runBasicAuth :: MonadSnap m => Request -> BS.ByteString -> BasicAuthCheck m usr -> DelayedM m usr
runBasicAuth req realm (BasicAuthCheck ba) =
  case decodeBAHdr req of
     Nothing -> plzAuthenticate
     Just e  -> DelayedM (const $ Route <$> ba e) >>= \res -> case res of
       BadPassword    -> plzAuthenticate
       NoSuchUser     -> plzAuthenticate
       Unauthorized   -> delayedFailFatal err403
       Authorized usr -> return usr
  where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] }