module Network.TLS.Context
    (
    
      TLSParams
    
    , Context(..)
    , Hooks(..)
    , ctxEOF
    , ctxHasSSLv2ClientHello
    , ctxDisableSSLv2ClientHello
    , ctxEstablished
    , withLog
    , ctxWithHooks
    , contextModifyHooks
    , setEOF
    , setEstablished
    , contextFlush
    , contextClose
    , contextSend
    , contextRecv
    , updateMeasure
    , withMeasure
    , withReadLock
    , withWriteLock
    , withStateLock
    , withRWLock
    
    , Information(..)
    , contextGetInformation
    
    , contextNew
    
    , contextNewOnHandle
    , contextNewOnSocket
    
    , contextHookSetHandshakeRecv
    , contextHookSetCertificateRecv
    , contextHookSetLogging
    
    , throwCore
    , usingState
    , usingState_
    , runTxState
    , runRxState
    , usingHState
    , getHState
    , getStateRNG
    ) where
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Cipher (Cipher(..), CipherKeyExchangeType(..))
import Network.TLS.Credentials
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.X509
import Data.Maybe (isJust)
import Crypto.Random
import Control.Concurrent.MVar
import Control.Monad.State
import Data.IORef
import Network.Socket (Socket)
import System.IO (Handle)
class TLSParams a where
    getTLSCommonParams :: a -> CommonParams
    getTLSRole         :: a -> Role
    getCiphers         :: a -> [Cipher]
    doHandshake        :: a -> Context -> IO ()
    doHandshakeWith    :: a -> Context -> Handshake -> IO ()
instance TLSParams ClientParams where
    getTLSCommonParams cparams = ( clientSupported cparams
                                 , clientShared cparams
                                 )
    getTLSRole _ = ClientRole
    getCiphers cparams = supportedCiphers $ clientSupported cparams
    doHandshake = handshakeClient
    doHandshakeWith = handshakeClientWith
instance TLSParams ServerParams where
    getTLSCommonParams sparams = ( serverSupported sparams
                                 , serverShared sparams
                                 )
    getTLSRole _ = ServerRole
    
    
    getCiphers sparams = filter authorizedCKE (supportedCiphers $ serverSupported sparams)
          where authorizedCKE cipher =
                    case cipherKeyExchange cipher of
                        CipherKeyExchange_RSA         -> canEncryptRSA
                        CipherKeyExchange_DH_Anon     -> canDHE
                        CipherKeyExchange_DHE_RSA     -> canSignRSA && canDHE
                        CipherKeyExchange_DHE_DSS     -> canSignDSS && canDHE
                        CipherKeyExchange_ECDHE_RSA   -> canSignRSA
                        
                        CipherKeyExchange_DH_DSS      -> False
                        CipherKeyExchange_DH_RSA      -> False
                        
                        CipherKeyExchange_ECDH_ECDSA  -> False
                        CipherKeyExchange_ECDH_RSA    -> False
                        CipherKeyExchange_ECDHE_ECDSA -> False
                canDHE        = isJust $ serverDHEParams sparams
                canSignDSS    = SignatureDSS `elem` signingAlgs
                canSignRSA    = SignatureRSA `elem` signingAlgs
                canEncryptRSA = isJust $ credentialsFindForDecrypting creds
                signingAlgs   = credentialsListSigningAlgorithms creds
                creds         = sharedCredentials $ serverShared sparams
    doHandshake = handshakeServer
    doHandshakeWith = handshakeServerWith
contextNew :: (MonadIO m, CPRG rng, HasBackend backend, TLSParams params)
           => backend   
           -> params    
           -> rng       
           -> m Context
contextNew backend params rng = liftIO $ do
    initializeBackend backend
    let role = getTLSRole params
        st   = newTLSState rng role
        (supported, shared) = getTLSCommonParams params
        ciphers = getCiphers params
    when (null ciphers) $ error "no ciphers available with those parameters"
    stvar <- newMVar st
    eof   <- newIORef False
    established <- newIORef False
    stats <- newIORef newMeasurement
    
    
    sslv2Compat <- newIORef (role == ServerRole)
    needEmptyPacket <- newIORef False
    hooks <- newIORef defaultHooks
    tx    <- newMVar newRecordState
    rx    <- newMVar newRecordState
    hs    <- newMVar Nothing
    lockWrite <- newMVar ()
    lockRead  <- newMVar ()
    lockState <- newMVar ()
    return $ Context
            { ctxConnection   = getBackend backend
            , ctxShared       = shared
            , ctxSupported    = supported
            , ctxCiphers      = ciphers
            , ctxState        = stvar
            , ctxTxState      = tx
            , ctxRxState      = rx
            , ctxHandshake    = hs
            , ctxDoHandshake  = doHandshake params
            , ctxDoHandshakeWith  = doHandshakeWith params
            , ctxMeasurement  = stats
            , ctxEOF_         = eof
            , ctxEstablished_ = established
            , ctxSSLv2ClientHello = sslv2Compat
            , ctxNeedEmptyPacket  = needEmptyPacket
            , ctxHooks            = hooks
            , ctxLockWrite        = lockWrite
            , ctxLockRead         = lockRead
            , ctxLockState        = lockState
            }
contextNewOnHandle :: (MonadIO m, CPRG rng, TLSParams params)
                   => Handle 
                   -> params 
                   -> rng    
                   -> m Context
contextNewOnHandle handle params st = contextNew handle params st
contextNewOnSocket :: (MonadIO m, CPRG rng, TLSParams params)
                   => Socket 
                   -> params 
                   -> rng    
                   -> m Context
contextNewOnSocket sock params st = contextNew sock params st
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f =
    contextModifyHooks context (\hooks -> hooks { hookRecvHandshake = f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv context f =
    contextModifyHooks context (\hooks -> hooks { hookRecvCertificates = f })
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging context loggingCallbacks =
    contextModifyHooks context (\hooks -> hooks { hookLogging = loggingCallbacks })