-- |
-- Module      : Crypto.PubKey.RSA.OAEP
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
-- RSA OAEP mode
-- <http://en.wikipedia.org/wiki/Optimal_asymmetric_encryption_padding>
module Crypto.PubKey.RSA.OAEP (
    OAEPParams (..),
    defaultOAEPParams,

    -- * OAEP encryption
    encryptWithSeed,
    encrypt,

    -- * OAEP decryption
    decrypt,
    decryptSafer,
) where

import Crypto.Hash
import Crypto.PubKey.Internal (and')
import Crypto.PubKey.MaskGenFunction
import Crypto.PubKey.RSA (generateBlinder)
import Crypto.PubKey.RSA.Prim
import Crypto.PubKey.RSA.Types
import Crypto.Random.Types
import Data.Bits (xor)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B

import Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B (convert)

-- | Parameters for OAEP encryption/decryption
data OAEPParams hash seed output = OAEPParams
    { forall hash seed output. OAEPParams hash seed output -> hash
oaepHash :: hash
    -- ^ Hash function to use.
    , forall hash seed output.
OAEPParams hash seed output -> MaskGenAlgorithm seed output
oaepMaskGenAlg :: MaskGenAlgorithm seed output
    -- ^ Mask Gen algorithm to use.
    , forall hash seed output.
OAEPParams hash seed output -> Maybe ByteString
oaepLabel :: Maybe ByteString
    -- ^ Optional label prepended to message.
    }

-- | Default Params with a specified hash function
defaultOAEPParams
    :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hash)
    => hash
    -> OAEPParams hash seed output
defaultOAEPParams :: forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
defaultOAEPParams hash
hashAlg =
    OAEPParams
        { oaepHash :: hash
oaepHash = hash
hashAlg
        , oaepMaskGenAlg :: MaskGenAlgorithm seed output
oaepMaskGenAlg = hash -> MaskGenAlgorithm seed output
forall seed output hashAlg.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hashAlg) =>
hashAlg -> seed -> Int -> output
mgf1 hash
hashAlg
        , oaepLabel :: Maybe ByteString
oaepLabel = Maybe ByteString
forall a. Maybe a
Nothing
        }

-- | Encrypt a message using OAEP with a predefined seed.
encryptWithSeed
    :: HashAlgorithm hash
    => ByteString
    -- ^ Seed
    -> OAEPParams hash ByteString ByteString
    -- ^ OAEP params to use for encryption
    -> PublicKey
    -- ^ Public key.
    -> ByteString
    -- ^ Message to encrypt
    -> Either Error ByteString
encryptWithSeed :: forall hash.
HashAlgorithm hash =>
ByteString
-> OAEPParams hash ByteString ByteString
-> PublicKey
-> ByteString
-> Either Error ByteString
encryptWithSeed ByteString
seed OAEPParams hash ByteString ByteString
oaep PublicKey
pk ByteString
msg
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidParameters
    | ByteString -> Int
B.length ByteString
seed Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
hashLen = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidParameters
    | Int
mLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
MessageTooLong
    | Bool
otherwise = ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right (ByteString -> Either Error ByteString)
-> ByteString -> Either Error ByteString
forall a b. (a -> b) -> a -> b
$ PublicKey -> ByteString -> ByteString
forall ba. ByteArray ba => PublicKey -> ba -> ba
ep PublicKey
pk ByteString
em
  where
    -- parameters
    k :: Int
k = PublicKey -> Int
public_size PublicKey
pk
    mLen :: Int
mLen = ByteString -> Int
B.length ByteString
msg
    mgf :: MaskGenAlgorithm ByteString ByteString
mgf = OAEPParams hash ByteString ByteString
-> MaskGenAlgorithm ByteString ByteString
forall hash seed output.
OAEPParams hash seed output -> MaskGenAlgorithm seed output
oaepMaskGenAlg OAEPParams hash ByteString ByteString
oaep
    labelHash :: Digest hash
labelHash = hash -> ByteString -> Digest hash
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep) (ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
B.empty ByteString -> ByteString
forall a. a -> a
id (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ OAEPParams hash ByteString ByteString -> Maybe ByteString
forall hash seed output.
OAEPParams hash seed output -> Maybe ByteString
oaepLabel OAEPParams hash ByteString ByteString
oaep)
    hashLen :: Int
hashLen = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep)

    -- put fields
    ps :: ByteString
