{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-|
Module      : Data.Password.Scrypt
Copyright   : (c) Dennis Gosnell, 2019; Felix Paulusma, 2020
License     : BSD-style (see LICENSE file)
Maintainer  : cdep.illabout@gmail.com
Stability   : experimental
Portability : POSIX

= scrypt

The @scrypt@ algorithm is a fairly new one. First published
in 2009, but published by the IETF in 2016 as <https://tools.ietf.org/html/rfc7914 RFC 7914>.
Originally used for the Tarsnap backup service, it is
designed to be costly by requiring large amounts of memory.

== Other algorithms

@scrypt@ does increase the memory requirement in contrast to
@'Data.Password.Bcrypt.Bcrypt'@ and @'Data.Password.PBKDF2.PBKDF2'@, but it
turns out it is not as optimal as it could be, and thus others have set out
to search for other algorithms that do fulfill on their promises.
@'Data.Password.Argon2.Argon2'@ seems to be the winner in that search.

That is not to say using @scrypt@ somehow means your passwords
won't be properly protected. The cryptography is sound and
thus is fine for protection against brute-force attacks.
Because of the memory cost, it is generally advised to use
@'Data.Password.Bcrypt.Bcrypt'@ if you're not sure this might be a
problem on your system.
-}

module Data.Password.Scrypt (
  -- * Algorithm
  Scrypt
  -- * Plain-text Password
  , Password
  , mkPassword
  -- * Hash Passwords (scrypt)
  , hashPassword
  , PasswordHash(..)
  -- * Verify Passwords (scrypt)
  , checkPassword
  , PasswordCheck(..)
  -- * Hashing Manually (scrypt)
  , hashPasswordWithParams
  , defaultParams
  , extractParams
  , ScryptParams(..)
  -- ** Hashing with salt (DISADVISED)
  --
  -- | Hashing with a set 'Salt' is almost never what you want
  -- to do. Use 'hashPassword' or 'hashPasswordWithParams' to have
  -- automatic generation of randomized salts.
  , hashPasswordWithSalt
  , newSalt
  , Salt(..)
  -- * Unsafe debugging function to show a Password
  , unsafeShowPassword
  , -- * Setup for doctests.
    -- $setup
  ) where

import Control.Monad (guard)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Crypto.KDF.Scrypt as Scrypt (Parameters(..), generate)
#if MIN_VERSION_base64(1,0,0)
import Data.Base64.Types (extractBase64)
#endif
import Data.ByteArray (Bytes, constEq, convert)
import Data.ByteString (ByteString)
import Data.ByteString.Base64 (encodeBase64)
import qualified Data.ByteString.Char8 as C8 (length)
import Data.Maybe (fromMaybe)
import qualified Data.Text as T (intercalate, split)
import Data.Word (Word32)

import Data.Password.Types (
    Password
  , PasswordHash(..)
  , mkPassword
  , unsafeShowPassword
  , Salt(..)
  )
import Data.Password.Internal (
    PasswordCheck(..)
  , from64
  , readT
  , showT
  , toBytes
  )
import qualified Data.Password.Internal (newSalt)

-- | Phantom type for __scrypt__
--
-- @since 2.0.0.0
data Scrypt

-- $setup
-- >>> :set -XFlexibleInstances
-- >>> :set -XOverloadedStrings
--
-- Import needed libraries.
--
-- >>> import Data.Password.Types
-- >>> import Data.ByteString (pack)
-- >>> import Test.QuickCheck (Arbitrary(arbitrary), Blind(Blind), vector)
-- >>> import Test.QuickCheck.Instances.Text ()
--
-- >>> instance Arbitrary (Salt a) where arbitrary = Salt . pack <$> vector 32
-- >>> instance Arbitrary Password where arbitrary = fmap mkPassword arbitrary
-- >>> let salt = Salt "abcdefghijklmnopqrstuvwxyz012345"
-- >>> let testParams = defaultParams {scryptRounds = 10}

-- -- >>> instance Arbitrary (PasswordHash Scrypt) where arbitrary = hashPasswordWithSalt testParams <$> arbitrary <*> arbitrary

