{-# LANGUAGE FlexibleContexts #-}

module Network.OAuth2.Experiment.Flows.TokenRequest where

import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans.Except (ExceptT (..), throwE)
import Data.Aeson (FromJSON)
import Network.HTTP.Conduit
import Network.OAuth.OAuth2 (
  ClientAuthenticationMethod (..),
  OAuth2,
  OAuth2Token,
  PostBody,
  uriToRequest,
 )
import Network.OAuth.OAuth2.TokenRequest (
  TokenResponseError,
  addBasicAuth,
  addDefaultRequestHeaders,
  handleOAuth2TokenResponse,
  parseResponseFlexible,
 )
import Network.OAuth2.Experiment.Pkce
import Network.OAuth2.Experiment.Types
import Network.OAuth2.Experiment.Utils
import URI.ByteString (URI)

class HasTokenRequestClientAuthenticationMethod a where
  getClientAuthenticationMethod :: a -> ClientAuthenticationMethod

-- | Only Authorization Code Grant involves a Exchange Token (Authorization Code).
-- ResourceOwnerPassword and Client Credentials make token request directly.
data NoNeedExchangeToken = NoNeedExchangeToken

class (HasOAuth2Key a, HasTokenRequestClientAuthenticationMethod a) => HasTokenRequest a where
  -- Each GrantTypeFlow has slightly different request parameter to /token endpoint.
  data TokenRequest a
  type ExchangeTokenInfo a

  -- | Only 'AuthorizationCode flow (but not resource owner password nor client credentials) will use 'ExchangeToken' in the token request
  -- create type family to be explicit on it.
  -- with 'type instance WithExchangeToken a b = b' implies no exchange token
  -- v.s. 'type instance WithExchangeToken a b = ExchangeToken -> b' implies needing an exchange token
  -- type WithExchangeToken a b
  mkTokenRequestParam :: a -> ExchangeTokenInfo a -> TokenRequest a

-------------------------------------------------------------------------------
--                               Token Request                               --
-------------------------------------------------------------------------------

-- | https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3
conduitTokenRequest ::
  (HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
  IdpApplication i a ->
  Manager ->
  ExchangeTokenInfo a ->
  ExceptT TokenResponseError m OAuth2Token
conduitTokenRequest :: forall {k} a (m :: * -> *) (i :: k).
(HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
IdpApplication i a
-> Manager
-> ExchangeTokenInfo a
-> ExceptT TokenResponseError m OAuth2Token
conduitTokenRequest IdpApplication i a
idpApp Manager
mgr ExchangeTokenInfo a
exchangeToken = do
  IdpApplication i a
-> Manager
-> (ExchangeTokenInfo a, Maybe CodeVerifier)
-> ExceptT TokenResponseError m OAuth2Token
forall {k} a (m :: * -> *) (i :: k).
(HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
IdpApplication i a
-> Manager
-> (ExchangeTokenInfo a, Maybe CodeVerifier)
-> ExceptT TokenResponseError m OAuth2Token
conduitTokenRequestInternal IdpApplication i a
idpApp Manager
mgr (ExchangeTokenInfo a
exchangeToken, Maybe CodeVerifier
forall a. Maybe a
Nothing)

-------------------------------------------------------------------------------
--                             PKCE Token Request                            --
-------------------------------------------------------------------------------

-- | https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
conduitPkceTokenRequest ::
  (HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
  IdpApplication i a ->
  Manager ->
  (ExchangeTokenInfo a, CodeVerifier) ->
  ExceptT TokenResponseError m OAuth2Token
conduitPkceTokenRequest :: forall {k} a (m :: * -> *) (i :: k).
(HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
IdpApplication i a
-> Manager
-> (ExchangeTokenInfo a, CodeVerifier)
-> ExceptT TokenResponseError m OAuth2Token
conduitPkceTokenRequest IdpApplication i a
idpApp Manager
mgr (ExchangeTokenInfo a
exchangeToken, CodeVerifier
codeVerifier) =
  IdpApplication i a
-> Manager
-> (ExchangeTokenInfo a, Maybe CodeVerifier)
-> ExceptT TokenResponseError m OAuth2Token
forall {k} a (m :: * -> *) (i :: k).
(HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
IdpApplication i a
-> Manager
-> (ExchangeTokenInfo a, Maybe CodeVerifier)
-> ExceptT TokenResponseError m OAuth2Token
conduitTokenRequestInternal IdpApplication i a
idpApp Manager
mgr (ExchangeTokenInfo a
exchangeToken, CodeVerifier -> Maybe CodeVerifier
forall a. a -> Maybe a
Just CodeVerifier
codeVerifier)

-------------------------------------------------------------------------------
--                              Internal helpers                             --
-------------------------------------------------------------------------------

conduitTokenRequestInternal ::
  (HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
  IdpApplication i a ->
  Manager ->
  (ExchangeTokenInfo a, Maybe CodeVerifier) ->
  ExceptT TokenResponseError m OAuth2Token
conduitTokenRequestInternal :: forall {k} a (m :: * -> *) (i :: k).
(HasTokenRequest a, ToQueryParam (TokenRequest a), MonadIO m) =>
IdpApplication i a
-> Manager
-> (ExchangeTokenInfo a, Maybe CodeVerifier)
-> ExceptT TokenResponseError m OAuth2Token
conduitTokenRequestInternal IdpApplication {a
Idp i
idp :: Idp i
application :: a
idp :: forall k (i :: k) a. IdpApplication i a -> Idp i
application :: forall k (i :: k) a. IdpApplication i a -> a
..} Manager
mgr (ExchangeTokenInfo a
exchangeToken, Maybe CodeVerifier
codeVerifier) =
  let req :: TokenRequest a
req = a -> ExchangeTokenInfo a -> TokenRequest a
forall a.
HasTokenRequest a =>
a -> ExchangeTokenInfo a -> TokenRequest a
mkTokenRequestParam a
application ExchangeTokenInfo a
exchangeToken
      key :: OAuth2
key = a -> OAuth2
forall a. HasOAuth2Key a => a -> OAuth2
mkOAuth2Key a
application
      body :: [(ByteString, ByteString)]
body =
        [Map Text Text] -> [(ByteString, ByteString)]
unionMapsToQueryParams
          [ TokenRequest a -> Map Text Text
forall a. ToQueryParam a => a -> Map Text Text
toQueryParam TokenRequest a
req
          , Maybe CodeVerifier -> Map Text Text
forall a. ToQueryParam a => a -> Map Text Text
toQueryParam Maybe CodeVerifier
codeVerifier
          ]
   in ClientAuthenticationMethod
-> Manager
-> OAuth2
-> URI
-> [(ByteString, ByteString)]
-> ExceptT TokenResponseError m OAuth2Token
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
ClientAuthenticationMethod
-> Manager
-> OAuth2
-> URI
-> [(ByteString, ByteString)]
-> ExceptT TokenResponseError m a
doTokenRequestInternal
        (a -> ClientAuthenticationMethod
forall a.
HasTokenRequestClientAuthenticationMethod a =>
a -> ClientAuthenticationMethod
getClientAuthenticationMethod a
application)
        Manager
mgr
        OAuth2
key
        (Idp i -> URI
forall k (i :: k). Idp i -> URI
idpTokenEndpoint Idp i
idp)
        [(ByteString, ByteString)]
body

doTokenRequestInternal ::
  (MonadIO m, FromJSON a) =>
  ClientAuthenticationMethod ->
  -- | HTTP connection manager.
  Manager ->
  -- | OAuth options
  OAuth2 ->
  -- | URL
  URI ->
  -- | Request body.
  PostBody ->
  -- | Response as ByteString
  ExceptT TokenResponseError m a
doTokenRequestInternal :: forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
ClientAuthenticationMethod
-> Manager
-> OAuth2
-> URI
-> [(ByteString, ByteString)]
-> ExceptT TokenResponseError m a
doTokenRequestInternal ClientAuthenticationMethod
clientAuthMethod Manager
manager OAuth2
oa URI
url [(ByteString, ByteString)]
body = do
  ByteString
resp <- m (Either TokenResponseError ByteString)
-> ExceptT TokenResponseError m ByteString
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either TokenResponseError ByteString)
 -> ExceptT TokenResponseError m ByteString)
-> (IO (Either TokenResponseError ByteString)
    -> m (Either TokenResponseError ByteString))
-> IO (Either TokenResponseError ByteString)
-> ExceptT TokenResponseError m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Either TokenResponseError ByteString)
-> m (Either TokenResponseError ByteString)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either TokenResponseError ByteString)
 -> ExceptT TokenResponseError m ByteString)
-> IO (Either TokenResponseError ByteString)
-> ExceptT TokenResponseError m ByteString
forall a b. (a -> b) -> a -> b
$ (Response ByteString -> Either TokenResponseError ByteString)
-> IO (Response ByteString)
-> IO (Either TokenResponseError ByteString)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Response ByteString -> Either TokenResponseError ByteString
handleOAuth2TokenResponse IO (Response ByteString)
go
  case ByteString -> Either TokenResponseError a
forall a. FromJSON a => ByteString -> Either TokenResponseError a
parseResponseFlexible ByteString
resp of
    Right a
obj -> a -> ExceptT TokenResponseError m a
forall a. a -> ExceptT TokenResponseError m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
obj
    Left TokenResponseError
e -> TokenResponseError -> ExceptT TokenResponseError m a
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE TokenResponseError
e
  where
    updateAuthHeader :: Request -> Request
updateAuthHeader =
      case ClientAuthenticationMethod
clientAuthMethod of
        ClientAuthenticationMethod
ClientSecretBasic -> OAuth2 -> Request -> Request
addBasicAuth OAuth2
oa
        ClientAuthenticationMethod
ClientSecretPost -> Request -> Request
forall a. a -> a
id
        ClientAuthenticationMethod
ClientAssertionJwt -> Request -> Request
forall a. a -> a
id

    go :: IO (Response ByteString)
go = do
      Request
req <- URI -> IO Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
uriToRequest URI
url
      let req' :: Request
req' = (Request -> Request
updateAuthHeader (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
addDefaultRequestHeaders) Request
req
      Request -> Manager -> IO (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> Manager -> m (Response ByteString)
httpLbs ([(ByteString, ByteString)] -> Request -> Request
urlEncodedBody [(ByteString, ByteString)]
body Request
req') Manager
manager