ps = Int -> Word8 -> ByteString
B.replicate (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) Word8
0
    db :: ByteString
db = [ByteString] -> ByteString
B.concat [Digest hash -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert Digest hash
labelHash, ByteString
ps, Word8 -> ByteString
B.singleton Word8
0x1, ByteString
msg]
    dbmask :: ByteString
dbmask = MaskGenAlgorithm ByteString ByteString
mgf ByteString
seed (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    maskedDB :: ByteString
maskedDB = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
db ByteString
dbmask
    seedMask :: ByteString
seedMask = MaskGenAlgorithm ByteString ByteString
mgf ByteString
maskedDB Int
hashLen
    maskedSeed :: ByteString
maskedSeed = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
seed ByteString
seedMask
    em :: ByteString
em = [ByteString] -> ByteString
B.concat [Word8 -> ByteString
B.singleton Word8
0x0, ByteString
maskedSeed, ByteString
maskedDB]

-- | Encrypt a message using OAEP
encrypt
    :: (HashAlgorithm hash, MonadRandom m)
    => OAEPParams hash ByteString ByteString
    -- ^ OAEP params to use for encryption.
    -> PublicKey
    -- ^ Public key.
    -> ByteString
    -- ^ Message to encrypt
    -> m (Either Error ByteString)
encrypt :: forall hash (m :: * -> *).
(HashAlgorithm hash, MonadRandom m) =>
OAEPParams hash ByteString ByteString
-> PublicKey -> ByteString -> m (Either Error ByteString)
encrypt OAEPParams hash ByteString ByteString
oaep PublicKey
pk ByteString
msg = do
    ByteString
seed <- Int -> m ByteString
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
hashLen
    Either Error ByteString -> m (Either Error ByteString)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
-> OAEPParams hash ByteString ByteString
-> PublicKey
-> ByteString
-> Either Error ByteString
forall hash.
HashAlgorithm hash =>
ByteString
-> OAEPParams hash ByteString ByteString
-> PublicKey
-> ByteString
-> Either Error ByteString
encryptWithSeed ByteString
seed OAEPParams hash ByteString ByteString
oaep PublicKey
pk ByteString
msg)
  where
    hashLen :: Int
hashLen = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep)

-- | un-pad a OAEP encoded message.
--
-- It doesn't apply the RSA decryption primitive
unpad
    :: HashAlgorithm hash
    => OAEPParams hash ByteString ByteString
    -- ^ OAEP params to use
    -> Int
    -- ^ size of the key in bytes
    -> ByteString
    -- ^ encoded message (not encrypted)
    -> Either Error ByteString
unpad :: forall hash.
HashAlgorithm hash =>
OAEPParams hash ByteString ByteString
-> Int -> ByteString -> Either Error ByteString
unpad OAEPParams hash ByteString ByteString
oaep Int
k ByteString
em
    | Bool
paddingSuccess = ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right ByteString
msg
    | Bool
otherwise = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
MessageNotRecognized
  where
    -- parameters
    mgf :: MaskGenAlgorithm ByteString ByteString
mgf = OAEPParams hash ByteString ByteString
-> MaskGenAlgorithm ByteString ByteString
forall hash seed output.
OAEPParams hash seed output -> MaskGenAlgorithm seed output
oaepMaskGenAlg OAEPParams hash ByteString ByteString
oaep
    labelHash :: ByteString
labelHash = Digest hash -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (Digest hash -> ByteString) -> Digest hash -> ByteString
forall a b. (a -> b) -> a -> b
$ hash -> ByteString -> Digest hash
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep) (ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
B.empty ByteString -> ByteString
forall a. a -> a
id (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ OAEPParams hash ByteString ByteString -> Maybe ByteString
forall hash seed output.
OAEPParams hash seed output -> Maybe ByteString
oaepLabel OAEPParams hash ByteString ByteString
oaep)
    hashLen :: Int
hashLen = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep)
    -- getting em's fields
    (ByteString
pb, ByteString
em0) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
1 ByteString
em
    (ByteString
maskedSeed, ByteString
maskedDB) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
hashLen ByteString
em0
    seedMask :: ByteString