-- | Hash the 'Password' using the 'Scrypt' hash algorithm
--
-- >>> hashPassword $ mkPassword "foobar"
-- PasswordHash {unPasswordHash = "14|8|1|...|..."}
hashPassword :: MonadIO m => Password -> m (PasswordHash Scrypt)
hashPassword :: forall (m :: * -> *).
MonadIO m =>
Password -> m (PasswordHash Scrypt)
hashPassword = ScryptParams -> Password -> m (PasswordHash Scrypt)
forall (m :: * -> *).
MonadIO m =>
ScryptParams -> Password -> m (PasswordHash Scrypt)
hashPasswordWithParams ScryptParams
defaultParams

-- TODO: Add way to parse the following. From [https://hashcat.net/wiki/doku.php?id=example_hashes]
-- SCRYPT:1024:1:1:MDIwMzMwNTQwNDQyNQ==:5FW+zWivLxgCWj7qLiQbeC8zaNQ+qdO0NUinvqyFcfo=

-- | Parameters used in the 'Scrypt' hashing algorithm.
--
-- @since 2.0.0.0
data ScryptParams = ScryptParams {
  ScryptParams -> Word32
scryptSalt :: Word32,
  -- ^ Bytes to randomly generate as a unique salt, default is __32__
  ScryptParams -> Word32
scryptRounds :: Word32,
  -- ^ log2(N) rounds to hash, default is __14__ (i.e. 2^14 rounds)
  ScryptParams -> Word32
scryptBlockSize :: Word32,
  -- ^ Block size, default is __8__
  --
  -- Limits are min: @1@, and max: @scryptBlockSize * scryptParallelism < 2 ^ 30@
  ScryptParams -> Word32
scryptParallelism :: Word32,
  -- ^ Parallelism factor, default is __1__
  --
  -- Limits are min: @0@, and max: @scryptBlockSize * scryptParallelism < 2 ^ 30@
  ScryptParams -> Word32
scryptOutputLength :: Word32
  -- ^ Output key length in bytes, default is __64__
} deriving (ScryptParams -> ScryptParams -> Bool
(ScryptParams -> ScryptParams -> Bool)
-> (ScryptParams -> ScryptParams -> Bool) -> Eq ScryptParams
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ScryptParams -> ScryptParams -> Bool
== :: ScryptParams -> ScryptParams -> Bool
$c/= :: ScryptParams -> ScryptParams -> Bool
/= :: ScryptParams -> ScryptParams -> Bool
Eq, Int -> ScryptParams -> ShowS
[ScryptParams] -> ShowS
ScryptParams -> String
(Int -> ScryptParams -> ShowS)
-> (ScryptParams -> String)
-> ([ScryptParams] -> ShowS)
-> Show ScryptParams
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ScryptParams -> ShowS
showsPrec :: Int -> ScryptParams -> ShowS
$cshow :: ScryptParams -> String
show :: ScryptParams -> String
$cshowList :: [ScryptParams] -> ShowS
showList :: [ScryptParams] -> ShowS
Show)

-- | Default parameters for the 'Scrypt' algorithm.
--
-- >>> defaultParams
-- ScryptParams {scryptSalt = 32, scryptRounds = 14, scryptBlockSize = 8, scryptParallelism = 1, scryptOutputLength = 64}
--
-- @since 2.0.0.0
defaultParams :: ScryptParams
defaultParams :: ScryptParams
defaultParams = ScryptParams {
  scryptSalt :: Word32
scryptSalt = Word32
32,
  scryptRounds :: Word32
scryptRounds = Word32
14,
  scryptBlockSize :: Word32
scryptBlockSize = Word32
8,
  scryptParallelism :: Word32
scryptParallelism = Word32
1,
  scryptOutputLength :: Word32
scryptOutputLength = Word32
64
}

-- | Hash a password with the given 'ScryptParams' and also with the given 'Salt'
-- instead of a randomly generated salt using 'scryptSalt' from 'ScryptParams'.
-- Using 'hashPasswordWithSalt' is strongly __disadvised__ and 'hashPasswordWithParams'
-- should be used instead. /Never use a static salt in production applications!/
--
-- The resulting 'PasswordHash' has the parameters used to hash it, as well as the
-- 'Salt' appended to it, separated by @|@.
--
-- The input 'Salt' and resulting 'PasswordHash' are both base64 encoded.
--
-- >>> let salt = Salt "abcdefghijklmnopqrstuvwxyz012345"
-- >>> hashPasswordWithSalt defaultParams salt (mkPassword "foobar")
-- PasswordHash {unPasswordHash = "14|8|1|YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXowMTIzNDU=|nENDaqWBmPKapAqQ3//H0iBImweGjoTqn5SvBS8Mc9FPFbzq6w65maYPZaO+SPamVZRXQjARQ8Y+5rhuDhjIhw=="}
--
-- (Note that we use an explicit 'Salt' in the example above.  This is so that the
-- example is reproducible, but in general you should use 'hashPassword'. 'hashPassword'
-- generates a new 'Salt' everytime it is called.)
hashPasswordWithSalt :: ScryptParams -> Salt Scrypt -> Password -> PasswordHash Scrypt
hashPasswordWithSalt :: ScryptParams -> Salt Scrypt -> Password -> PasswordHash Scrypt
hashPasswordWithSalt params :: ScryptParams
params@ScryptParams{Word32
scryptSalt :: ScryptParams -> Word32
scryptRounds :: ScryptParams -> Word32
scryptBlockSize :: ScryptParams -> Word32
scryptParallelism :: ScryptParams -> Word32
scryptOutputLength :: ScryptParams -> Word32
scryptSalt :: Word32
scryptRounds :: Word32
scryptBlockSize :: Word32
scryptParallelism :: Word32
scryptOutputLength :: Word32
..} s :: Salt Scrypt
s@(Salt ByteString
salt) Password
pass =
  Text -> PasswordHash Scrypt
forall a. Text -> PasswordHash a
PasswordHash (Text -> PasswordHash Scrypt) -> Text -> PasswordHash Scrypt
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
"|"
    [ Word32 -> Text
forall a. Show a => a -> Text
showT Word32
scryptRounds
    , Word32 -> Text
forall a. Show a => a -> Text
showT Word32
scryptBlockSize
    , Word32 -> Text
forall a. Show a => a -> Text
showT Word32
scryptParallelism
    , ByteString -> Text
toB64 ByteString
salt
    , ByteString -> Text
toB64 ByteString
key
    ]
  where
    key :: ByteString
key = ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' ScryptParams
params Salt Scrypt
s Password
pass
#if MIN_VERSION_base64(1,0,0)
    toB64 :: ByteString -> Text
toB64 = Base64 'StdPadded Text -> Text
forall (k :: Alphabet) a. Base64 k a -> a
extractBase64 (Base64 'StdPadded Text -> Text)
-> (ByteString -> Base64 'StdPadded Text) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Base64 'StdPadded Text
encodeBase64
#else
    toB64 = encodeBase64
#endif

-- | Only for internal use
hashPasswordWithSalt' :: ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' :: ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' ScryptParams{Word32
scryptSalt :: ScryptParams -> Word32
scryptRounds :: ScryptParams -> Word32
scryptBlockSize :: ScryptParams -> Word32
scryptParallelism :: ScryptParams -> Word32
scryptOutputLength :: ScryptParams -> Word32
scryptSalt :: Word32
scryptRounds :: Word32
scryptBlockSize :: Word32
scryptParallelism :: Word32
scryptOutputLength :: Word32
..} (Salt ByteString
salt) Password
pass =
    Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (Bytes
scryptHash :: Bytes)
  where
    scryptHash :: Bytes
scryptHash = Parameters -> Bytes -> Bytes -> Bytes
forall password salt output.
(ByteArrayAccess password, ByteArrayAccess salt,
 ByteArray output) =>
Parameters -> password -> salt -> output
Scrypt.generate
        Parameters
params
        (Text -> Bytes
toBytes (Text -> Bytes) -> Text -> Bytes
forall a b. (a -> b) -> a -> b
$ Password -> Text
unsafeShowPassword Password
pass)
        (ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
salt :: Bytes)
    params :: Parameters
params = Scrypt.Parameters {
        n :: Word64
n = Word64
2 Word64 -> Word32 -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^ Word32
scryptRounds,
        r :: Int
r = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
scryptBlockSize,
        p :: Int
p = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
scryptParallelism,
        outputLength :: Int
outputLength = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
scryptOutputLength
      }


-- | Hash a password using the 'Scrypt' algorithm with the given 'ScryptParams'.
--
-- __N.B.__: If you have any doubt in your knowledge of cryptography and/or the
-- 'Scrypt' algorithm, please just use 'hashPassword'.
--
-- Advice for setting the parameters:
--
-- * Memory used is about: @(2 ^ 'scryptRounds') * 'scryptBlockSize' * 128@
-- * Increasing 'scryptBlockSize' and 'scryptRounds' will increase CPU time
--   and memory used.
-- * Increasing 'scryptParallelism' will increase CPU time. (since this
--   implementation, like most, runs the 'scryptParallelism' parameter in
--   sequence, not in parallel)
--
-- @since 2.0.0.0
hashPasswordWithParams :: MonadIO m => ScryptParams -> Password -> m (PasswordHash Scrypt)
hashPasswordWithParams :: forall (m :: * -> *).
MonadIO m =>
ScryptParams -> Password -> m (PasswordHash Scrypt)
hashPasswordWithParams ScryptParams
params Password
pass = IO (PasswordHash Scrypt) -> m (PasswordHash Scrypt)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PasswordHash Scrypt) -> m (PasswordHash Scrypt))
-> IO (PasswordHash Scrypt) -> m (PasswordHash Scrypt)
forall a b. (a -> b) -> a -> b
$ do
    Salt Scrypt
