{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface #-}

-- |
-- Module      : Crypto.KDF.Scrypt
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Scrypt key derivation function as defined in Colin Percival's paper
-- "Stronger Key Derivation via Sequential Memory-Hard Functions"
-- <http://www.tarsnap.com/scrypt/scrypt.pdf>.
module Crypto.KDF.Scrypt (
    Parameters (..),
    generate,
) where

import Control.Monad (forM_)
import Data.Word
import Foreign.Marshal.Alloc
import Foreign.Ptr (Ptr, plusPtr)

import Crypto.Hash (SHA256 (..))
import Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat (popCount, unsafeDoIO)
import qualified Crypto.KDF.PBKDF2 as PBKDF2

-- | Parameters for Scrypt
data Parameters = Parameters
    { Parameters -> Word64
n :: Word64
    -- ^ Cpu/Memory cost ratio. must be a power of 2 greater than 1. also known as N.
    , Parameters -> Int
r :: Int
    -- ^ Must satisfy r * p < 2^30
    , Parameters -> Int
p :: Int
    -- ^ Must satisfy r * p < 2^30
    , Parameters -> Int
outputLength :: Int
    -- ^ the number of bytes to generate out of Scrypt
    }

foreign import ccall "crypton_scrypt_smix"
    ccrypton_scrypt_smix
        :: Ptr Word8 -> Word32 -> Word64 -> Ptr Word8 -> Ptr Word8 -> IO ()

-- | Generate the scrypt key derivation data
generate
    :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray output)
    => Parameters
    -> password
    -> salt
    -> output
generate :: forall password salt output.
(ByteArrayAccess password, ByteArrayAccess salt,
 ByteArray output) =>
Parameters -> password -> salt -> output
generate Parameters
params password
password salt
salt
    | Parameters -> Int
r Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
* Parameters -> Int
p Parameters
params Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0x40000000 =
        [Char] -> output
forall a. HasCallStack => [Char] -> a
error [Char]
"Scrypt: invalid parameters: r and p constraint"
    | Word64 -> Int
forall a. Bits a => a -> Int
popCount (Parameters -> Word64
n Parameters
params) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 =
        [Char] -> output
forall a. HasCallStack => [Char] -> a
error [Char]
"Scrypt: invalid parameters: n not a power of 2"
    | Bool
otherwise = IO output -> output
forall a. IO a -> a
unsafeDoIO (IO output -> output) -> IO output -> output
forall a b. (a -> b) -> a -> b
$ do
        let b :: Bytes
b = PRF password -> Parameters -> password -> salt -> Bytes
forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
PBKDF2.generate PRF password
prf (Int -> Int -> Parameters
PBKDF2.Parameters Int
1 Int
intLen) password
password salt
salt :: B.Bytes
        Bytes
newSalt <- Bytes -> (Ptr Any -> IO ()) -> IO Bytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy Bytes
b ((Ptr Any -> IO ()) -> IO Bytes) -> (Ptr Any -> IO ()) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Any
bPtr ->
            Int -> Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ Parameters -> Word64
n Parameters
params) Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Parameters -> Int
r Parameters
params)) Int
8 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
v ->
                Int -> Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
256 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Parameters -> Int
r Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64) Int
8 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
xy -> do
                    [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. (Parameters -> Int
p Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
                        Ptr Word8 -> Word32 -> Word64 -> Ptr Word8 -> Ptr Word8 -> IO ()
ccrypton_scrypt_smix
                            (Ptr Any
bPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Parameters -> Int
r Parameters
params)))
                            (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
r Parameters
params)
                            (Parameters -> Word64
n Parameters
params)
                            Ptr Word8
v
                            Ptr Word8
xy

        output -> IO output
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (output -> IO output) -> output -> IO output
forall a b. (a -> b) -> a -> b
$
            PRF password -> Parameters -> password -> Bytes -> output
forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
PBKDF2.generate
                PRF password
prf
                (Int -> Int -> Parameters
PBKDF2.Parameters Int
1 (Parameters -> Int
outputLength Parameters
params))
                password
password
                (Bytes
newSalt :: B.Bytes)
  where
    prf :: PRF password
prf = SHA256 -> PRF password
forall a password.
(HashAlgorithm a, ByteArrayAccess password) =>
a -> PRF password
PBKDF2.prfHMAC SHA256
SHA256
    intLen :: Int
intLen = Parameters -> Int
p Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Parameters -> Int
r Parameters
params
{-# NOINLINE generate #-}