{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE PackageImports #-}
module Web.ClientSession
(
Key
, IV
, randomIV
, mkIV
, getKey
, getKeyEnv
, defaultKeyFile
, getDefaultKey
, initKey
, randomKey
, randomKeyEnv
, encrypt
, encryptIO
, decrypt
) where
import Control.Applicative ((<$>))
import Control.Concurrent (forkIO)
import Control.Monad (guard, when)
import Data.Bifunctor (first)
import Data.Function (on)
#if MIN_VERSION_base(4,7,0)
import System.Environment (lookupEnv, setEnv)
#elif MIN_VERSION_base(4,6,0)
import System.Environment (lookupEnv)
import System.SetEnv (setEnv)
#else
import System.LookupEnv (lookupEnv)
import System.SetEnv (setEnv)
#endif
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.IORef as I
import System.Directory (doesFileExist)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Base64 as B
import Data.Serialize (encode, Serialize (put, get), getBytes, putByteString)
import Data.Tagged (Tagged, untag)
import Crypto.Classes (constTimeEq)
import qualified Crypto.Cipher.AES as A
import Crypto.Cipher.Types(Cipher(..),BlockCipher(..),makeIV)
import Crypto.Error (eitherCryptoError)
import "crypton" Crypto.Random (ChaChaDRG,drgNew,randomBytesGenerate)
import Crypto.Skein (skeinMAC', Skein_512_256)
import System.Entropy (getEntropy)
data Key = Key { Key -> AES256
aesKey ::
!A.AES256
, Key -> ByteString -> Skein_512_256
macKey :: !(S.ByteString -> Skein_512_256)
, Key -> ByteString
keyRaw :: !S.ByteString
}
instance Eq Key where
Key AES256
_ ByteString -> Skein_512_256
_ ByteString
r1 == :: Key -> Key -> Bool
== Key AES256
_ ByteString -> Skein_512_256
_ ByteString
r2 = ByteString
r1 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
r2
instance Serialize Key where
put :: Putter Key
put = Putter ByteString
putByteString Putter ByteString -> (Key -> ByteString) -> Putter Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key -> ByteString
keyRaw
get :: Get Key
get = (String -> Key) -> (Key -> Key) -> Either String Key -> Key
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> Key
forall a. HasCallStack => String -> a
error Key -> Key
forall a. a -> a
id (Either String Key -> Key)
-> (ByteString -> Either String Key) -> ByteString -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String Key
initKey (ByteString -> Key) -> Get ByteString -> Get Key
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get ByteString
getBytes Int
96
instance Show Key where
show :: Key -> String
show Key
_ = String
"<Web.ClientSession.Key>"
newtype IV = IV S.ByteString
unsafeMkIV :: S.ByteString -> IV
unsafeMkIV :: ByteString -> IV
unsafeMkIV ByteString
bs = (ByteString -> IV
IV ByteString
bs)
unIV :: IV -> S.ByteString
unIV :: IV -> ByteString
unIV (IV ByteString
bs) = ByteString
bs
instance Eq IV where
== :: IV -> IV -> Bool
(==) = ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
(==) (ByteString -> ByteString -> Bool)
-> (IV -> ByteString) -> IV -> IV -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
/= :: IV -> IV -> Bool
(/=) = ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
(/=) (ByteString -> ByteString -> Bool)
-> (IV -> ByteString) -> IV -> IV -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
instance Ord IV where
compare :: IV -> IV -> Ordering
compare = ByteString -> ByteString -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (ByteString -> ByteString -> Ordering)
-> (IV -> ByteString) -> IV -> IV -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
<= :: IV -> IV -> Bool
(<=) = ByteString -> ByteString -> Bool
forall a. Ord a => a -> a -> Bool
(<=) (ByteString -> ByteString -> Bool)
-> (IV -> ByteString) -> IV -> IV -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
< :: IV -> IV -> Bool
(<) = ByteString -> ByteString -> Bool
forall a. Ord a => a -> a -> Bool
(<) (ByteString -> ByteString -> Bool)
-> (IV -> ByteString) -> IV -> IV -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
>= :: IV -> IV -> Bool
(>=) = ByteString -> ByteString -> Bool
forall a. Ord a => a -> a -> Bool
(>=) (ByteString -> ByteString -> Bool)
-> (IV -> ByteString) -> IV -> IV -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
> :: IV -> IV -> Bool
(>) = ByteString -> ByteString -> Bool
forall a. Ord a => a -> a -> Bool
(>) (ByteString -> ByteString -> Bool)
-> (IV -> ByteString) -> IV -> IV -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` IV -> ByteString
unIV
instance Show IV where
show :: IV -> String
show = ByteString -> String
forall a. Show a => a -> String
show (ByteString -> String) -> (IV -> ByteString) -> IV -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IV -> ByteString
unIV
instance Serialize IV where
put :: Putter IV
put = Putter ByteString
forall t. Serialize t => Putter t
put Putter ByteString -> (IV -> ByteString) -> Putter IV
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IV -> ByteString
unIV
get :: Get IV
get = ByteString -> IV
unsafeMkIV (ByteString -> IV) -> Get ByteString -> Get IV
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
forall t. Serialize t => Get t
get
mkIV :: S.ByteString -> Maybe IV
mkIV :: ByteString -> Maybe IV
mkIV ByteString
bs | ByteString -> Int
S.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
16 = IV -> Maybe IV
forall a. a -> Maybe a
Just (ByteString -> IV
unsafeMkIV ByteString
bs)
| Bool
otherwise = Maybe IV
forall a. Maybe a
Nothing
randomIV :: IO IV
randomIV :: IO IV
randomIV = IO IV
chaChaRNG
defaultKeyFile :: FilePath
defaultKeyFile :: String
defaultKeyFile = String
"client_session_key.aes"
getDefaultKey :: IO Key
getDefaultKey :: IO Key
getDefaultKey = String -> IO Key
getKey String
defaultKeyFile
getKey :: FilePath
-> IO Key
getKey :: String -> IO Key
getKey String
keyFile = do
Bool
exists <- String -> IO Bool
doesFileExist String
keyFile
if Bool
exists
then String -> IO ByteString
S.readFile String
keyFile IO ByteString -> (ByteString -> IO Key) -> IO Key
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (String -> IO Key)
-> (Key -> IO Key) -> Either String Key -> IO Key
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (IO Key -> String -> IO Key
forall a b. a -> b -> a
const IO Key
newKey) Key -> IO Key
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String Key -> IO Key)
-> (ByteString -> Either String Key) -> ByteString -> IO Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String Key
initKey
else IO Key
newKey
where
newKey :: IO Key
newKey = do
(ByteString
bs, Key
key') <- IO (ByteString, Key)
randomKey
String -> ByteString -> IO ()
S.writeFile String
keyFile ByteString
bs
Key -> IO Key
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Key
key'
getKeyEnv :: String
-> IO Key
getKeyEnv :: String -> IO Key
getKeyEnv String
envVar = do
Maybe String
mvalue <- String -> IO (Maybe String)
lookupEnv String
envVar
case Maybe String
mvalue of
Just String
value -> (String -> IO Key)
-> (Key -> IO Key) -> Either String Key -> IO Key
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (IO Key -> String -> IO Key
forall a b. a -> b -> a
const IO Key
newKey) Key -> IO Key
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String Key -> IO Key) -> Either String Key -> IO Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String Key
initKey (ByteString -> Either String Key)
-> Either String ByteString -> Either String Key
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Either String ByteString
decode String
value
Maybe String
Nothing -> IO Key
newKey
where
decode :: String -> Either String ByteString
decode = ByteString -> Either String ByteString
B.decode (ByteString -> Either String ByteString)
-> (String -> ByteString) -> String -> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
C.pack
newKey :: IO Key
newKey = String -> IO Key
randomKeyEnv String
envVar
randomKey :: IO (S.ByteString, Key)
randomKey :: IO (ByteString, Key)
randomKey = do
ByteString
bs <- Int -> IO ByteString
getEntropy Int
96
case ByteString -> Either String Key
initKey ByteString
bs of
Left String
e -> String -> IO (ByteString, Key)
forall a. HasCallStack => String -> a
error (String -> IO (ByteString, Key)) -> String -> IO (ByteString, Key)
forall a b. (a -> b) -> a -> b
$ String
"Web.ClientSession.randomKey: never here, " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
e
Right Key
key -> (ByteString, Key) -> IO (ByteString, Key)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, Key
key)
randomKeyEnv :: String -> IO Key
randomKeyEnv :: String -> IO Key
randomKeyEnv String
envVar = do
(ByteString
bs, Key
key) <- IO (ByteString, Key)
randomKey
let encoded :: String
encoded = ByteString -> String
C.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.encode ByteString
bs
String -> String -> IO ()
setEnv String
envVar String
encoded
String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
envVar String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
encoded
Key -> IO Key
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Key
key
initKey :: S.ByteString -> Either String Key
initKey :: ByteString -> Either String Key
initKey ByteString
bs | ByteString -> Int
S.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
96 = String -> Either String Key
forall a b. a -> Either a b
Left (String -> Either String Key) -> String -> Either String Key
forall a b. (a -> b) -> a -> b
$ String
"Web.ClientSession.initKey: length of " String -> ShowS
forall a. [a] -> [a] -> [a]
++
Int -> String
forall a. Show a => a -> String
show (ByteString -> Int
S.length ByteString
bs) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" /= 96."
initKey ByteString
bs = do
let (ByteString
preMacKey, ByteString
preAesKey) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
64 ByteString
bs
AES256
aesKey <- (CryptoError -> String)
-> Either CryptoError AES256 -> Either String AES256
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> String
forall a. Show a => a -> String
show (Either CryptoError AES256 -> Either String AES256)
-> Either CryptoError AES256 -> Either String AES256
forall a b. (a -> b) -> a -> b
$ CryptoFailable AES256 -> Either CryptoError AES256
forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError (ByteString -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit ByteString
preAesKey)
Key -> Either String Key
forall a b. b -> Either a b
Right (Key -> Either String Key) -> Key -> Either String Key
forall a b. (a -> b) -> a -> b
$ Key { AES256
aesKey :: AES256
aesKey :: AES256
aesKey
, macKey :: ByteString -> Skein_512_256
macKey = ByteString -> ByteString -> Skein_512_256
forall skeinCtx digest.
(SkeinMAC skeinCtx, Hash skeinCtx digest) =>
ByteString -> ByteString -> digest
skeinMAC' ByteString
preMacKey
, keyRaw :: ByteString
keyRaw = ByteString
bs
}
encryptIO :: Key -> S.ByteString -> IO S.ByteString
encryptIO :: Key -> ByteString -> IO ByteString
encryptIO Key
key ByteString
x = do
IV
iv <- IO IV
randomIV
ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Key -> IV -> ByteString -> ByteString
encrypt Key
key IV
iv ByteString
x
encrypt :: Key
-> IV
-> S.ByteString
-> S.ByteString
encrypt :: Key -> IV -> ByteString -> ByteString
encrypt Key
key (IV ByteString
b) ByteString
x = case ByteString -> Maybe (IV AES256)
forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
b of
Maybe (IV AES256)
Nothing -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"Web.ClientSession.encrypt: Failed to makeIV"
Just IV AES256
iv -> ByteString -> ByteString
B.encode ByteString
final
where
encrypted :: ByteString
encrypted = AES256 -> IV AES256 -> ByteString -> ByteString
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
forall ba. ByteArray ba => AES256 -> IV AES256 -> ba -> ba
ctrCombine (Key -> AES256
aesKey Key
key) IV AES256
iv ByteString
x
toBeAuthed :: ByteString
toBeAuthed = ByteString
b ByteString -> ByteString -> ByteString
`S.append` ByteString
encrypted
auth :: Skein_512_256
auth = Key -> ByteString -> Skein_512_256
macKey Key
key ByteString
toBeAuthed
final :: ByteString
final = Skein_512_256 -> ByteString
forall a. Serialize a => a -> ByteString
encode Skein_512_256
auth ByteString -> ByteString -> ByteString
`S.append` ByteString
toBeAuthed
decrypt :: Key
-> S.ByteString
-> Maybe S.ByteString
decrypt :: Key -> ByteString -> Maybe ByteString
decrypt Key
key ByteString
dataBS64 = do
ByteString
dataBS <- (String -> Maybe ByteString)
-> (ByteString -> Maybe ByteString)
-> Either String ByteString
-> Maybe ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe ByteString -> String -> Maybe ByteString
forall a b. a -> b -> a
const Maybe ByteString
forall a. Maybe a
Nothing) ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Either String ByteString -> Maybe ByteString)
-> Either String ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String ByteString
B.decode ByteString
dataBS64
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> Int
S.length ByteString
dataBS Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
48)
let (ByteString
auth, ByteString
toBeAuthed) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
32 ByteString
dataBS
auth' :: Skein_512_256
auth' = Key -> ByteString -> Skein_512_256
macKey Key
key ByteString
toBeAuthed
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Skein_512_256 -> ByteString
forall a. Serialize a => a -> ByteString
encode Skein_512_256
auth' ByteString -> ByteString -> Bool
`constTimeEq` ByteString
auth)
let (ByteString
iv, ByteString
encrypted) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
16 ByteString
toBeAuthed
IV AES256
iv' <- ByteString -> Maybe (IV AES256)
forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
iv
ByteString -> Maybe ByteString
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$! AES256 -> IV AES256 -> ByteString -> ByteString
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
forall ba. ByteArray ba => AES256 -> IV AES256 -> ba -> ba
ctrCombine (Key -> AES256
aesKey Key
key) IV AES256
iv' ByteString
encrypted
data ChaChaState =
CCSt {-# UNPACK #-} !ChaChaDRG
{-# UNPACK #-} !Int
chaChaSeed :: IO ChaChaState
chaChaSeed :: IO ChaChaState
chaChaSeed = do
ChaChaDRG
drg <- IO ChaChaDRG
forall (randomly :: * -> *).
MonadRandom randomly =>
randomly ChaChaDRG
drgNew
ChaChaState -> IO ChaChaState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ChaChaState -> IO ChaChaState) -> ChaChaState -> IO ChaChaState
forall a b. (a -> b) -> a -> b
$! ChaChaDRG -> Int -> ChaChaState
CCSt ChaChaDRG
drg Int
0
chaChaReseed :: IO ()
chaChaReseed :: IO ()
chaChaReseed = do
ChaChaDRG
drg' <- IO ChaChaDRG
forall (randomly :: * -> *).
MonadRandom randomly =>
randomly ChaChaDRG
drgNew
IORef ChaChaState -> ChaChaState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChaChaState
chaChaRef (ChaChaState -> IO ()) -> ChaChaState -> IO ()
forall a b. (a -> b) -> a -> b
$ ChaChaDRG -> Int -> ChaChaState
CCSt ChaChaDRG
drg' Int
0
chaChaRef :: I.IORef ChaChaState
chaChaRef :: IORef ChaChaState
chaChaRef = IO (IORef ChaChaState) -> IORef ChaChaState
forall a. IO a -> a
unsafePerformIO (IO (IORef ChaChaState) -> IORef ChaChaState)
-> IO (IORef ChaChaState) -> IORef ChaChaState
forall a b. (a -> b) -> a -> b
$ IO ChaChaState
chaChaSeed IO ChaChaState
-> (ChaChaState -> IO (IORef ChaChaState))
-> IO (IORef ChaChaState)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ChaChaState -> IO (IORef ChaChaState)
forall a. a -> IO (IORef a)
I.newIORef
{-# NOINLINE chaChaRef #-}
chaChaRNG :: IO IV
chaChaRNG :: IO IV
chaChaRNG = do
(ByteString
bs, Int
count) <-
IORef ChaChaState
-> (ChaChaState -> (ChaChaState, (ByteString, Int)))
-> IO (ByteString, Int)
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef IORef ChaChaState
chaChaRef ((ChaChaState -> (ChaChaState, (ByteString, Int)))
-> IO (ByteString, Int))
-> (ChaChaState -> (ChaChaState, (ByteString, Int)))
-> IO (ByteString, Int)
forall a b. (a -> b) -> a -> b
$ \(CCSt ChaChaDRG
drg Int
count) ->
let (ByteString
bs', ChaChaDRG
drg') = Int -> ChaChaDRG -> (ByteString, ChaChaDRG)
forall gen byteArray.
(DRG gen, ByteArray byteArray) =>
Int -> gen -> (byteArray, gen)
forall byteArray.
ByteArray byteArray =>
Int -> ChaChaDRG -> (byteArray, ChaChaDRG)
randomBytesGenerate Int
16 ChaChaDRG
drg
in (ChaChaDRG -> Int -> ChaChaState
CCSt ChaChaDRG
drg' (Int -> Int
forall a. Enum a => a -> a
succ Int
count), (ByteString
bs', Int
count))
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
threshold) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO ThreadId -> IO ()
forall {m :: * -> *} {a}. Monad m => m a -> m ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO IO ()
chaChaReseed
IV -> IO IV
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IV -> IO IV) -> IV -> IO IV
forall a b. (a -> b) -> a -> b
$! ByteString -> IV
unsafeMkIV ByteString
bs
where
void :: m a -> m ()
void m a
f = m a
f m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
threshold :: Int
threshold :: Int
threshold = Int
100000