salt <- Int -> IO (Salt Scrypt)
forall (m :: * -> *) a. MonadIO m => Int -> m (Salt a)
Data.Password.Internal.newSalt Int
saltLength
    PasswordHash Scrypt -> IO (PasswordHash Scrypt)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PasswordHash Scrypt -> IO (PasswordHash Scrypt))
-> PasswordHash Scrypt -> IO (PasswordHash Scrypt)
forall a b. (a -> b) -> a -> b
$ ScryptParams -> Salt Scrypt -> Password -> PasswordHash Scrypt
hashPasswordWithSalt ScryptParams
params Salt Scrypt
salt Password
pass
  where
    saltLength :: Int
saltLength = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ ScryptParams -> Word32
scryptSalt ScryptParams
params

-- | Check a 'Password' against a 'PasswordHash' 'Scrypt'.
--
-- Returns 'PasswordCheckSuccess' on success.
--
-- >>> let pass = mkPassword "foobar"
-- >>> passHash <- hashPassword pass
-- >>> checkPassword pass passHash
-- PasswordCheckSuccess
--
-- Returns 'PasswordCheckFail' if an incorrect 'Password' or 'PasswordHash' 'Scrypt' is used.
--
-- >>> let badpass = mkPassword "incorrect-password"
-- >>> checkPassword badpass passHash
-- PasswordCheckFail
--
-- This should always fail if an incorrect password is given.
--
-- prop> \(Blind badpass) -> let correctPasswordHash = hashPasswordWithSalt testParams salt "foobar" in checkPassword badpass correctPasswordHash == PasswordCheckFail
checkPassword :: Password -> PasswordHash Scrypt -> PasswordCheck
checkPassword :: Password -> PasswordHash Scrypt -> PasswordCheck
checkPassword Password
pass PasswordHash Scrypt
passHash =
  PasswordCheck -> Maybe PasswordCheck -> PasswordCheck
forall a. a -> Maybe a -> a
fromMaybe PasswordCheck
PasswordCheckFail (Maybe PasswordCheck -> PasswordCheck)
-> Maybe PasswordCheck -> PasswordCheck
forall a b. (a -> b) -> a -> b
$ do
    (ScryptParams
params, Salt Scrypt
salt, ByteString
hashedKey) <- PasswordHash Scrypt
-> Maybe (ScryptParams, Salt Scrypt, ByteString)
parseScryptPasswordHashParams PasswordHash Scrypt
passHash
    let producedKey :: ByteString
producedKey = ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' ScryptParams
params Salt Scrypt
salt Password
pass
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ ByteString
hashedKey ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
producedKey
    PasswordCheck -> Maybe PasswordCheck
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return PasswordCheck
PasswordCheckSuccess

parseScryptPasswordHashParams :: PasswordHash Scrypt -> Maybe (ScryptParams, Salt Scrypt, ByteString)
parseScryptPasswordHashParams :: PasswordHash Scrypt
-> Maybe (ScryptParams, Salt Scrypt, ByteString)
parseScryptPasswordHashParams (PasswordHash Text
passHash) =
    case [Text]
paramList of
        [Text
scryptRoundsT, Text
scryptBlockSizeT, Text
scryptParallelismT, Text
salt64, Text
hashedKey64] -> do
            Word32
