{-# LANGUAGE BangPatterns #-}

-- |
-- Module      : Crypto.KDF.HKDF
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Key Derivation Function based on HMAC
--
-- See RFC5869
module Crypto.KDF.HKDF (
    PRK,
    extract,
    extractSkip,
    expand,
    toPRK,
) where

import Crypto.Hash
import Crypto.Internal.ByteArray (
    ByteArray,
    ByteArrayAccess,
    ScrubbedBytes,
 )
import qualified Crypto.Internal.ByteArray as B
import Crypto.MAC.HMAC
import Data.Word

-- | Pseudo Random Key
data PRK a = PRK (HMAC a) | PRK_NoExpand ScrubbedBytes
    deriving (PRK a -> PRK a -> Bool
(PRK a -> PRK a -> Bool) -> (PRK a -> PRK a -> Bool) -> Eq (PRK a)
forall a. PRK a -> PRK a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. PRK a -> PRK a -> Bool
== :: PRK a -> PRK a -> Bool
$c/= :: forall a. PRK a -> PRK a -> Bool
/= :: PRK a -> PRK a -> Bool
Eq)

instance ByteArrayAccess (PRK a) where
    length :: PRK a -> Int
length (PRK HMAC a
hm) = HMAC a -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length HMAC a
hm
    length (PRK_NoExpand ScrubbedBytes
sb) = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
sb
    withByteArray :: forall p a. PRK a -> (Ptr p -> IO a) -> IO a
withByteArray (PRK HMAC a
hm) = HMAC a -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. HMAC a -> (Ptr p -> IO a) -> IO a
B.withByteArray HMAC a
hm
    withByteArray (PRK_NoExpand ScrubbedBytes
sb) = ScrubbedBytes -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sb

-- | Extract a Pseudo Random Key using the parameter and the underlaying hash mechanism
extract
    :: (HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm)
    => salt
    -- ^ Salt
    -> ikm
    -- ^ Input Keying Material
    -> PRK a
    -- ^ Pseudo random key
extract :: forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
extract salt
salt ikm
ikm = HMAC a -> PRK a
forall a. HMAC a -> PRK a
PRK (HMAC a -> PRK a) -> HMAC a -> PRK a
forall a b. (a -> b) -> a -> b
$ salt -> ikm -> HMAC a
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac salt
salt ikm
ikm

-- | Create a PRK directly from the input key material.
--
-- Only use when guaranteed to have a good quality and random data to use directly as key.
-- This effectively skip a HMAC with key=salt and data=key.
extractSkip
    :: ByteArrayAccess ikm
    => ikm
    -> PRK a
extractSkip :: forall ikm a. ByteArrayAccess ikm => ikm -> PRK a
extractSkip ikm
ikm = ScrubbedBytes -> PRK a
forall a. ScrubbedBytes -> PRK a
PRK_NoExpand (ScrubbedBytes -> PRK a) -> ScrubbedBytes -> PRK a
forall a b. (a -> b) -> a -> b
$ ikm -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert ikm
ikm

-- | Expand key material of specific length out of the parameters
expand
    :: (HashAlgorithm a, ByteArrayAccess info, ByteArray out)
    => PRK a
    -- ^ Pseudo Random Key
    -> info
    -- ^ Optional context and application specific information
    -> Int
    -- ^ Output length in bytes
    -> out
    -- ^ Output data
expand :: forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
expand PRK a
prkAt info
infoAt Int
outputLength =
    let hF :: ScrubbedBytes -> HMAC a
hF = PRK a -> ScrubbedBytes -> HMAC a
forall a b.
(HashAlgorithm a, ByteArrayAccess b) =>
PRK a -> b -> HMAC a
hFGet PRK a
prkAt
     in [ScrubbedBytes] -> out
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat ([ScrubbedBytes] -> out) -> [ScrubbedBytes] -> out
forall a b. (a -> b) -> a -> b
$ (ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
forall a.
HashAlgorithm a =>
(ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
loop ScrubbedBytes -> HMAC a
hF ScrubbedBytes
forall a. ByteArray a => a
B.empty Int
outputLength Word8
1
  where
    hFGet :: (HashAlgorithm a, ByteArrayAccess b) => PRK a -> (b -> HMAC a)
    hFGet :: forall a b.
(HashAlgorithm a, ByteArrayAccess b) =>
PRK a -> b -> HMAC a
hFGet PRK a
prk = case PRK a
prk of
        PRK HMAC a
hmacKey -> HMAC a -> b -> HMAC a
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac HMAC a
hmacKey
        PRK_NoExpand ScrubbedBytes
ikm -> ScrubbedBytes -> b -> HMAC a
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ScrubbedBytes
ikm

    info :: ScrubbedBytes
    info :: ScrubbedBytes
info = info -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert info
infoAt

    loop
        :: HashAlgorithm a
        => (ScrubbedBytes -> HMAC a)
        -> ScrubbedBytes
        -> Int
        -> Word8
        -> [ScrubbedBytes]
    loop :: forall a.
HashAlgorithm a =>
(ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
loop ScrubbedBytes -> HMAC a
hF ScrubbedBytes
tim1 Int
n Word8
i
        | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = []
        | Bool
otherwise =
            let input :: ScrubbedBytes
input = [ScrubbedBytes] -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat [ScrubbedBytes
tim1, ScrubbedBytes
info, Word8 -> ScrubbedBytes
forall a. ByteArray a => Word8 -> a
B.singleton Word8
i] :: ScrubbedBytes
                ti :: ScrubbedBytes
ti = HMAC a -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (HMAC a -> ScrubbedBytes) -> HMAC a -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> HMAC a
hF ScrubbedBytes
input
                hashLen :: Int
hashLen = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
ti
                r :: Int
r = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen
             in (if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hashLen then ScrubbedBytes
ti else Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
n ScrubbedBytes
ti)
                    ScrubbedBytes -> [ScrubbedBytes] -> [ScrubbedBytes]
forall a. a -> [a] -> [a]
: (ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
forall a.
HashAlgorithm a =>
(ScrubbedBytes -> HMAC a)
-> ScrubbedBytes -> Int -> Word8 -> [ScrubbedBytes]
loop ScrubbedBytes -> HMAC a
hF ScrubbedBytes
ti Int
r (Word8
i Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1)

toPRK :: (HashAlgorithm a, ByteArrayAccess ba) => ba -> Maybe (PRK a)
toPRK :: forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
ba -> Maybe (PRK a)
toPRK ba
bs = case ba -> Maybe (Digest a)
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
ba -> Maybe (Digest a)
digestFromByteString ba
bs of
    Maybe (Digest a)
Nothing -> Maybe (PRK a)
forall a. Maybe a
Nothing
    Just Digest a
digest -> PRK a -> Maybe (PRK a)
forall a. a -> Maybe a
Just (PRK a -> Maybe (PRK a)) -> PRK a -> Maybe (PRK a)
forall a b. (a -> b) -> a -> b
$ HMAC a -> PRK a
forall a. HMAC a -> PRK a
PRK (HMAC a -> PRK a) -> HMAC a -> PRK a
forall a b. (a -> b) -> a -> b
$ Digest a -> HMAC a
forall a. Digest a -> HMAC a
HMAC Digest a
digest