seedMask = MaskGenAlgorithm ByteString ByteString
mgf ByteString
maskedDB Int
hashLen
    seed :: ByteString
seed = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
maskedSeed ByteString
seedMask
    dbmask :: ByteString
dbmask = MaskGenAlgorithm ByteString ByteString
mgf ByteString
seed (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    db :: ByteString
db = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
maskedDB ByteString
dbmask
    -- getting db's fields
    (ByteString
labelHash', ByteString
db1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
hashLen ByteString
db
    (ByteString
_, ByteString
db2) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0) ByteString
db1
    (ByteString
ps1, ByteString
msg) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
1 ByteString
db2

    paddingSuccess :: Bool
paddingSuccess =
        [Bool] -> Bool
and'
            [ ByteString
labelHash' ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
labelHash -- no need for constant eq
            , ByteString
ps1 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
B.replicate Int
1 Word8
0x1
            , ByteString
pb ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
B.replicate Int
1 Word8
0x0
            ]

-- | Decrypt a ciphertext using OAEP
--
-- When the signature is not in a context where an attacker could gain
-- information from the timing of the operation, the blinder can be set to None.
--
-- If unsure always set a blinder or use decryptSafer
decrypt
    :: HashAlgorithm hash
    => Maybe Blinder
    -- ^ Optional blinder
    -> OAEPParams hash ByteString ByteString
    -- ^ OAEP params to use for decryption
    -> PrivateKey
    -- ^ Private key
    -> ByteString
    -- ^ Cipher text
    -> Either Error ByteString
decrypt :: forall hash.
HashAlgorithm hash =>
Maybe Blinder
-> OAEPParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
decrypt Maybe Blinder
blinder OAEPParams hash ByteString ByteString
oaep PrivateKey
pk ByteString
cipher
    | ByteString -> Int
B.length ByteString
cipher Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
k = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
MessageSizeIncorrect
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidParameters
    | Bool
otherwise = OAEPParams hash ByteString ByteString
-> Int -> ByteString -> Either Error ByteString
forall hash.
HashAlgorithm hash =>
OAEPParams hash ByteString ByteString
-> Int -> ByteString -> Either Error ByteString
unpad OAEPParams hash ByteString ByteString
oaep (PrivateKey -> Int
private_size PrivateKey
pk) (ByteString -> Either Error ByteString)
-> ByteString -> Either Error ByteString
forall a b. (a -> b) -> a -> b
$ Maybe Blinder -> PrivateKey -> ByteString -> ByteString
forall ba. ByteArray ba => Maybe Blinder -> PrivateKey -> ba -> ba
dp Maybe Blinder
blinder PrivateKey
pk ByteString
cipher
  where
    -- parameters
    k :: Int
k = PrivateKey -> Int
private_size PrivateKey
pk
    hashLen :: Int
hashLen = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep)

-- | Decrypt a ciphertext using OAEP and by automatically generating a blinder.
decryptSafer
    :: (HashAlgorithm hash, MonadRandom m)
    => OAEPParams hash ByteString ByteString
    -- ^ OAEP params to use for decryption
    -> PrivateKey
    -- ^ Private key
    -> ByteString
    -- ^ Cipher text
    -> m (Either Error ByteString)
decryptSafer :: forall hash (m :: * -> *).
(HashAlgorithm hash, MonadRandom m) =>
OAEPParams hash ByteString ByteString
-> PrivateKey -> ByteString -> m (Either Error ByteString)
decryptSafer OAEPParams hash ByteString ByteString
oaep PrivateKey
pk ByteString
cipher = do
    Blinder
blinder <- Integer -> m Blinder
forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PrivateKey -> Integer
private_n PrivateKey
pk)
    Either Error ByteString -> m (Either Error ByteString)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Blinder
-> OAEPParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
forall hash.
HashAlgorithm hash =>
Maybe Blinder
-> OAEPParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Either Error ByteString
decrypt (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
blinder) OAEPParams hash ByteString ByteString
oaep PrivateKey
pk ByteString
cipher)