scryptRounds <- Text -> Maybe Word32
forall a. Read a => Text -> Maybe a
readT Text
scryptRoundsT
            Word32
scryptBlockSize <- Text -> Maybe Word32
forall a. Read a => Text -> Maybe a
readT Text
scryptBlockSizeT
            Word32
scryptParallelism <- Text -> Maybe Word32
forall a. Read a => Text -> Maybe a
readT Text
scryptParallelismT
            ByteString
salt <- Text -> Maybe ByteString
from64 Text
salt64
            ByteString
hashedKey <- Text -> Maybe ByteString
from64 Text
hashedKey64
            let scryptOutputLength :: Word32
scryptOutputLength = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
C8.length ByteString
hashedKey
                scryptSalt :: Word32
scryptSalt = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
C8.length ByteString
salt
            (ScryptParams, Salt Scrypt, ByteString)
-> Maybe (ScryptParams, Salt Scrypt, ByteString)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (ScryptParams{Word32
scryptSalt :: Word32
scryptRounds :: Word32
scryptBlockSize :: Word32
scryptParallelism :: Word32
scryptOutputLength :: Word32
scryptRounds :: Word32
scryptBlockSize :: Word32
scryptParallelism :: Word32
scryptOutputLength :: Word32
scryptSalt :: Word32
..}, ByteString -> Salt Scrypt
forall a. ByteString -> Salt a
Salt ByteString
salt, ByteString
hashedKey)
        [Text]
_ -> Maybe (ScryptParams, Salt Scrypt, ByteString)
forall a. Maybe a
Nothing
  where
    paramList :: [Text]
paramList = (Char -> Bool) -> Text -> [Text]
T.split (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'|') Text
passHash

-- | Extracts 'ScryptParams' from a 'PasswordHash' 'Scrypt'.
--
-- Returns 'Just ScryptParams' on success.
--
-- >>> let pass = mkPassword "foobar"
-- >>> passHash <- hashPassword pass
-- >>> extractParams passHash == Just defaultParams
-- True
--
-- @since 3.0.2.0
extractParams :: PasswordHash Scrypt -> Maybe ScryptParams
extractParams :: PasswordHash Scrypt -> Maybe ScryptParams
extractParams PasswordHash Scrypt
passHash =
  (\(ScryptParams
params, Salt Scrypt
_, ByteString
_) -> ScryptParams
params) ((ScryptParams, Salt Scrypt, ByteString) -> ScryptParams)
-> Maybe (ScryptParams, Salt Scrypt, ByteString)
-> Maybe ScryptParams
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PasswordHash Scrypt
-> Maybe (ScryptParams, Salt Scrypt, ByteString)
parseScryptPasswordHashParams PasswordHash Scrypt
passHash

-- | Generate a random 32-byte @scrypt@ salt
--
-- @since 2.0.0.0
newSalt :: MonadIO m => m (Salt Scrypt)
newSalt :: forall (m :: * -> *). MonadIO m => m (Salt Scrypt)
newSalt = Int -> m (Salt Scrypt)
forall (m :: * -> *) a. MonadIO m => Int -> m (Salt a)
Data.Password.Internal.newSalt Int
32