{-# LANGUAGE CPP #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE UndecidableInstances #-}

module Servant.Auth.Server.Internal.AddSetCookie where

import Blaze.ByteString.Builder (toByteString)
import qualified Data.ByteString as BS
import Data.Kind (Type)
import qualified Network.HTTP.Types as HTTP
import Network.Wai (mapResponseHeaders)
import Servant
import Servant.API.Generic
import Servant.Server.Generic
import Web.Cookie

-- What are we doing here? Well, the idea is to add headers to the response,
-- but the headers come from the authentication check. In order to do that, we
-- tweak a little the general theme of recursing down the API tree; this time,
-- we recurse down a variation of it that adds headers to all the endpoints.
-- This involves the usual type-level checks.
--
-- TODO: If the endpoints already have headers, this will not work as is.

data Nat = Z | S Nat

type family AddSetCookiesApi (n :: Nat) a where
  AddSetCookiesApi ('S 'Z) a = AddSetCookieApi a
  AddSetCookiesApi ('S n) a = AddSetCookiesApi n (AddSetCookieApi a)

type family AddSetCookieApiVerb a where
  AddSetCookieApiVerb (Headers ls a) = Headers (Header "Set-Cookie" SetCookie ': ls) a
  AddSetCookieApiVerb a = Headers '[Header "Set-Cookie" SetCookie] a

#if MIN_VERSION_servant_server(0,18,1)
type family MapAddSetCookieApiVerb (as :: [Type]) where
  MapAddSetCookieApiVerb '[] = '[]
  MapAddSetCookieApiVerb (a ': as) = (AddSetCookieApiVerb a ': MapAddSetCookieApiVerb as)
#endif

type family AddSetCookieApi a :: Type

type instance AddSetCookieApi (a :> b) = a :> AddSetCookieApi b

type instance AddSetCookieApi (a :<|> b) = AddSetCookieApi a :<|> AddSetCookieApi b
#if MIN_VERSION_servant_server(0,19,0)
type instance AddSetCookieApi (NamedRoutes api) = AddSetCookieApi (ToServantApi api)
#endif
type instance
  AddSetCookieApi (Verb method stat ctyps a) =
    Verb method stat ctyps (AddSetCookieApiVerb a)
#if MIN_VERSION_servant_server(0,18,1)
type instance AddSetCookieApi (UVerb method ctyps as)
  = UVerb method ctyps (MapAddSetCookieApiVerb as)
#endif
type instance AddSetCookieApi Raw = Raw
#if MIN_VERSION_servant_server(0,15,0)
type instance AddSetCookieApi (Stream method stat framing ctyps a)
  = Stream method stat framing ctyps (AddSetCookieApiVerb a)
#endif
type instance AddSetCookieApi (Headers hs a) = AddSetCookieApiVerb (Headers hs a)

data SetCookieList (n :: Nat) :: Type where
  SetCookieNil :: SetCookieList 'Z
  SetCookieCons :: Maybe SetCookie -> SetCookieList n -> SetCookieList ('S n)

class AddSetCookies (n :: Nat) orig new where
  addSetCookies :: SetCookieList n -> orig -> new

instance
  {-# OVERLAPS #-}
  AddSetCookies ('S n) oldb newb
  => AddSetCookies ('S n) (a -> oldb) (a -> newb)
  where
  addSetCookies :: SetCookieList ('S n) -> (a -> oldb) -> a -> newb
addSetCookies SetCookieList ('S n)
cookies a -> oldb
oldfn = SetCookieList ('S n) -> oldb -> newb
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList ('S n)
cookies (oldb -> newb) -> (a -> oldb) -> a -> newb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> oldb
oldfn

instance orig1 ~ orig2 => AddSetCookies 'Z orig1 orig2 where
  addSetCookies :: SetCookieList 'Z -> orig1 -> orig2
addSetCookies SetCookieList 'Z
_ = orig1 -> orig1
orig1 -> orig2
forall a. a -> a
id

instance
  {-# OVERLAPPABLE #-}
  ( AddHeader mods "Set-Cookie" SetCookie cookied new
  , AddSetCookies n (m old) (m cookied)
  , Functor m
  )
  => AddSetCookies ('S n) (m old) (m new)
  where
  addSetCookies :: SetCookieList ('S n) -> m old -> m new
addSetCookies (Maybe SetCookie
mCookie `SetCookieCons` SetCookieList n
rest) m old
oldVal =
    case Maybe SetCookie
mCookie of
      Maybe SetCookie
Nothing -> cookied -> new
forall (mods :: [*]) (h :: Symbol) v orig new.
AddHeader mods h v orig new =>
orig -> new
noHeader' (cookied -> new) -> m cookied -> m new
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SetCookieList n -> m old -> m cookied
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList n
rest m old
oldVal
      Just SetCookie
cookie -> SetCookie -> cookied -> new
forall (mods :: [*]) (h :: Symbol) v orig new.
AddHeader mods h v orig new =>
v -> orig -> new
addHeader' SetCookie
cookie (cookied -> new) -> m cookied -> m new
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SetCookieList n -> m old -> m cookied
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList n
rest m old
oldVal

instance
  {-# OVERLAPS #-}
  (AddSetCookies ('S n) a a', AddSetCookies ('S n) b b')
  => AddSetCookies ('S n) (a :<|> b) (a' :<|> b')
  where
  addSetCookies :: SetCookieList ('S n) -> (a :<|> b) -> a' :<|> b'
addSetCookies SetCookieList ('S n)
cookies (a
a :<|> b
b) = SetCookieList ('S n) -> a -> a'
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList ('S n)
cookies a
a a' -> b' -> a' :<|> b'
forall a b. a -> b -> a :<|> b
:<|> SetCookieList ('S n) -> b -> b'
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList ('S n)
cookies b
b

instance
  {-# OVERLAPPING #-}
  (AddSetCookies ('S n) a a, AddSetCookies ('S n) b b')
  => AddSetCookies ('S n) (a :<|> b) (a :<|> b')
  where
  addSetCookies :: SetCookieList ('S n) -> (a :<|> b) -> a :<|> b'
addSetCookies SetCookieList ('S n)
cookies (a
a :<|> b
b) = SetCookieList ('S n) -> a -> a
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList ('S n)
cookies a
a a -> b' -> a :<|> b'
forall a b. a -> b -> a :<|> b
:<|> SetCookieList ('S n) -> b -> b'
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList ('S n)
cookies b
b

instance
  {-# OVERLAPS #-}
  ( AddSetCookies ('S n) (ServerT (ToServantApi api) m) cookiedApi
  , GServantProduct (Rep (api (AsServerT m)))
  , Generic (api (AsServerT m))
  , ToServant api (AsServerT m) ~ ServerT (ToServantApi api) m
  )
  => AddSetCookies ('S n) (api (AsServerT m)) cookiedApi
  where
  addSetCookies :: SetCookieList ('S n) -> api (AsServerT m) -> cookiedApi
addSetCookies SetCookieList ('S n)
cookies = SetCookieList ('S n) -> ToServant api (AsServerT m) -> cookiedApi
forall (n :: Nat) orig new.
AddSetCookies n orig new =>
SetCookieList n -> orig -> new
addSetCookies SetCookieList ('S n)
cookies (ToServant api (AsServerT m) -> cookiedApi)
-> (api (AsServerT m) -> ToServant api (AsServerT m))
-> api (AsServerT m)
-> cookiedApi
forall b c a. (b -> c) -> (a -> b) -> a -> c
. api (AsServerT m) -> ToServant api (AsServerT m)
forall {k} (routes :: k -> *) (mode :: k).
GenericServant routes mode =>
routes mode -> ToServant routes mode
toServant

-- | for @servant <0.11@
instance AddSetCookies ('S n) Application Application where
  addSetCookies :: SetCookieList ('S n) -> Application -> Application
addSetCookies SetCookieList ('S n)
cookies Application
r Request
request Response -> IO ResponseReceived
respond =
    Application
r Request
request ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ResponseHeaders -> ResponseHeaders) -> Response -> Response
mapResponseHeaders (ResponseHeaders -> ResponseHeaders -> ResponseHeaders
forall a. [a] -> [a] -> [a]
++ SetCookieList ('S n) -> ResponseHeaders
forall (x :: Nat). SetCookieList x -> ResponseHeaders
mkHeaders SetCookieList ('S n)
cookies)

-- | for @servant >=0.11@
instance AddSetCookies ('S n) (Tagged m Application) (Tagged m Application) where
  addSetCookies :: SetCookieList ('S n)
-> Tagged m Application -> Tagged m Application
addSetCookies SetCookieList ('S n)
cookies Tagged m Application
r = Application -> Tagged m Application
forall {k} (s :: k) b. b -> Tagged s b
Tagged (Application -> Tagged m Application)
-> Application -> Tagged m Application
forall a b. (a -> b) -> a -> b
$ \Request
request Response -> IO ResponseReceived
respond ->
    Tagged m Application -> Application
forall {k} (s :: k) b. Tagged s b -> b
unTagged Tagged m Application
r Request
request ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ResponseHeaders -> ResponseHeaders) -> Response -> Response
mapResponseHeaders (ResponseHeaders -> ResponseHeaders -> ResponseHeaders
forall a. [a] -> [a] -> [a]
++ SetCookieList ('S n) -> ResponseHeaders
forall (x :: Nat). SetCookieList x -> ResponseHeaders
mkHeaders SetCookieList ('S n)
cookies)

mkHeaders :: SetCookieList x -> [HTTP.Header]
mkHeaders :: forall (x :: Nat). SetCookieList x -> ResponseHeaders
mkHeaders SetCookieList x
x = (HeaderName
"Set-Cookie",) (ByteString -> Header) -> [ByteString] -> ResponseHeaders
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SetCookieList x -> [ByteString]
forall (y :: Nat). SetCookieList y -> [ByteString]
mkCookies SetCookieList x
x
  where
    mkCookies :: forall y. SetCookieList y -> [BS.ByteString]
    mkCookies :: forall (y :: Nat). SetCookieList y -> [ByteString]
mkCookies SetCookieList y
SetCookieNil = []
    mkCookies (SetCookieCons Maybe SetCookie
Nothing SetCookieList n
rest) = SetCookieList n -> [ByteString]
forall (y :: Nat). SetCookieList y -> [ByteString]
mkCookies SetCookieList n
rest
    mkCookies (SetCookieCons (Just SetCookie
y) SetCookieList n
rest) =
      Builder -> ByteString
toByteString (SetCookie -> Builder
renderSetCookie SetCookie
y) ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: SetCookieList n -> [ByteString]
forall (y :: Nat). SetCookieList y -> [ByteString]
mkCookies SetCookieList n
rest