{-# LANGUAGE OverloadedStrings #-}
module Snap.Util.CORS
  ( 
    applyCORS
    
  , CORSOptions(..)
  , defaultOptions
    
  , OriginList(..)
  , OriginSet, mkOriginSet, origins
    
  , HashableURI(..), HashableMethod (..)
  ) where
import Control.Applicative
import Control.Monad (join, when)
import Data.CaseInsensitive (CI)
import Data.Hashable (Hashable(..))
import Data.Maybe (fromMaybe)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..),  parseURI)
import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import qualified Data.ByteString.Char8 as S
import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import qualified Snap.Core as Snap
import Snap.Internal.Parsing (pTokens)
newtype OriginSet = OriginSet { origins :: HashSet.HashSet HashableURI }
data OriginList
  = Everywhere
  
  
  | Nowhere
  
  | Origins OriginSet
  
data CORSOptions m = CORSOptions
  { corsAllowOrigin :: m OriginList
  
  , corsAllowCredentials :: m Bool
  
  
  , corsExposeHeaders :: m (HashSet.HashSet (CI S.ByteString))
  
  
  , corsAllowedMethods :: m (HashSet.HashSet HashableMethod)
  
  , corsAllowedHeaders :: HashSet.HashSet S.ByteString -> m (HashSet.HashSet S.ByteString)
  
  
  
  }
defaultOptions :: Monad m => CORSOptions m
defaultOptions = CORSOptions
  { corsAllowOrigin = return Everywhere
  , corsAllowCredentials = return True
  , corsExposeHeaders = return HashSet.empty
  , corsAllowedMethods = return $! defaultAllowedMethods
  , corsAllowedHeaders = return
  }
defaultAllowedMethods :: HashSet.HashSet HashableMethod
defaultAllowedMethods = HashSet.fromList $ map HashableMethod
        [ Snap.GET, Snap.POST, Snap.PUT, Snap.DELETE, Snap.HEAD ]
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS options m =
  (join . fmap decodeOrigin <$> getHeader "Origin") >>= maybe m corsRequestFrom
 where
  corsRequestFrom origin = do
    originList <- corsAllowOrigin options
    if origin `inOriginList` originList
       then Snap.method Snap.OPTIONS (preflightRequestFrom origin)
              <|> handleRequestFrom origin
       else m
  preflightRequestFrom origin = do
    maybeMethod <- fmap (parseMethod . S.unpack) <$>
                     getHeader "Access-Control-Request-Method"
    case maybeMethod of
      Nothing -> m
      Just method -> do
        allowedMethods <- corsAllowedMethods options
        if method `HashSet.member` allowedMethods
          then do
            maybeHeaders <-
              fromMaybe (Just HashSet.empty) . fmap splitHeaders
                <$> getHeader "Access-Control-Request-Headers"
            case maybeHeaders of
              Nothing -> m
              Just headers -> do
                allowedHeaders <- corsAllowedHeaders options headers
                if not $ HashSet.null $
                     headers `HashSet.difference` allowedHeaders
                   then m
                   else do
                     addAccessControlAllowOrigin origin
                     addAccessControlAllowCredentials
                     commaSepHeader
                       "Access-Control-Allow-Headers"
                       id (HashSet.toList allowedHeaders)
                     commaSepHeader
                       "Access-Control-Allow-Methods"
                       (S.pack . show) (HashSet.toList allowedMethods)
          else m
  handleRequestFrom origin = do
    addAccessControlAllowOrigin origin
    addAccessControlAllowCredentials
    exposeHeaders <- corsExposeHeaders options
    when (not $ HashSet.null exposeHeaders) $
      commaSepHeader
        "Access-Control-Expose-Headers"
        CI.original (HashSet.toList exposeHeaders)
    m
  addAccessControlAllowOrigin origin =
    addHeader "Access-Control-Allow-Origin"
              (encodeUtf8 $ Text.pack $ show origin)
  addAccessControlAllowCredentials = do
    allowCredentials <- corsAllowCredentials options
    when (allowCredentials) $
      addHeader "Access-Control-Allow-Credentials" "true"
  decodeOrigin :: S.ByteString -> Maybe URI
  decodeOrigin = fmap simplifyURI . parseURI . Text.unpack . decodeUtf8
  addHeader k v = Snap.modifyResponse (Snap.addHeader k v)
  commaSepHeader k f vs =
    case vs of
      [] -> return ()
      _  -> addHeader k $ S.intercalate ", " (map f vs)
  getHeader = Snap.getsRequest . Snap.getHeader
  splitHeaders = either (const Nothing) (Just . HashSet.fromList) .
    Attoparsec.parseOnly pTokens
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = OriginSet . HashSet.fromList .
              map (HashableURI . simplifyURI)
