{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Crypto.MAC.CMAC
-- License     : BSD-style
-- Maintainer  : Kei Hibino <ex8k.hibino@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Provide the CMAC (Cipher based Message Authentification Code) base algorithm.
-- <http://en.wikipedia.org/wiki/CMAC>
-- <http://csrc.nist.gov/publications/nistpubs/800-38B/SP_800-38B.pdf>
module Crypto.MAC.CMAC (
    cmac,
    CMAC,
    subKeys,
) where

import Data.Bits (setBit, shiftL, testBit)
import Data.List (foldl')
import Data.Word

import Crypto.Cipher.Types
import Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess, Bytes)
import qualified Crypto.Internal.ByteArray as B

-- | Authentication code
newtype CMAC a = CMAC Bytes
    deriving (CMAC a -> Int
(CMAC a -> Int)
-> (forall p a. CMAC a -> (Ptr p -> IO a) -> IO a)
-> (forall p. CMAC a -> Ptr p -> IO ())
-> ByteArrayAccess (CMAC a)
forall a. CMAC a -> Int
forall p. CMAC a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall a p. CMAC a -> Ptr p -> IO ()
forall p a. CMAC a -> (Ptr p -> IO a) -> IO a
forall a p a. CMAC a -> (Ptr p -> IO a) -> IO a
$clength :: forall a. CMAC a -> Int
length :: CMAC a -> Int
$cwithByteArray :: forall a p a. CMAC a -> (Ptr p -> IO a) -> IO a
withByteArray :: forall p a. CMAC a -> (Ptr p -> IO a) -> IO a
$ccopyByteArrayToPtr :: forall a p. CMAC a -> Ptr p -> IO ()
copyByteArrayToPtr :: forall p. CMAC a -> Ptr p -> IO ()
ByteArrayAccess)

instance Eq (CMAC a) where
    CMAC Bytes
b1 == :: CMAC a -> CMAC a -> Bool
== CMAC Bytes
b2 = Bytes -> Bytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
B.constEq Bytes
b1 Bytes
b2

-- | compute a MAC using the supplied cipher
cmac
    :: (ByteArrayAccess bin, BlockCipher cipher)
    => cipher
    -- ^ key to compute CMAC with
    -> bin
    -- ^ input message
    -> CMAC cipher
    -- ^ output tag
cmac :: forall bin cipher.
(ByteArrayAccess bin, BlockCipher cipher) =>
cipher -> bin -> CMAC cipher
cmac cipher
k bin
msg =
    Bytes -> CMAC cipher
forall a. Bytes -> CMAC a
CMAC (Bytes -> CMAC cipher) -> Bytes -> CMAC cipher
forall a b. (a -> b) -> a -> b
$ (Bytes -> Bytes -> Bytes) -> Bytes -> [Bytes] -> Bytes
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Bytes
c Bytes
m -> cipher -> Bytes -> Bytes
forall ba. ByteArray ba => cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt cipher
k (Bytes -> Bytes) -> Bytes -> Bytes
forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes -> Bytes
forall ba. ByteArray ba => ba -> ba -> ba
bxor Bytes
c Bytes
m) Bytes
zeroV [Bytes]
ms
  where
    bytes :: Int
bytes = cipher -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize cipher
k
    zeroV :: Bytes
zeroV = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
bytes Word8
0 :: Bytes
    (Bytes
k1, Bytes
k2) = cipher -> (Bytes, Bytes)
forall k ba. (BlockCipher k, ByteArray ba) => k -> (ba, ba)
subKeys cipher
k
    ms :: [Bytes]
ms = cipher -> Bytes -> Bytes -> Bytes -> [Bytes]
forall k ba.
(BlockCipher k, ByteArray ba) =>
k -> ba -> ba -> ba -> [ba]
cmacChunks cipher
k Bytes
k1 Bytes
k2 (Bytes -> [Bytes]) -> Bytes -> [Bytes]
forall a b. (a -> b) -> a -> b
$ bin -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert bin
msg

