{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE OverloadedStrings #-}

{-# LANGUAGE CApiFFI           #-}

{-# OPTIONS_GHC -Wall  #-}

-- |
-- Copyright: © 2020  Herbert Valerio Riedel
-- SPDX-License-Identifier: GPL-2.0-or-later
--
module Main
    ( main
    , xor32ByteString'ref
    , xor32ByteString'v3
    , xor32ByteString'v4
    ) where

import           Control.Exception        (assert)
import           Control.Monad
import           Criterion.Main
import           Data.Bits
import qualified Data.ByteString          as BS
import           Data.ByteString.Internal as BS
import qualified Data.ByteString.Short    as SBS
import           Data.Word                (Word32, Word8)
import           Foreign.ForeignPtr
import           Foreign.Ptr
import           Foreign.Storable
import           GHC.ByteOrder            (ByteOrder (..), targetByteOrder)

import qualified Data.XOR                 as IUT

main :: IO ()
main = defaultMain benches

benches :: [Benchmark]
benches =
    [ doGroup "4k"   bs4k
    , doGroup "4k1"  bs4k1
    , doGroup "4k2"  bs4k2
    , doGroup "4k3"  bs4k3
    , doGroup "32k"  bs32k
    , doGroup "256k" bs256k
    ]
  where
    doGroup label bs = let sbs = SBS.toShort bs in bgroup label
      [ bench "REF"      $ whnf (xor32ByteString'ref msk)       bs
      , bench "IUT"      $ whnf (IUT.xor32StrictByteString msk) bs
      , bench "IUT/SBS"  $ whnf (IUT.xor32ShortByteString msk) sbs
      , bench "v3"       $ whnf (xor32ByteString'v3 msk)        bs
      , bench "v4"       $ whnf (xor32ByteString'v4 msk)        bs
      , bench "REF 8bit" $ whnf (xor8StrictByteString'ref msk8) bs
      , bench "IUT 8bit" $ whnf (IUT.xor8StrictByteString msk8) bs
      ]

    {-# NOINLINE bs32k #-}
    !bs4k = BS.replicate (4*1024) 0x55
    !bs4k1 = BS.replicate (4*1024+1) 0x55
    !bs4k2 = BS.replicate (4*1024+2) 0x55
    !bs4k3 = BS.replicate (4*1024+3) 0x55
    !bs32k = BS.replicate (32*1024) 0x55
    !bs256k = BS.replicate (256*1024) 0x55

    {-# NOINLINE msk #-}
    msk = 0x12345678

    msk8 = 0x42

----------------------------------------------------------------------------

-- reference impl

{-# NOINLINE xor32ByteString'ref #-}
xor32ByteString'ref :: Word32 -> BS.ByteString -> BS.ByteString
xor32ByteString'ref 0    = id
xor32ByteString'ref msk0 = snd . BS.mapAccumL go msk0
  where
    go :: Word32 -> Word8 -> (Word32,Word8)
    go msk b = let b'   = fromIntegral msk' `xor` b
                   msk' = rotateL msk 8
               in b' `seq` (msk',b')


{-# NOINLINE xor8StrictByteString'ref #-}
xor8StrictByteString'ref :: Word8 -> BS.ByteString -> BS.ByteString
xor8StrictByteString'ref 0    = id
xor8StrictByteString'ref msk0 = BS.map (xor msk0)

-- {-# NOINLINE xor32ByteString'v2 #-}
-- xor32ByteString'v2 :: Word32 -> BS.ByteString -> BS.ByteString
-- xor32ByteString'v2 msk0 = snd . BS.mapAccumL go mskstr
--   where
--     mskstr :: [Word8]
--     mskstr = cycle (map fromIntegral (tail (take 5 (iterate rotl8 msk0))))
--
--     rotl8 :: Word32 -> Word32
--     rotl8 = flip rotateL 8
--
--     go (x:xs) b = let !b' = xor x b in (xs,b')

{-# NOINLINE xor32ByteString'v3 #-}
xor32ByteString'v3 :: Word32 -> BS.ByteString -> BS.ByteString
xor32ByteString'v3 0 bs = bs
xor32ByteString'v3 _ bs | BS.null bs = bs
xor32ByteString'v3 msk0 (BS.PS x s l)
    = unsafeCreate l $ \p8 ->
        withForeignPtr x $ \f -> do
          memcpy p8 (f `plusPtr` s) (fromIntegral l)
          let p32 = castPtr p8 :: Ptr Word32
              l32 = l `quot` 4
              p32end = p32 `plusPtr` (l32*4)
          unless (alignPtr p32 4 == p32) $ fail "bytestring allocation not aligned"
          xor32PtrAligned msk0 p32 (l32*4)
          _ <- xor32PtrNonAligned msk0 (castPtr p32end) (l - (l32*4))
          return ()

{-# NOINLINE xor32ByteString'v4 #-}
xor32ByteString'v4 :: Word32 -> BS.ByteString -> BS.ByteString
xor32ByteString'v4  0 bs = bs
xor32ByteString'v4  _ bs | BS.null bs = bs
xor32ByteString'v4 msk0 (BS.PS x s l)
    = unsafeCreate l $ \p8 ->
        withForeignPtr x $ \f -> do
          memcpy p8 (f `plusPtr` s) (fromIntegral l)
          _ <- IUT.xor32CStringLen msk0 (castPtr p8,l)
          return ()

{-# INLINE xor32PtrNonAligned #-}
xor32PtrNonAligned :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned mask0 _ 0 = return mask0
xor32PtrNonAligned mask0 p0 n = go mask0 p0
  where
    p' = p0 `plusPtr` n
    go m p
      | p == p'   = return m
      | otherwise = do
          let m' = rotateL m 8
          xor8Ptr1 (fromIntegral m') p
          go m' (p `plusPtr` 1)

{-# INLINE xor32PtrAligned #-}
xor32PtrAligned :: Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned _ _ 0 = return ()
xor32PtrAligned mask0be p0 n
  = assert (p0 == alignPtr p0 4 && n `rem` 4 == 0) $ go p0
  where
    p' = p0 `plusPtr` n
    go p
      | p == p'   = return ()
      | otherwise = do { xor32Ptr1 mask0 p; go (p `plusPtr` 4) }

    mask0 = case targetByteOrder of
              LittleEndian -> {- byteSwap32 -} mask0be
              BigEndian    -> mask0be

----------------------------------------------------------------------------

xor8Ptr1 :: Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 msk ptr  = do { x <- peek ptr; poke ptr (xor msk x) }

-- xor16Ptr1 :: Word16 -> Ptr Word16 -> IO ()
-- xor16Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }

xor32Ptr1 :: Word32 -> Ptr Word32 -> IO ()
xor32Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }