{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE RecordWildCards #-}
module Network.ByteOrder (
    
    Buffer
  , Offset
  , BufferSize
  , BufferOverrun(..)
    
  , poke8
  , poke16
  , poke24
  , poke32
  , poke64
    
  , peek8
  , peek16
  , peek24
  , peek32
  , peek64
  , peekByteString
    
  , bytestring8
  , bytestring16
  , bytestring32
  , bytestring64
    
  , word8
  , word16
  , word32
  , word64
    
  , unsafeWithByteString
  , copy
  , bufferIO
    
  , Readable(..)
    
  , ReadBuffer
  , newReadBuffer
  , withReadBuffer
  , read16
  , read24
  , read32
  , read64
  , extractByteString
  , extractShortByteString
    
  , WriteBuffer(..)
  , newWriteBuffer
  , clearWriteBuffer
  , withWriteBuffer
  , withWriteBuffer'
  , write8
  , write16
  , write24
  , write32
  , write64
  , copyByteString
  , copyShortByteString
  , shiftLastN
  , toByteString
  , toShortByteString
  , currentOffset
    
  , Word8, Word16, Word32, Word64, ByteString
  ) where
import Control.Exception (bracket, throwIO, Exception)
import Control.Monad (when)
import Data.Bits (shiftR, shiftL, (.&.), (.|.))
import Data.ByteString.Internal (ByteString(..), create, memcpy, ByteString(..), unsafeCreate)
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short.Internal as Short
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Typeable
import Data.Word (Word8, Word8, Word16, Word32, Word64)
import Foreign.ForeignPtr (withForeignPtr, newForeignPtr_)
import Foreign.Marshal.Alloc
import Foreign.Ptr (Ptr, plusPtr, plusPtr, minusPtr)
import Foreign.Storable (peek, poke, poke, peek)
import System.IO.Unsafe (unsafeDupablePerformIO)
type Buffer = Ptr Word8
type Offset = Int
type BufferSize = Int
(+.) :: Buffer -> Offset -> Buffer
(+.) = plusPtr
poke8 :: Word8 -> Buffer -> Offset -> IO ()
poke8 w ptr off = poke (ptr +. off) w
poke16 :: Word16 -> Buffer -> Offset -> IO ()
poke16 w ptr off = do
    poke8 w0 ptr off
    poke8 w1 ptr (off + 1)
  where
    w0 = fromIntegral ((w `shiftR`  8) .&. 0xff)
    w1 = fromIntegral  (w              .&. 0xff)
poke24 :: Word32 -> Buffer -> Offset -> IO ()
poke24 w ptr off = do
    poke8 w0 ptr off
    poke8 w1 ptr (off + 1)
    poke8 w2 ptr (off + 2)
  where
    w0 = fromIntegral ((w `shiftR` 16) .&. 0xff)
    w1 = fromIntegral ((w `shiftR`  8) .&. 0xff)
    w2 = fromIntegral  (w              .&. 0xff)
poke32 :: Word32 -> Buffer -> Offset -> IO ()
poke32 w ptr off = do
    poke8 w0 ptr off
    poke8 w1 ptr (off + 1)
    poke8 w2 ptr (off + 2)
    poke8 w3 ptr (off + 3)
  where
    w0 = fromIntegral ((w `shiftR` 24) .&. 0xff)
    w1 = fromIntegral ((w `shiftR` 16) .&. 0xff)
    w2 = fromIntegral ((w `shiftR`  8) .&. 0xff)
    w3 = fromIntegral  (w              .&. 0xff)
poke64 :: Word64 -> Buffer -> Offset -> IO ()
poke64 w ptr off = do
    poke8 w0 ptr off
    poke8 w1 ptr (off + 1)
    poke8 w2 ptr (off + 2)
    poke8 w3 ptr (off + 3)
    poke8 w4 ptr (off + 4)
    poke8 w5 ptr (off + 5)
    poke8 w6 ptr (off + 6)
    poke8 w7 ptr (off + 7)
  where
    w0 = fromIntegral ((w `shiftR` 56) .&. 0xff)
    w1 = fromIntegral ((w `shiftR` 48) .&. 0xff)
    w2 = fromIntegral ((w `shiftR` 40) .&. 0xff)
    w3 = fromIntegral ((w `shiftR` 32) .&. 0xff)
    w4 = fromIntegral ((w `shiftR` 24) .&. 0xff)
    w5 = fromIntegral ((w `shiftR` 16) .&. 0xff)
    w6 = fromIntegral ((w `shiftR`  8) .&. 0xff)
    w7 = fromIntegral  (w              .&. 0xff)
peek8 :: Buffer -> Offset -> IO Word8
peek8 ptr off = peek (ptr +. off)
peek16 :: Buffer -> Offset -> IO Word16
peek16 ptr off = do
    w0 <- (`shiftL` 8) . fromIntegral <$> peek8 ptr off
    w1 <-                fromIntegral <$> peek8 ptr (off + 1)
    return $ w0 .|. w1
peek24 :: Buffer -> Offset -> IO Word32
peek24 ptr off = do
    w0 <- (`shiftL` 16) . fromIntegral <$> peek8 ptr off
    w1 <- (`shiftL`  8) . fromIntegral <$> peek8 ptr (off + 1)
    w2 <-                 fromIntegral <$> peek8 ptr (off + 2)
    return $ w0 .|. w1 .|. w2
peek32 :: Buffer -> Offset -> IO Word32
peek32 ptr off = do
    w0 <- (`shiftL` 24) . fromIntegral <$> peek8 ptr off
    w1 <- (`shiftL` 16) . fromIntegral <$> peek8 ptr (off + 1)
    w2 <- (`shiftL`  8) . fromIntegral <$> peek8 ptr (off + 2)
    w3 <-                 fromIntegral <$> peek8 ptr (off + 3)
    return $ w0 .|. w1 .|. w2 .|. w3
peek64 :: Buffer -> Offset -> IO Word64
peek64 ptr off = do
    w0 <- (`shiftL` 56) . fromIntegral <$> peek8 ptr off
    w1 <- (`shiftL` 48) . fromIntegral <$> peek8 ptr (off + 1)
    w2 <- (`shiftL` 40) . fromIntegral <$> peek8 ptr (off + 2)
    w3 <- (`shiftL` 32) . fromIntegral <$> peek8 ptr (off + 3)
    w4 <- (`shiftL` 24) . fromIntegral <$> peek8 ptr (off + 4)
    w5 <- (`shiftL` 16) . fromIntegral <$> peek8 ptr (off + 5)
    w6 <- (`shiftL`  8) . fromIntegral <$> peek8 ptr (off + 6)
    w7 <-                 fromIntegral <$> peek8 ptr (off + 7)
    return $ w0 .|. w1 .|. w2 .|. w3 .|. w4 .|. w5 .|. w6 .|. w7
peekByteString :: Buffer -> Int -> IO ByteString
peekByteString src len = create len $ \dst -> memcpy dst src len
bytestring8 :: Word8 -> ByteString
bytestring8 w = unsafeCreate 1 $ \ptr -> poke8 w ptr 0
bytestring16 :: Word16 -> ByteString
bytestring16 w = unsafeCreate 2 $ \ptr -> poke16 w ptr 0
bytestring32 :: Word32 -> ByteString
bytestring32 w = unsafeCreate 4 $ \ptr -> poke32 w ptr 0
bytestring64 :: Word64 -> ByteString
bytestring64 w = unsafeCreate 8 $ \ptr -> poke64 w ptr 0
word8 :: ByteString -> Word8
word8 bs = unsafeDupablePerformIO $ unsafeWithByteString bs peek8
word16 :: ByteString -> Word16
word16 bs = unsafeDupablePerformIO $ unsafeWithByteString bs peek16
word32 :: ByteString -> Word32
word32 bs = unsafeDupablePerformIO $ unsafeWithByteString bs peek32
word64 :: ByteString -> Word64
word64 bs = unsafeDupablePerformIO $ unsafeWithByteString bs peek64
unsafeWithByteString :: ByteString -> (Buffer -> Offset -> IO a) -> IO a
unsafeWithByteString (PS fptr off _) io = withForeignPtr fptr $
    \ptr -> io ptr off
copy :: Buffer -> ByteString -> IO Buffer
copy ptr (PS fp o l) = withForeignPtr fp $ \p -> do
    memcpy ptr (p `plusPtr` o) (fromIntegral l)
    return $ ptr `plusPtr` l
{-# INLINE copy #-}
bufferIO :: Buffer -> Int -> (ByteString -> IO a) -> IO a
bufferIO ptr siz io = do
    fptr <- newForeignPtr_ ptr
    io $ PS fptr 0 siz
data WriteBuffer = WriteBuffer {
    start :: Buffer
  , limit :: Buffer
  , offset :: IORef Buffer
  , oldoffset :: IORef Buffer
  }
newWriteBuffer :: Buffer -> BufferSize -> IO WriteBuffer
newWriteBuffer buf siz =
    WriteBuffer buf (buf `plusPtr` siz) <$> newIORef buf <*> newIORef buf
clearWriteBuffer :: WriteBuffer -> IO ()
clearWriteBuffer WriteBuffer{..} = do
    writeIORef offset start
    writeIORef oldoffset start
{-# INLINE write8 #-}
write8 :: WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer{..} w = do
    ptr <- readIORef offset
    let ptr' = ptr `plusPtr` 1
    when (ptr' > limit) $ throwIO BufferOverrun
    poke ptr w
    writeIORef offset ptr'
{-# INLINE write16 #-}
write16 :: WriteBuffer -> Word16 -> IO ()
write16 WriteBuffer{..} w = do
    ptr <- readIORef offset
    let ptr' = ptr `plusPtr` 2
    when (ptr' > limit) $ throwIO BufferOverrun
    poke16 w ptr 0
    writeIORef offset ptr'
{-# INLINE write24 #-}
write24 :: WriteBuffer -> Word32 -> IO ()
write24 WriteBuffer{..} w = do
    ptr <- readIORef offset
    let ptr' = ptr `plusPtr` 3
    when (ptr' > limit) $ throwIO BufferOverrun
    poke24 w ptr 0
    writeIORef offset ptr'
{-# INLINE write32 #-}
write32 :: WriteBuffer -> Word32 -> IO ()
write32 WriteBuffer{..} w = do
    ptr <- readIORef offset
    let ptr' = ptr `plusPtr` 4
    when (ptr' > limit) $ throwIO BufferOverrun
    poke32 w ptr 0
    writeIORef offset ptr'
{-# INLINE write64 #-}
write64 :: WriteBuffer -> Word64 -> IO ()
write64 WriteBuffer{..} w = do
    ptr <- readIORef offset
    let ptr' = ptr `plusPtr` 8
    when (ptr' > limit) $ throwIO BufferOverrun
    poke64 w ptr 0
    writeIORef offset ptr'
{-# INLINE shiftLastN #-}
shiftLastN :: WriteBuffer -> Int -> Int -> IO ()
shiftLastN WriteBuffer{..} 0 _   = return ()
shiftLastN WriteBuffer{..} i len = do
    ptr <- readIORef offset
    let ptr' = ptr `plusPtr` i
    when (ptr' >= limit) $ throwIO BufferOverrun
    if i < 0 then do
        let src = ptr `plusPtr` negate len
            dst = src `plusPtr` i
        shiftLeft dst src len
        writeIORef offset ptr'
      else do
        let src = ptr `plusPtr` (-1)
            dst = ptr' `plusPtr` (-1)
        shiftRight dst src len
        writeIORef offset ptr'
  where
    
    shiftLeft :: Buffer -> Buffer -> Int -> IO ()
    shiftLeft _    _    0   = return ()
    shiftLeft dst src n = do
        peek src >>= poke dst
        shiftLeft (dst `plusPtr` 1) (src `plusPtr` 1) (n - 1)
    shiftRight :: Buffer -> Buffer -> Int -> IO ()
    shiftRight _    _    0   = return ()
    shiftRight dst src n = do
        peek src >>= poke dst
        shiftRight (dst `plusPtr` (-1)) (src `plusPtr` (-1)) (n - 1)
{-# INLINE copyByteString #-}
copyByteString :: WriteBuffer -> ByteString -> IO ()
copyByteString WriteBuffer{..} (PS fptr off len) = withForeignPtr fptr $ \ptr -> do
    let src = ptr `plusPtr` off
    dst <- readIORef offset
    let dst' = dst `plusPtr` len
    when (dst' > limit) $ throwIO BufferOverrun
    memcpy dst src len
    writeIORef offset dst'
copyShortByteString :: WriteBuffer -> ShortByteString -> IO ()
copyShortByteString WriteBuffer{..} sbs = do
    dst <- readIORef offset
    let len = Short.length sbs
    let dst' = dst `plusPtr` len
    when (dst' > limit) $ throwIO BufferOverrun
    Short.copyToPtr sbs 0 dst len
    writeIORef offset dst'
toByteString :: WriteBuffer -> IO ByteString
toByteString WriteBuffer{..} = do
    ptr <- readIORef offset
    let len = ptr `minusPtr` start
    create len $ \p -> memcpy p start len
toShortByteString :: WriteBuffer -> IO ShortByteString
toShortByteString WriteBuffer{..} = do
    ptr <- readIORef offset
    let len = ptr `minusPtr` start
    Short.createFromPtr start len
withWriteBuffer :: BufferSize -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer siz action = bracket (mallocBytes siz) free $ \buf -> do
    wbuf <- newWriteBuffer buf siz
    action wbuf
    toByteString wbuf
withWriteBuffer' :: BufferSize -> (WriteBuffer -> IO a) -> IO (ByteString, a)
withWriteBuffer' siz action = bracket (mallocBytes siz) free $ \buf -> do
    wbuf <- newWriteBuffer buf siz
    x <- action wbuf
    bs <- toByteString wbuf
    return (bs,x)
{-# INLINE currentOffset #-}
currentOffset :: WriteBuffer -> IO Buffer
currentOffset WriteBuffer{..} = readIORef offset
class Readable a where
    
    read8 :: a -> IO Word8
    
    readInt8 :: a -> IO Int
    
    ff :: a -> Offset -> IO ()
    
    remainingSize :: a -> IO Int
    
    withCurrentOffSet :: a -> (Buffer -> IO b) -> IO b
    
    save :: a -> IO ()
    
    savingSize :: a -> IO Int
    
    goBack :: a -> IO ()
instance Readable WriteBuffer where
    {-# INLINE read8 #-}
    read8 WriteBuffer{..} = do
        ptr <- readIORef offset
        if ptr < limit then do
            w <- peek ptr
            writeIORef offset $ ptr `plusPtr` 1
            return w
          else
            throwIO BufferOverrun
    {-# INLINE readInt8 #-}
    readInt8 WriteBuffer{..} = do
        ptr <- readIORef offset
        if ptr < limit then do
            w <- peek ptr
            writeIORef offset $ ptr `plusPtr` 1
            let i = fromIntegral w
            return i
          else
            return (-1)
    {-# INLINE ff #-}
    ff WriteBuffer{..} n = do
        ptr <- readIORef offset
        let ptr' = ptr `plusPtr` n
        when (ptr' < start) $ throwIO BufferOverrun
        when (ptr' > limit) $ throwIO BufferOverrun 
        writeIORef offset ptr'
    {-# INLINE remainingSize #-}
    remainingSize WriteBuffer{..} = do
        ptr <- readIORef offset
        return $ limit `minusPtr` ptr
    {-# INLINE withCurrentOffSet #-}
    withCurrentOffSet WriteBuffer{..} action = readIORef offset >>= action
    {-# INLINE save #-}
    save WriteBuffer{..} = readIORef offset >>= writeIORef oldoffset
    {-# INLINE savingSize #-}
    savingSize WriteBuffer{..} = do
        old <- readIORef oldoffset
        cur <- readIORef offset
        return $ cur `minusPtr` old
    {-# INLINE goBack #-}
    goBack WriteBuffer{..} = do
        old <- readIORef oldoffset
        writeIORef offset old
instance Readable ReadBuffer where
    {-# INLINE read8 #-}
    read8 (ReadBuffer w) = read8 w
    {-# INLINE readInt8 #-}
    readInt8 (ReadBuffer w) = readInt8 w
    {-# INLINE ff #-}
    ff (ReadBuffer w) = ff w
    {-# INLINE remainingSize #-}
    remainingSize (ReadBuffer w) = remainingSize w
    {-# INLINE withCurrentOffSet #-}
    withCurrentOffSet (ReadBuffer w) = withCurrentOffSet w
    {-# INLINE save #-}
    save (ReadBuffer w) = save w
    {-# INLINE savingSize #-}
    savingSize (ReadBuffer w) = savingSize w
    {-# INLINE goBack #-}
    goBack (ReadBuffer w) = goBack w
newtype ReadBuffer = ReadBuffer WriteBuffer
newReadBuffer :: Buffer -> BufferSize -> IO ReadBuffer
newReadBuffer buf siz = ReadBuffer <$> newWriteBuffer buf siz
withReadBuffer :: ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer (PS fp off siz) action = withForeignPtr fp $ \ptr -> do
    let buf = ptr `plusPtr` off
    nsrc <- newReadBuffer buf siz
    action nsrc
extractByteString :: Readable a => a -> Int -> IO ByteString
extractByteString rbuf len
  | len == 0 = return mempty
  | len >  0 = do
    checkR rbuf len
    bs <- withCurrentOffSet rbuf $ \src ->
        create len $ \dst -> memcpy dst src len
    ff rbuf len
    return bs
  | otherwise = withCurrentOffSet rbuf $ \src0 -> do
      let src = src0 `plusPtr` len
      let len' = negate len
      create len' $ \dst -> memcpy dst src len'
extractShortByteString :: Readable a => a -> Int -> IO ShortByteString
extractShortByteString rbuf len
  | len == 0 = return mempty
  | len >  0 = do
    checkR rbuf len
    sbs <- withCurrentOffSet rbuf $ \src -> Short.createFromPtr src len
    ff rbuf len
    return sbs
  | otherwise = withCurrentOffSet rbuf $ \src0 -> do
      let src = src0 `plusPtr` len
      let len' = negate len
      Short.createFromPtr src len'
read16 :: Readable a => a -> IO Word16
read16 rbuf = do
    checkR rbuf 2
    w16 <- withCurrentOffSet rbuf (`peek16` 0)
    ff rbuf 2
    return w16
read24 :: Readable a => a -> IO Word32
read24 rbuf = do
    checkR rbuf 3
    w24 <- withCurrentOffSet rbuf (`peek24` 0)
    ff rbuf 3
    return w24
read32 :: Readable a => a -> IO Word32
read32 rbuf = do
    checkR rbuf 4
    w32 <- withCurrentOffSet rbuf (`peek32` 0)
    ff rbuf 4
    return w32
read64 :: Readable a => a -> IO Word64
read64 rbuf = do
    checkR rbuf 8
    w64 <- withCurrentOffSet rbuf (`peek64` 0)
    ff rbuf 8
    return w64
checkR :: Readable a => a -> Int -> IO ()
checkR rbuf siz = do
    left <- remainingSize rbuf
    when (left < siz) $ throwIO BufferOverrun
data BufferOverrun = BufferOverrun 
                     deriving (Eq,Show,Typeable)
instance Exception BufferOverrun