simplifyURI :: URI -> URI
simplifyURI uri = uri { uriAuthority =
                          fmap simplifyURIAuth (uriAuthority uri)
                       , uriPath = ""
                       , uriQuery = ""
                       , uriFragment = ""
                       }
 where simplifyURIAuth auth = auth { uriUserInfo = "" }
parseMethod :: String -> HashableMethod
parseMethod "GET"     = HashableMethod Snap.GET
parseMethod "POST"    = HashableMethod Snap.POST
parseMethod "HEAD"    = HashableMethod Snap.HEAD
parseMethod "PUT"     = HashableMethod Snap.PUT
parseMethod "DELETE"  = HashableMethod Snap.DELETE
parseMethod "TRACE"   = HashableMethod Snap.TRACE
parseMethod "OPTIONS" = HashableMethod Snap.OPTIONS
parseMethod "CONNECT" = HashableMethod Snap.CONNECT
parseMethod "PATCH"   = HashableMethod Snap.PATCH
parseMethod s         = HashableMethod $ Snap.Method (S.pack s)
newtype HashableURI = HashableURI URI
  deriving (Eq)
instance Show HashableURI where
  show (HashableURI u) = show u
instance Hashable HashableURI where
  hashWithSalt s (HashableURI (URI scheme authority path query fragment)) =
    s `hashWithSalt`
    scheme `hashWithSalt`
    fmap hashAuthority authority `hashWithSalt`
    path `hashWithSalt`
    query `hashWithSalt`
    fragment
   where
    hashAuthority (URIAuth userInfo regName port) =
          s `hashWithSalt`
          userInfo `hashWithSalt`
          regName `hashWithSalt`
          port
inOriginList :: URI -> OriginList -> Bool
_ `inOriginList` Nowhere = False
_ `inOriginList` Everywhere = True
origin `inOriginList` (Origins (OriginSet xs)) =
  HashableURI origin `HashSet.member` xs
newtype HashableMethod = HashableMethod Snap.Method
  deriving (Eq)
instance Hashable HashableMethod where
  hashWithSalt s (HashableMethod Snap.GET)        = s `hashWithSalt` (0 :: Int)
  hashWithSalt s (HashableMethod Snap.HEAD)       = s `hashWithSalt` (1 :: Int)
  hashWithSalt s (HashableMethod Snap.POST)       = s `hashWithSalt` (2 :: Int)
  hashWithSalt s (HashableMethod Snap.PUT)        = s `hashWithSalt` (3 :: Int)
  hashWithSalt s (HashableMethod Snap.DELETE)     = s `hashWithSalt` (4 :: Int)
  hashWithSalt s (HashableMethod Snap.TRACE)      = s `hashWithSalt` (5 :: Int)
  hashWithSalt s (HashableMethod Snap.OPTIONS)    = s `hashWithSalt` (6 :: Int)
  hashWithSalt s (HashableMethod Snap.CONNECT)    = s `hashWithSalt` (7 :: Int)
  hashWithSalt s (HashableMethod Snap.PATCH)      = s `hashWithSalt` (8 :: Int)
  hashWithSalt s (HashableMethod (Snap.Method m)) =
    s `hashWithSalt` (9 :: Int) `hashWithSalt` m
instance Show HashableMethod where
  show (HashableMethod m) = show m