cmacChunks :: (BlockCipher k, ByteArray ba) => k -> ba -> ba -> ba -> [ba]
cmacChunks :: forall k ba.
(BlockCipher k, ByteArray ba) =>
k -> ba -> ba -> ba -> [ba]
cmacChunks k
k ba
k1 ba
k2 = ba -> [ba]
rec'
  where
    rec' :: ba -> [ba]
rec' ba
msg
        | ba -> Bool
forall a. ByteArrayAccess a => a -> Bool
B.null ba
tl =
            if Int
lack Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                then [ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
bxor ba
k1 ba
hd]
                else [ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
bxor ba
k2 (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ ba
hd ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
`B.append` [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack (Word8
0x80 Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate (Int
lack Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Word8
0)]
        | Bool
otherwise = ba
hd ba -> [ba] -> [ba]
forall a. a -> [a] -> [a]
: ba -> [ba]
rec' ba
tl
      where
        bytes :: Int
bytes = k -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize k
k
        (ba
hd, ba
tl) = Int -> ba -> (ba, ba)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
bytes ba
msg
        lack :: Int
lack = Int
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
hd

-- | make sub-keys used in CMAC
subKeys
    :: (BlockCipher k, ByteArray ba)
    => k
    -- ^ key to compute CMAC with
    -> (ba, ba)
    -- ^ sub-keys to compute CMAC
subKeys :: forall k ba. (BlockCipher k, ByteArray ba) => k -> (ba, ba)
subKeys k
k = (ba
k1, ba
k2)
  where
    ipt :: [Word8]
ipt = k -> [Word8]
forall k. BlockCipher k => k -> [Word8]
cipherIPT k
k
    k0 :: ba
k0 = k -> ba -> ba
forall ba. ByteArray ba => k -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt k
k (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ Int -> Word8 -> ba
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate (k -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize k
k) Word8
0
    k1 :: ba
k1 = [Word8] -> ba -> ba
forall ba. ByteArray ba => [Word8] -> ba -> ba
subKey [Word8]
ipt ba
k0
    k2 :: ba
k2 = [Word8] -> ba -> ba
forall ba. ByteArray ba => [Word8] -> ba -> ba
subKey [Word8]
ipt ba
k1

-- polynomial multiply operation to culculate subkey
subKey :: ByteArray ba => [Word8] -> ba -> ba
subKey :: forall ba. ByteArray ba => [Word8] -> ba -> ba
subKey [Word8]
ipt ba
ws = case ba -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack ba
ws of
    [] -> ba
forall a. ByteArray a => a
B.empty
    Word8
w : [Word8]
_
        | Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word8
w Int
7 -> [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack [Word8]
ipt ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
`bxor` ba -> ba
forall ba. ByteArray ba => ba -> ba
shiftL1 ba
ws
        | Bool
otherwise -> ba -> ba
forall ba. ByteArray ba => ba -> ba
shiftL1 ba
ws

shiftL1 :: ByteArray ba => ba -> ba
shiftL1 :: forall ba. ByteArray ba => ba -> ba
shiftL1 = [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack ([Word8] -> ba) -> (ba -> [Word8]) -> ba -> ba
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> [Word8]
shiftL1W ([Word8] -> [Word8]) -> (ba -> [Word8]) -> ba -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ba -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack

shiftL1W :: [Word8] -> [Word8]
shiftL1W :: [Word8] -> [Word8]
shiftL1W [] = []
shiftL1W ws :: [Word8]
ws@(Word8
_ : [Word8]
ns) = [(Word8, Word8)] -> [Word8]
forall {b} {a}. (Bits b, Bits a) => [(a, b)] -> [a]
rec' ([(Word8, Word8)] -> [Word8]) -> [(Word8, Word8)] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Word8] -> [Word8] -> [(Word8, Word8)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Word8]
ws ([Word8]
ns [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
0])
  where
    rec' :: [(a, b)] -> [a]
rec' [] = []
    rec' ((a
x, b
y) : [(a, b)]
ps) = a
w a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [(a, b)] -> [a]
rec' [(a, b)]
ps
      where
        w :: a
w
            | b -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit b
y Int
7 = a -> Int -> a
forall a. Bits a => a -> Int -> a
setBit a
sl1 Int
0
            | Bool
otherwise = a
sl1
          where
            sl1 :: a
sl1 = a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftL a
x Int
1

bxor :: ByteArray ba => ba -> ba -> ba
bxor :: forall ba. ByteArray ba => ba -> ba -> ba
bxor = ba -> ba -> ba
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor

-----

cipherIPT :: BlockCipher k => k -> [Word8]
cipherIPT :: forall k. BlockCipher k => k -> [Word8]
cipherIPT = Int -> [Word8]
expandIPT (Int -> [Word8]) -> (k -> Int) -> k -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize

-- Data type which represents the smallest irreducibule binary polynomial
-- against specified degree.
--
-- Maximum degree bit and degree 0 bit are omitted.
-- For example, The value /Q 7 2 1/ corresponds to the degree /128/.
-- It represents that the smallest irreducible binary polynomial of degree 128
-- is x^128 + x^7 + x^2 + x^1 + 1.
data IPolynomial
    = Q Int Int Int

---  | T Int

iPolynomial :: Int -> Maybe IPolynomial
iPolynomial :: Int -> Maybe IPolynomial
iPolynomial = Int -> Maybe IPolynomial
forall {a}. (Eq a, Num a) => a -> Maybe IPolynomial
d
  where
    d :: a -> Maybe IPolynomial
d a
64 = IPolynomial -> Maybe IPolynomial
forall a. a -> Maybe a
Just (IPolynomial -> Maybe IPolynomial)
-> IPolynomial -> Maybe IPolynomial
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> IPolynomial
Q Int
4 Int
3 Int
1
    d a
128 = IPolynomial -> Maybe IPolynomial
forall a. a -> Maybe a
Just (IPolynomial -> Maybe IPolynomial)
-> IPolynomial -> Maybe IPolynomial
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> IPolynomial
Q Int
7 Int
2 Int
1
    d a
_ = Maybe IPolynomial
forall a. Maybe a
Nothing

-- Expand a tail bit pattern of irreducible binary polynomial
expandIPT :: Int -> [Word8]
expandIPT :: Int -> [Word8]
expandIPT Int
bytes = Int -> IPolynomial -> [Word8]
expandIPT' Int
bytes IPolynomial
ipt
  where
    ipt :: IPolynomial
ipt =
        IPolynomial
-> (IPolynomial -> IPolynomial) -> Maybe IPolynomial -> IPolynomial
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
            ( [Char] -> IPolynomial
forall a. HasCallStack => [Char] -> a
error ([Char] -> IPolynomial) -> [Char] -> IPolynomial
forall a b. (a -> b) -> a -> b
$
                [Char]
"Irreducible binary polynomial not defined against " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
nb [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" bit"
            )
            IPolynomial -> IPolynomial
forall a. a -> a
id
            (Maybe IPolynomial -> IPolynomial)
-> Maybe IPolynomial -> IPolynomial
forall a b. (a -> b) -> a -> b
$ Int -> Maybe IPolynomial
iPolynomial Int
nb
    nb :: Int
nb = Int
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8

-- Expand a tail bit pattern of irreducible binary polynomial
expandIPT'
    :: Int
    -- ^ width in byte
    -> IPolynomial
    -- ^ irreducible binary polynomial definition
    -> [Word8]
    -- ^ result bit pattern
expandIPT' :: Int -> IPolynomial -> [Word8]
expandIPT' Int
bytes (Q Int
x Int
y Int
z) =
    [Word8] -> [Word8]
forall a. [a] -> [a]
reverse ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
x ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
y ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
z ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
0 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
bytes Word8
0
  where
    setB :: Int -> [a] -> [a]
setB Int
i [a]
ws = case [a]
tl of
        (a
a : [a]
as) -> [a]
hd [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> Int -> a
forall a. Bits a => a -> Int -> a
setBit a
a Int
r a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
as
        [a]
_ -> [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"expandIPT'"
      where
        (Int
q, Int
r) = Int
i Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
8
        ([a]
hd, [a]
tl) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q [a]
ws