{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE UnliftedFFITypes #-}

{- | Compress a contiguous sequence of bytes into an LZ4 frame
containing a single block.
-}
module Lz4.Frame
  ( -- * Compression
    compressHighlyU
    -- * Decompression
  , decompressU
  ) where

import Lz4.Internal (requiredBufferSize,c_hs_compress_HC,c_hs_decompress_safe)

import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Bits ((.&.))
import Data.Bytes.Types (Bytes (Bytes))
import Data.Int (Int32)
import Data.Primitive (ByteArray (..), MutableByteArray (..))
import Data.Word (Word8, Word32)
import GHC.IO (unsafeIOToST)

import qualified Data.Primitive as PM
import qualified Data.Primitive.ByteArray.LittleEndian as LE
import qualified Data.Bytes as Bytes

-- | Decompress an LZ4 frame. The caller must know the exact size
-- of the decompressed byte array.
--
-- Note: This currently fails if any of the optional headers are used.
-- It is difficult to find examples of lz4 frames that actually use
-- any of these. Open a PR with an example of an lz4 frame that fails
-- to decode if you find one.
decompressU ::
     Int -- ^ The exact size of the decompressed bytes
  -> Bytes -- ^ Compressed bytes
  -> Maybe ByteArray
decompressU :: Int -> Bytes -> Maybe ByteArray
decompressU !Int
decompressedSize (Bytes arr :: ByteArray
arr@(ByteArray ByteArray#
arr# ) Int
off Int
len) = do
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
11) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr Int
off Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x04) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x22) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x4D) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x18) Maybe ()
forall a. Maybe a
Nothing
  let !flag :: Word8
flag = ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4)
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
flag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0b0110_0000) Maybe ()
forall a. Maybe a
Nothing
  -- Here is the code that would read the size hint from the bd. However,
  -- there is no reason to use this since this function takes the actual
  -- size as an argument. We ignore the checksum at position off+6 as well.
  -- let !bd = indexWord8 arr (off + 5)
  -- maximumDecompressedSize <- case bd of
  --   0b0111_0000 -> pure 4194304
  --   0b0110_0000 -> pure 1048576
  --   0b0101_0000 -> pure 262144
  --   0b0100_0000 -> pure 65536
  --   _ -> Nothing
  -- when (maximumDecompressedSize < decompressedSize) Nothing
  let !compressedSize :: Word32
compressedSize = ByteArray -> Int -> Word32
forall a. (PrimUnaligned a, Bytes a) => ByteArray -> Int -> a
LE.indexUnalignedByteArray ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) :: Word32
  let !compressedSizeI :: Int
compressedSizeI = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
compressedSize Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x7fff_ffff) :: Int
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
compressedSizeI Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
len) Maybe ()
forall a. Maybe a
Nothing
  let !offPost :: Int
offPost = Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
11 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
compressedSizeI
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr Int
offPost Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x00) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
offPost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x00) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
offPost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x00) Maybe ()
forall a. Maybe a
Nothing
  Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteArray -> Int -> Word8
indexWord8 ByteArray
arr (Int
offPost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x00) Maybe ()
forall a. Maybe a
Nothing
  case Word32
compressedSize Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x8000_0000 of
    Word32
0 -> (forall s. ST s (Maybe ByteArray)) -> Maybe ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Maybe ByteArray)) -> Maybe ByteArray)
-> (forall s. ST s (Maybe ByteArray)) -> Maybe ByteArray
forall a b. (a -> b) -> a -> b
$ do
      dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
decompressedSize
      Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> IO Int
c_hs_decompress_safe ByteArray#
arr# (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
11) MutableByteArray# s
dst# Int
0 Int
compressedSizeI Int
decompressedSize)
      -- Note: actualSz will be negative if decompression fails. That's fine.
      if Int
actualSz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
decompressedSize
        then do
          ByteArray
dst' <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst
          Maybe ByteArray -> ST s (Maybe ByteArray)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Maybe ByteArray
forall a. a -> Maybe a
Just ByteArray
dst')
        else Maybe ByteArray -> ST s (Maybe ByteArray)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteArray
forall a. Maybe a
Nothing
    Word32
_ -> do
      -- When the upper bit of the size is set, it means that the data in
      -- the block is uncompressed. This code path is not tested in the test
      -- suite, and I cannot find examples of this feature used in the wild.
      -- If anyone knows of an example, open a PR.
      Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
decompressedSize Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
compressedSizeI) Maybe ()
forall a. Maybe a
Nothing
      Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
decompressedSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
15 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
len) Maybe ()
forall a. Maybe a
Nothing
      ByteArray -> Maybe ByteArray
forall a. a -> Maybe a
Just (ByteArray -> Maybe ByteArray) -> ByteArray -> Maybe ByteArray
forall a b. (a -> b) -> a -> b
$! Bytes -> ByteArray
Bytes.toByteArrayClone (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
11) Int
decompressedSize)

{- | Use HC compression to produce a frame with a single block.
All optional fields (checksums, content sizes, and dictionary IDs)
are omitted.

Note: Currently, this produces incorrect output when the size of
the input to be compressed is greater than 4MiB. The only way
to correct this function is to make it not compress large input.
This can be done by setting the high bit of the size. This needs
to be tested though since it is an uncommon code path.
-}
compressHighlyU ::
  -- | Compression level (Use 9 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  ByteArray
compressHighlyU :: Int -> Bytes -> ByteArray
compressHighlyU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
15
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  -- -- First 4 bytes: magic identifier
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 (Word8
0x04 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
1 (Word8
0x22 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
2 (Word8
0x4D :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
3 (Word8
0x18 :: Word8)
  -- Next 3 bytes: frame descriptor
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
4 (Word8
0b0110_0000 :: Word8)
  if
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
65_536 -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0100_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0x82 :: Word8)
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
262_144 -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0101_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0xFB :: Word8)
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1_048_576 -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0110_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0x51 :: Word8)
    | Bool
otherwise -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0111_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0x73 :: Word8)
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
11 Int
len Int
maxSz Int
lvl)
  MutableByteArray (PrimState (ST s)) -> Int -> Int32 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, PrimUnaligned a, Bytes a) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
LE.writeUnalignedByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
7 (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
actualSz :: Int32)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
11) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
12) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
13) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
14) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
15)
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

indexWord8 :: ByteArray -> Int -> Word8
{-# inline indexWord8 #-}
indexWord8 :: ByteArray -> Int -> Word8
indexWord8 = ByteArray -> Int -> Word8
forall a. Prim a => ByteArray -> Int -> a
PM.indexByteArray