{-# LANGUAGE BangPatterns #-}

module DataFrame.IO.Parquet.Encoding where

import Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BSU
import Data.List (foldl')
import qualified Data.Vector.Unboxed as VU
import Data.Word
import DataFrame.IO.Parquet.Binary

ceilLog2 :: Int -> Int
ceilLog2 :: Int -> Int
ceilLog2 Int
x
    | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Int
0
    | Bool
otherwise = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
ceilLog2 ((Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)

bitWidthForMaxLevel :: Int -> Int
bitWidthForMaxLevel :: Int -> Int
bitWidthForMaxLevel Int
maxLevel = Int -> Int
ceilLog2 (Int
maxLevel Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

bytesForBW :: Int -> Int
bytesForBW :: Int -> Int
bytesForBW Int
bw = (Int
bw Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8

unpackBitPacked :: Int -> Int -> BS.ByteString -> ([Word32], BS.ByteString)
unpackBitPacked :: Int -> Int -> ByteString -> ([Word32], ByteString)
unpackBitPacked Int
bw Int
count ByteString
bs
    | Int
count Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = ([], ByteString
bs)
    | ByteString -> Bool
BS.null ByteString
bs = ([], ByteString
bs)
    | Bool
otherwise =
        let totalBytes :: Int
totalBytes = (Int
bw Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
            chunk :: ByteString
chunk = Int -> ByteString -> ByteString
BS.take Int
totalBytes ByteString
bs
            rest :: ByteString
rest = Int -> ByteString -> ByteString
BS.drop Int
totalBytes ByteString
bs
         in (Int -> Int -> ByteString -> [Word32]
extractBits Int
bw Int
count ByteString
chunk, ByteString
rest)

-- | LSB-first bit accumulator: reads each byte once with no intermediate ByteString allocation.
extractBits :: Int -> Int -> BS.ByteString -> [Word32]
extractBits :: Int -> Int -> ByteString -> [Word32]
extractBits Int
bw Int
count ByteString
bs = Int -> Word64 -> Int -> Int -> [Word32]
forall {t} {a}.
(Ord t, Num t, Num a) =>
Int -> Word64 -> Int -> t -> [a]
go Int
0 (Word64
0 :: Word64) Int
0 Int
count
  where
    !mask :: Word64
mask = if Int
bw Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 then Word64
forall a. Bounded a => a
maxBound else (Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` Int
bw) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1 :: Word64
    !len :: Int
len = ByteString -> Int
BS.length ByteString
bs
    go :: Int -> Word64 -> Int -> t -> [a]
go !Int
byteIdx !Word64
acc !Int
accBits !t
remaining
        | t
remaining t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 = []
        | Int
accBits Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
bw =
            Word64 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
acc Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
mask)
                a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Int -> Word64 -> Int -> t -> [a]
go Int
byteIdx (Word64
acc Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
bw) (Int
accBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bw) (t
remaining t -> t -> t
forall a. Num a => a -> a -> a
- t
1)
        | Int
byteIdx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len = []
        | Bool
otherwise =
            let b :: Word64
b = Word8 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
BSU.unsafeIndex ByteString
bs Int
byteIdx) :: Word64
             in Int -> Word64 -> Int -> t -> [a]
go (Int
byteIdx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Word64
acc Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
b Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` Int
accBits)) (Int
accBits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
8) t
remaining

decodeRLEBitPackedHybrid ::
    Int -> Int -> BS.ByteString -> ([Word32], BS.ByteString)
decodeRLEBitPackedHybrid :: Int -> Int -> ByteString -> ([Word32], ByteString)
decodeRLEBitPackedHybrid Int
bw Int
need ByteString
bs
    | Int
bw Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (Int -> Word32 -> [Word32]
forall a. Int -> a -> [a]
replicate Int
need Word32
0, ByteString
bs)
    | Bool
otherwise = Int -> ByteString -> [Word32] -> ([Word32], ByteString)
go Int
need ByteString
bs []
  where
    mask :: Word32
    mask :: Word32
mask = if Int
bw Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 then Word32
forall a. Bounded a => a
maxBound else (Word32
1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
bw) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
1
    go :: Int -> BS.ByteString -> [Word32] -> ([Word32], BS.ByteString)
    go :: Int -> ByteString -> [Word32] -> ([Word32], ByteString)
go Int
0 ByteString
rest [Word32]
acc = ([Word32] -> [Word32]
forall a. [a] -> [a]
reverse [Word32]
acc, ByteString
rest)
    go Int
n ByteString
rest [Word32]
acc
        | ByteString -> Bool
BS.null ByteString
rest = ([Word32] -> [Word32]
forall a. [a] -> [a]
reverse [Word32]
acc, ByteString
rest)
        | Bool
otherwise =
            let (Word64
hdr64, ByteString
afterHdr) = ByteString -> (Word64, ByteString)
readUVarInt ByteString
rest
                isPacked :: Bool
isPacked = (Word64
hdr64 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
1) Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
1
             in if Bool
isPacked
                    then
                        let groups :: Int
groups = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
hdr64 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) :: Int
                            totalVals :: Int
totalVals = Int
groups Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8
                            ([Word32]
valsAll, ByteString
afterRun) = Int -> Int -> ByteString -> ([Word32], ByteString)
unpackBitPacked Int
bw Int
totalVals ByteString
afterHdr
                            takeN :: Int
takeN = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
totalVals
                            actualTaken :: [Word32]
actualTaken = Int -> [Word32] -> [Word32]
forall a. Int -> [a] -> [a]
take Int
takeN [Word32]
valsAll
                         in Int -> ByteString -> [Word32] -> ([Word32], ByteString)
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
takeN) ByteString
afterRun ([Word32] -> [Word32]
forall a. [a] -> [a]
reverse [Word32]
actualTaken [Word32] -> [Word32] -> [Word32]
forall a. [a] -> [a] -> [a]
++ [Word32]
acc)
                    else
                        let runLen :: Int
runLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
hdr64 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) :: Int
                            nbytes :: Int
nbytes = Int -> Int
bytesForBW Int
bw
                            word32 :: Word32
word32 = ByteString -> Word32
littleEndianWord32 (Int -> ByteString -> ByteString
BS.take Int
4 ByteString
afterHdr)
                            afterV :: ByteString
afterV = Int -> ByteString -> ByteString
BS.drop Int
nbytes ByteString
afterHdr
                            val :: Word32
val = Word32
word32 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
mask
                            takeN :: Int
takeN = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
runLen
                         in Int -> ByteString -> [Word32] -> ([Word32], ByteString)
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
takeN) ByteString
afterV (Int -> Word32 -> [Word32]
forall a. Int -> a -> [a]
replicate Int
takeN Word32
val [Word32] -> [Word32] -> [Word32]
forall a. [a] -> [a] -> [a]
++ [Word32]
acc)

decodeDictIndicesV1 ::
    Int -> Int -> BS.ByteString -> (VU.Vector Int, BS.ByteString)
decodeDictIndicesV1 :: Int -> Int -> ByteString -> (Vector Int, ByteString)
decodeDictIndicesV1 Int
need Int
dictCard ByteString
bs =
    case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
bs of
        Maybe (Word8, ByteString)
Nothing -> [Char] -> (Vector Int, ByteString)
forall a. HasCallStack => [Char] -> a
error [Char]
"empty dictionary index stream"
        Just (Word8
w0, ByteString
rest0) ->
            let bw :: Int
bw = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w0 :: Int
                ([Word32]
u32s, ByteString
rest1) = Int -> Int -> ByteString -> ([Word32], ByteString)
decodeRLEBitPackedHybrid Int
bw Int
need ByteString
rest0
             in ([Int] -> Vector Int
forall a. Unbox a => [a] -> Vector a
VU.fromList ((Word32 -> Int) -> [Word32] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Word32]
u32s), ByteString
rest1)