{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
#include "MachDeps.h"
#endif
module Data.Binary.BitBuilder (
    
      BitBuilder
    , toLazyByteString
    
    , empty
    , singleton
    , append
    , fromByteString        
    , fromLazyByteString    
    , fromBits
    
    , flush
  ) where
import Foreign
import Data.Semigroup (Semigroup((<>)))
import qualified Data.ByteString      as S
import qualified Data.ByteString.Lazy as L
import System.IO.Unsafe (unsafePerformIO)
#ifdef BYTESTRING_IN_BASE
import Data.ByteString.Base (inlinePerformIO)
import qualified Data.ByteString.Base as S
#else
import Data.ByteString.Internal (inlinePerformIO)
import qualified Data.ByteString.Internal as S
import qualified Data.ByteString.Lazy.Internal as L
#endif
import Data.Binary.Strict.BitUtil
#if defined(__GLASGOW_HASKELL__) && !defined(__HADDOCK__)
import GHC.Base hiding ( empty, foldr )
#endif
newtype BitBuilder = BitBuilder {
        
        
        runBitBuilder :: (Buffer -> [S.ByteString]) -> Buffer -> [S.ByteString]
    }
instance Show BitBuilder where
  show = const "<BitBuilder>"
instance Semigroup BitBuilder where
  (<>) = append
instance Monoid BitBuilder where
    mempty  = empty
empty :: BitBuilder
empty = BitBuilder id
singleton :: Bool -> BitBuilder
singleton bit = writeN 1 $ \p phase -> do
  byte <- peek p
  let mask = complement (0x80 `shiftR` phase)
      value = if not bit then 0 else 0x80 `shiftR` phase
  poke p $ (byte .&. mask) .|. value
{-# INLINE singleton #-}
fromByteString :: (S.ByteString, Int) -> BitBuilder
fromByteString (bs, bsPhase) = withPhase f where
  f phase
      
    | S.length bs == 0 = empty
    | phase == 0 && bsPhase == 0 =
      
      
        flush `append` (BitBuilder $ \k buf -> bs : k buf)
    | phase == 0 =
      
      
      
        flush `append` (BitBuilder $ \k buf -> S.init bs :
                         (runBitBuilder (writeN bsPhase (\p _ -> poke p (S.last bs))) k buf))
    | otherwise =
      
      
        writeN (8 - phase) (mergeByte $ S.head bs) `mappend` fromByteString shiftedBS where
          mergeByte nextByte p phase = do
            byte <- peek p
            let takingBits = 8 - phase
                mask = topNBits phase
                a = topNBits takingBits .&. nextByte
                b = a `shiftR` phase
                c = (byte .&. mask) .|. b
            poke p c
          shiftedBS = (S.take newLength shifted, bsPhase')
          shifted = leftShift (8 - phase) bs
          oldBitLength =
            if bsPhase == 0
               then 8 * S.length bs
               else (S.length bs - 1) * 8 + bsPhase
          newLength = ((oldBitLength - (8 - phase)) + 7) `div` 8
          bsPhase' = (bsPhase - (8 - phase)) `mod` 8
fromBits :: (Integral a, Bits a) => Int -> a -> BitBuilder
fromBits n v
  | n == 0 = empty
  | otherwise = writeN n $ f n where
      f n p phase = do
        let space = 8 - phase
        if n <= space
           then g p phase v n
           else g p phase (v `shiftR` (n - space)) space >> f (n - space) (p `plusPtr` 1) 0
      g p phase v n = do
        byte <- peek p
        let mask = topNBits phase
            bits = ((fromIntegral v) .&. bottomNBits n) `shiftL` ((8 - phase) - n)
        poke p $ (byte .&. mask) .|. bits
{-# INLINE fromBits #-}
append :: BitBuilder -> BitBuilder -> BitBuilder
append (BitBuilder f) (BitBuilder g) = BitBuilder (f . g)
fromLazyByteString :: L.ByteString -> BitBuilder
fromLazyByteString = foldr (append . fromByteString . flip (,) 0) empty . L.toChunks
data Buffer = Buffer {-# UNPACK #-} !(ForeignPtr Word8)
                     {-# UNPACK #-} !Int                
                     {-# UNPACK #-} !Int                
                     {-# UNPACK #-} !Int                
                     {-# UNPACK #-} !Int                
toLazyByteString :: BitBuilder -> L.ByteString
toLazyByteString m = L.fromChunks $ unsafePerformIO $ do
    fp <- S.mallocByteString (defaultSize `div` 8)
    let buf = Buffer fp 0 0 0 (defaultSize `div` 8)
    return (runBitBuilder (m `append` zeroExtendFinalByte `append` flush) (const []) buf)
flush :: BitBuilder
flush = BitBuilder $ \ k buf@(Buffer p bo phase u l) ->
    if u == 0
      then k buf
      else S.PS p bo u : k (Buffer p (bo+u) phase 0 l)
defaultSize :: Int
defaultSize = 8 * (512 - overhead) where
  overhead = 2 * sizeOf (undefined :: Int)
unsafeLiftIO :: (Buffer -> IO Buffer) -> BitBuilder
unsafeLiftIO f =  BitBuilder $ \ k buf -> unsafePerformIO $ do
    buf' <- f buf
    return (k buf')
{-# INLINE unsafeLiftIO #-}
withSize :: (Int -> BitBuilder) -> BitBuilder
withSize f = BitBuilder $ \ k buf@(Buffer _ _ phase _ l) ->
    runBitBuilder (f $ l*8 - phase) k buf
withPhase :: (Int -> BitBuilder) -> BitBuilder
withPhase f = BitBuilder $ \ k buf@(Buffer _ _ phase _ _) ->
    runBitBuilder (f phase) k buf
zeroExtendFinalByte :: BitBuilder
zeroExtendFinalByte = withPhase $ \phase ->
  if phase == 0
     then empty
     else writeN (8 - phase) (\p phase -> do
       byte <- peek p
       poke p $ byte .&. topNBits phase)
ensureFree :: Int -> BitBuilder
ensureFree n = n `seq` withSize $ \ l ->
    if n <= l then empty else
        flush `append` unsafeLiftIO (newBuffer (max n defaultSize))
{-# INLINE ensureFree #-}
writeN :: Int -> (Ptr Word8 -> Int -> IO ()) -> BitBuilder
writeN n f = ensureFree n `append` unsafeLiftIO (writeNBuffer n f)
{-# INLINE [1] writeN #-}
writeNBuffer :: Int -> (Ptr Word8 -> Int -> IO ()) -> Buffer -> IO Buffer
writeNBuffer n f (Buffer fp bo phase u l) = do
    withForeignPtr fp (\p -> f (p `plusPtr` (bo+u)) phase)
    let (bytesUsed, phase') = divMod (phase + n) 8
    return (Buffer fp bo phase' (u+bytesUsed) (l-bytesUsed))
{-# INLINE writeNBuffer #-}
newBuffer :: Int -> Buffer -> IO Buffer
newBuffer size (Buffer p bo phase u _) =
  if phase == 0
     then do
       let byteSize = (size + 7) `div` 8
       fp <- S.mallocByteString byteSize
       return $! Buffer fp 0 0 0 byteSize
     else do
       let byteSize = (size + 15) `div` 8
       fp <- S.mallocByteString byteSize
       withForeignPtr fp (\fp ->
         withForeignPtr p (\p -> do
           byte <- peek (p `plusPtr` (bo+u))
           poke fp byte))
       return $! Buffer fp 0 phase 0 byteSize