{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wall -Werror #-}
module Documentation.SBV.Examples.Crypto.RC4 where
import Data.Char (ord, chr)
import Data.List (genericIndex)
import Data.Maybe (fromJust)
import Data.SBV
import Data.SBV.Tools.STree
import Numeric (showHex)
type S = STree Word8 Word8
initS :: S
initS :: S
initS = [SWord8] -> S
forall i e. HasKind i => [SBV e] -> STree i e
mkSTree ((Word8 -> SWord8) -> [Word8] -> [SWord8]
forall a b. (a -> b) -> [a] -> [b]
map Word8 -> SWord8
forall a. SymVal a => a -> SBV a
literal [Word8
0 .. Word8
255])
type Key = [SWord8]
type RC4 = (S, SWord8, SWord8)
swap :: SWord8 -> SWord8 -> S -> S
swap :: SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
st = S -> SWord8 -> SWord8 -> S
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree (S -> SWord8 -> SWord8 -> S
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree S
st SWord8
i SWord8
stj) SWord8
j SWord8
sti
where sti :: SWord8
sti = S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
i
stj :: SWord8
stj = S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
j
prga :: RC4 -> (SWord8, RC4)
prga :: RC4 -> (SWord8, RC4)
prga (S
st', SWord8
i', SWord8
j') = (S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
kInd, (S
st, SWord8
i, SWord8
j))
where i :: SWord8
i = SWord8
i' SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ SWord8
1
j :: SWord8
j = SWord8
j' SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st' SWord8
i
st :: S
st = SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
st'
kInd :: SWord8
kInd = S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
i SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
j
initRC4 :: Key -> S
initRC4 :: [SWord8] -> S
initRC4 [SWord8]
key
| Int
keyLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
keyLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
256
= [Char] -> S
forall a. HasCallStack => [Char] -> a
error ([Char] -> S) -> [Char] -> S
forall a b. (a -> b) -> a -> b
$ [Char]
"RC4 requires a key of length between 1 and 256, received: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
keyLength
| Bool
True
= (SWord8, S) -> S
forall a b. (a, b) -> b
snd ((SWord8, S) -> S) -> (SWord8, S) -> S
forall a b. (a -> b) -> a -> b
$ ((SWord8, S) -> SWord8 -> (SWord8, S))
-> (SWord8, S) -> [SWord8] -> (SWord8, S)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SWord8, S) -> SWord8 -> (SWord8, S)
mix (SWord8
0, S
initS) [SWord8
0..SWord8
255]
where keyLength :: Int
keyLength = [SWord8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord8]
key
mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
mix (SWord8
j', S
s) SWord8
i = let j :: SWord8
j = SWord8
j' SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
s SWord8
i SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ [SWord8] -> Word8 -> SWord8
forall i a. Integral i => [a] -> i -> a
genericIndex [SWord8]
key (Maybe Word8 -> Word8
forall a. HasCallStack => Maybe a -> a
fromJust (SWord8 -> Maybe Word8
forall a. SymVal a => SBV a -> Maybe a
unliteral SWord8
i) Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`mod` Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
keyLength)
in (SWord8
j, SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
s)
keySchedule :: Key -> [SWord8]
keySchedule :: [SWord8] -> [SWord8]
keySchedule [SWord8]
key = RC4 -> [SWord8]
genKeys ([SWord8] -> S
initRC4 [SWord8]
key, SWord8
0, SWord8
0)
where genKeys :: RC4 -> [SWord8]
genKeys :: RC4 -> [SWord8]
genKeys RC4
st = let (SWord8
k, RC4
st') = RC4 -> (SWord8, RC4)
prga RC4
st in SWord8
k SWord8 -> [SWord8] -> [SWord8]
forall a. a -> [a] -> [a]
: RC4 -> [SWord8]
genKeys RC4
st'
keyScheduleString :: String -> [SWord8]
keyScheduleString :: [Char] -> [SWord8]
keyScheduleString = [SWord8] -> [SWord8]
keySchedule ([SWord8] -> [SWord8])
-> ([Char] -> [SWord8]) -> [Char] -> [SWord8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> SWord8) -> [Char] -> [SWord8]
forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> SWord8
forall a. SymVal a => a -> SBV a
literal (Word8 -> SWord8) -> (Char -> Word8) -> Char -> SWord8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord)
encrypt :: String -> String -> [SWord8]
encrypt :: [Char] -> [Char] -> [SWord8]
encrypt [Char]
key [Char]
pt = (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor ([Char] -> [SWord8]
keyScheduleString [Char]
key) ((Char -> SWord8) -> [Char] -> [SWord8]
forall a b. (a -> b) -> [a] -> [b]
map Char -> SWord8
cvt [Char]
pt)
where cvt :: Char -> SWord8
cvt = Word8 -> SWord8
forall a. SymVal a => a -> SBV a
literal (Word8 -> SWord8) -> (Char -> Word8) -> Char -> SWord8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord
decrypt :: String -> [SWord8] -> String
decrypt :: [Char] -> [SWord8] -> [Char]
decrypt [Char]
key [SWord8]
ct = (SWord8 -> Char) -> [SWord8] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map SWord8 -> Char
cvt ([SWord8] -> [Char]) -> [SWord8] -> [Char]
forall a b. (a -> b) -> a -> b
$ (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor ([Char] -> [SWord8]
keyScheduleString [Char]
key) [SWord8]
ct
where cvt :: SWord8 -> Char
cvt = Int -> Char
chr (Int -> Char) -> (SWord8 -> Int) -> SWord8 -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> (SWord8 -> Word8) -> SWord8 -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Word8 -> Word8
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Word8 -> Word8)
-> (SWord8 -> Maybe Word8) -> SWord8 -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord8 -> Maybe Word8
forall a. SymVal a => SBV a -> Maybe a
unliteral
rc4IsCorrect :: IO ThmResult
rc4IsCorrect :: IO ThmResult
rc4IsCorrect = SymbolicT IO SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove (SymbolicT IO SBool -> IO ThmResult)
-> SymbolicT IO SBool -> IO ThmResult
forall a b. (a -> b) -> a -> b
$ do
key <- Int -> Symbolic [SWord8]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkFreeVars Int
5
pt <- mkFreeVars 5
let ks = [SWord8] -> [SWord8]
keySchedule [SWord8]
key
ct = (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor [SWord8]
ks [SWord8]
pt
pt' = (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor [SWord8]
ks [SWord8]
ct
return $ pt .== pt'
hex2 :: (SymVal a, Show a, Integral a) => SBV a -> String
hex2 :: forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex2 SBV a
v = Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Char] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
'0' [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
s
where s :: [Char]
s = (a -> [Char] -> [Char]) -> [Char] -> a -> [Char]
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> [Char] -> [Char]
forall a. Integral a => a -> [Char] -> [Char]
showHex [Char]
"" (a -> [Char]) -> (SBV a -> a) -> SBV a -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> (SBV a -> Maybe a) -> SBV a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBV a -> Maybe a
forall a. SymVal a => SBV a -> Maybe a
unliteral (SBV a -> [Char]) -> SBV a -> [Char]
forall a b. (a -> b) -> a -> b
$ SBV a
v