-- |
-- Module      : Crypto.Number.Generate
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
module Crypto.Number.Generate (
    GenTopPolicy (..),
    generateParams,
    generatePrefix,
    generateMax,
    generateBetween,
) where

import Control.Monad (when)
import Crypto.Internal.ByteArray (ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Imports
import Crypto.Number.Basic
import Crypto.Number.Serialize
import Crypto.Random.Types
import Data.Bits (complement, shiftL, testBit, unsafeShiftR, (.&.), (.|.))
import Foreign.Ptr
import Foreign.Storable

-- | Top bits policy when generating a number
data GenTopPolicy
    = -- | set the highest bit
      SetHighest
    | -- | set the two highest bit
      SetTwoHighest
    deriving (Int -> GenTopPolicy -> ShowS
[GenTopPolicy] -> ShowS
GenTopPolicy -> String
(Int -> GenTopPolicy -> ShowS)
-> (GenTopPolicy -> String)
-> ([GenTopPolicy] -> ShowS)
-> Show GenTopPolicy
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GenTopPolicy -> ShowS
showsPrec :: Int -> GenTopPolicy -> ShowS
$cshow :: GenTopPolicy -> String
show :: GenTopPolicy -> String
$cshowList :: [GenTopPolicy] -> ShowS
showList :: [GenTopPolicy] -> ShowS
Show, GenTopPolicy -> GenTopPolicy -> Bool
(GenTopPolicy -> GenTopPolicy -> Bool)
-> (GenTopPolicy -> GenTopPolicy -> Bool) -> Eq GenTopPolicy
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: GenTopPolicy -> GenTopPolicy -> Bool
== :: GenTopPolicy -> GenTopPolicy -> Bool
$c/= :: GenTopPolicy -> GenTopPolicy -> Bool
/= :: GenTopPolicy -> GenTopPolicy -> Bool
Eq)

-- | Generate a number for a specific size of bits,
-- and optionaly set bottom and top bits
--
-- If the top bit policy is 'Nothing', then nothing is
-- done on the highest bit (it's whatever the random generator set).
--
-- If @generateOdd is set to 'True', then the number generated
-- is guaranteed to be odd. Otherwise it will be whatever is generated
generateParams
    :: MonadRandom m
    => Int
    -- ^ number of bits
    -> Maybe GenTopPolicy
    -- ^ top bit policy
    -> Bool
    -- ^ force the number to be odd
    -> m Integer
generateParams :: forall (m :: * -> *).
MonadRandom m =>
Int -> Maybe GenTopPolicy -> Bool -> m Integer
generateParams Int
bits Maybe GenTopPolicy
genTopPolicy Bool
generateOdd
    | Int
bits Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
0
    | Bool
otherwise = ScrubbedBytes -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (ScrubbedBytes -> Integer)
-> (ScrubbedBytes -> ScrubbedBytes) -> ScrubbedBytes -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> ScrubbedBytes
tweak (ScrubbedBytes -> Integer) -> m ScrubbedBytes -> m Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
bytes
  where
    tweak :: ScrubbedBytes -> ScrubbedBytes
    tweak :: ScrubbedBytes -> ScrubbedBytes
tweak ScrubbedBytes
orig = ScrubbedBytes -> (Ptr Word8 -> IO ()) -> ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze ScrubbedBytes
orig ((Ptr Word8 -> IO ()) -> ScrubbedBytes)
-> (Ptr Word8 -> IO ()) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p0 -> do
        let p1 :: Ptr b
p1 = Ptr Word8
p0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
            pEnd :: Ptr b
pEnd = Ptr Word8
p0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        case Maybe GenTopPolicy
genTopPolicy of
            Maybe GenTopPolicy
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just GenTopPolicy
SetHighest -> Ptr Word8
p0 Ptr Word8 -> Word8 -> IO ()
|= (Word8
1 Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftL` Int
bit)
            Just GenTopPolicy
SetTwoHighest
                | Int
bit Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 -> do
                    Ptr Word8
p0 Ptr Word8 -> Word8 -> IO ()
$= Word8
0x1
                    Ptr Word8
forall {b}. Ptr b
p1 Ptr Word8 -> Word8 -> IO ()
|= Word8
0x80
                | Bool
otherwise -> Ptr Word8
p0 Ptr Word8 -> Word8 -> IO ()
|= (Word8
0x3 Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftL` (Int
bit Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
        Ptr Word8
p0 Ptr Word8 -> Word8 -> IO ()
&= (Word8 -> Word8
forall a. Bits a => a -> a
complement (Word8 -> Word8) -> Word8 -> Word8
forall a b. (a -> b) -> a -> b
$ Word8
mask)
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
generateOdd (Ptr Word8
forall {b}. Ptr b
pEnd Ptr Word8 -> Word8 -> IO ()
|= Word8
0x1)

    ($=) :: Ptr Word8 -> Word8 -> IO ()
    $= :: Ptr Word8 -> Word8 -> IO ()
($=) Ptr Word8
p Word8
w = Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
w

    (|=) :: Ptr Word8 -> Word8 -> IO ()
    |= :: Ptr Word8 -> Word8 -> IO ()
(|=) Ptr Word8
p Word8
w = Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p IO Word8 -> (Word8 -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word8
v -> Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p (Word8
v Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
w)

    (&=) :: Ptr Word8 -> Word8 -> IO ()
    &= :: Ptr Word8 -> Word8 -> IO ()
(&=) Ptr Word8
p Word8
w = Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p IO Word8 -> (Word8 -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word8
v -> Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p (Word8
v Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
w)

    bytes :: Int
bytes = (Int
bits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
    bit :: Int
bit = (Int
bits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8
    mask :: Word8
mask = Word8
0xff Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftL` (Int
bit Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | Generate a number for a specific size of bits.
--
-- * @'generateParams' n Nothing False@ generates bytes and uses the suffix of @n@ bits
-- * @'generatePrefix' n@ generates bytes and uses the prefix of @n@ bits
generatePrefix :: MonadRandom m => Int -> m Integer
generatePrefix :: forall (m :: * -> *). MonadRandom m => Int -> m Integer
generatePrefix Int
bits
    | Int
bits Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
0
    | Bool
otherwise = do
        let (Int
count, Int
offset) = (Int
bits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
8
        ScrubbedBytes
bytes <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
count
        Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> m Integer) -> Integer -> m Integer
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (ScrubbedBytes
bytes :: ScrubbedBytes) Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (Int
7 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
offset)

-- | Generate a positive integer x, s.t. 0 <= x < range
generateMax
    :: MonadRandom m
    => Integer
    -- ^ range
    -> m Integer
generateMax :: forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax Integer
range
    | Integer
range Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
1 = Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
0
    | Integer
range Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
127 = m Integer
generateSimple
    | Bool
canOverGenerate = Int -> m Integer
forall {t} {m :: * -> *}.
(Eq t, Num t, MonadRandom m) =>
t -> m Integer
loopGenerateOver Int
tries
    | Bool
otherwise = Int -> m Integer
forall {t} {m :: * -> *}.
(Eq t, Num t, MonadRandom m) =>
t -> m Integer
loopGenerate Int
tries
  where
    -- this "generator" is mostly for quickcheck benefits. it'll be biased if
    -- range is not a multiple of 2, but overall, no security should be
    -- assumed for a number between 0 and 127.
    generateSimple :: m Integer
generateSimple = (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer
forall a b c. (a -> b -> c) -> b -> a -> c
flip Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
mod Integer
range (Integer -> Integer) -> m Integer -> m Integer
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Int -> Maybe GenTopPolicy -> Bool -> m Integer
forall (m :: * -> *).
MonadRandom m =>
Int -> Maybe GenTopPolicy -> Bool -> m Integer
generateParams Int
bits Maybe GenTopPolicy
forall a. Maybe a
Nothing Bool
False

    loopGenerate :: t -> m Integer
loopGenerate t
count
        | t
count t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 =
            String -> m Integer
forall a. HasCallStack => String -> a
error (String -> m Integer) -> String -> m Integer
forall a b. (a -> b) -> a -> b
$
                String
"internal: generateMax("
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
range
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" bits="
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
bits
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") (normal) doesn't seems to work properly"
        | Bool
otherwise = do
            Integer
r <- Int -> Maybe GenTopPolicy -> Bool -> m Integer
forall (m :: * -> *).
MonadRandom m =>
Int -> Maybe GenTopPolicy -> Bool -> m Integer
generateParams Int
bits Maybe GenTopPolicy
forall a. Maybe a
Nothing Bool
False
            if Integer -> Bool
isValid Integer
r then Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
r else t -> m Integer
loopGenerate (t
count t -> t -> t
forall a. Num a => a -> a -> a
- t
1)

    loopGenerateOver :: t -> m Integer
loopGenerateOver t
count
        | t
count t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 =
            String -> m Integer
forall a. HasCallStack => String -> a
error (String -> m Integer) -> String -> m Integer
forall a b. (a -> b) -> a -> b
$
                String
"internal: generateMax("
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
range
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" bits="
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
bits
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") (over) doesn't seems to work properly"
        | Bool
otherwise = do
            Integer
r <- Int -> Maybe GenTopPolicy -> Bool -> m Integer
forall (m :: * -> *).
MonadRandom m =>
Int -> Maybe GenTopPolicy -> Bool -> m Integer
generateParams (Int
bits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Maybe GenTopPolicy
forall a. Maybe a
Nothing Bool
False
            let r2 :: Integer
r2 = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
range
                r3 :: Integer
r3 = Integer
r2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
range
            if Integer -> Bool
isValid Integer
r
                then Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
r
                else
                    if Integer -> Bool
isValid Integer
r2
                        then Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
r2
                        else
                            if Integer -> Bool
isValid Integer
r3
                                then Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
r3
                                else t -> m Integer
loopGenerateOver (t
count t -> t -> t
forall a. Num a => a -> a -> a
- t
1)

    bits :: Int
bits = Integer -> Int
numBits Integer
range
    canOverGenerate :: Bool
canOverGenerate =
        Int
bits Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
3 Bool -> Bool -> Bool
&& Bool -> Bool
not (Integer
range Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` (Int
bits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)) Bool -> Bool -> Bool
&& Bool -> Bool
not (Integer
range Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` (Int
bits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
3))

    isValid :: Integer -> Bool
isValid Integer
n = Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
range

    tries :: Int
    tries :: Int
tries = Int
100

-- | generate a number between the inclusive bound [low,high].
generateBetween :: MonadRandom m => Integer -> Integer -> m Integer
generateBetween :: forall (m :: * -> *).
MonadRandom m =>
Integer -> Integer -> m Integer
generateBetween Integer
low Integer
high = (Integer
low Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+) (Integer -> Integer) -> m Integer -> m Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> m Integer
forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax (Integer
high Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
low Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)