module Web.Simple.Auth where
import Control.Monad.IO.Class (liftIO)
import Data.ByteString.Base64
import qualified Data.ByteString.Char8 as S8
import Network.HTTP.Types
import Network.Wai
import Web.Simple.Responses
import Web.Simple.Router
type AuthRouter r = Routeable r
=> (Request -> S8.ByteString
-> S8.ByteString
-> IO (Maybe Request))
-> r
-> Route ()
basicAuthRoute :: String -> AuthRouter r
basicAuthRoute realm testAuth rt = Route (\req -> do
didAuthenticate <-
case lookup hAuthorization (requestHeaders req) of
Nothing -> return Nothing
Just authStr
| S8.take 5 authStr /= "Basic" -> return Nothing
| otherwise -> do
let up = fmap (S8.split ':') $ decode $ S8.drop 6 authStr
case up of
Right (user:pass:[]) -> liftIO $ testAuth req user pass
_ -> return Nothing
case didAuthenticate of
Nothing -> return $ Just $ requireBasicAuth realm
Just finReq -> runRoute rt finReq
) ()
authRewriteReq :: Routeable r
=> AuthRouter r
-> (S8.ByteString -> S8.ByteString -> IO Bool)
-> r
-> Route ()
authRewriteReq authRouter testAuth rt =
authRouter (\req user pwd -> do
success <- testAuth user pwd
if success then
return $ Just $ transReq req user
else return Nothing) rt
where transReq req user = req { requestHeaders = ("X-User", user):(requestHeaders req)}
basicAuth :: Routeable r => String -> S8.ByteString -> S8.ByteString -> r -> Route ()
basicAuth realm user pass = authRewriteReq (basicAuthRoute realm)
(\u p -> return $ u == user && p == pass)