{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE CPP                      #-}
{-# LANGUAGE MagicHash                #-}
{-# LANGUAGE UnboxedTuples            #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE ScopedTypeVariables      #-}
#include "cbor.h"
module Codec.CBOR.Magic
  ( 
    grabWord8         
  , grabWord16        
  , grabWord32        
  , grabWord64        
    
  , eatTailWord8      
  , eatTailWord16     
  , eatTailWord32     
  , eatTailWord64     
    
  , wordToFloat16     
  , floatToWord16     
    
  , wordToFloat32     
  , wordToFloat64     
    
  , word8ToWord       
  , word16ToWord      
  , word32ToWord      
  , word64ToWord      
  
  , word8ToInt        
  , word16ToInt       
  , word32ToInt       
  , word64ToInt       
  , intToInt64        
#if defined(ARCH_32bit)
  , word8ToInt64      
  , word16ToInt64     
  , word32ToInt64     
  , word64ToInt64     
  , word8ToWord64     
  , word16ToWord64    
  , word32ToWord64    
#endif
    
  , nintegerFromBytes 
  , uintegerFromBytes 
    
  , Counter           
  , newCounter        
  , readCounter       
  , writeCounter      
  , incCounter        
  , decCounter        
    
  , copyByteStringToByteArray
  , copyByteArrayToByteString
  ) where
import           GHC.Exts
import           GHC.ST (ST(ST))
import           GHC.IO (IO(IO), unsafeDupablePerformIO)
import           GHC.Word
import           GHC.Int
#if MIN_VERSION_base(4,11,0)
import           GHC.Float (castWord32ToFloat, castWord64ToDouble)
#endif
import           Foreign.Ptr
#if defined(OPTIMIZE_GMP)
import qualified GHC.Integer.GMP.Internals      as Gmp
#endif
import           Data.ByteString (ByteString)
import qualified Data.ByteString          as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Unsafe   as BS
import           Data.Primitive.ByteArray as Prim
import           Foreign.ForeignPtr (withForeignPtr)
import           Foreign.C (CUShort)
import qualified Numeric.Half as Half
#if !defined(HAVE_BYTESWAP_PRIMOPS) || !defined(MEM_UNALIGNED_OPS) || !defined(OPTIMIZE_GMP)
import           Data.Bits ((.|.), unsafeShiftL)
#endif
#if defined(ARCH_32bit)
import           GHC.IntWord64 (wordToWord64#, word64ToWord#,
                                intToInt64#, int64ToInt#,
                                leWord64#, ltWord64#, word64ToInt64#)
#endif
grabWord8 :: Ptr () -> Word8
{-# INLINE grabWord8 #-}
grabWord16 :: Ptr () -> Word16
{-# INLINE grabWord16 #-}
grabWord32 :: Ptr () -> Word32
{-# INLINE grabWord32 #-}
grabWord64 :: Ptr () -> Word64
{-# INLINE grabWord64 #-}
grabWord8 (Ptr ip#) = W8# (indexWord8OffAddr# ip# 0#)
#if defined(HAVE_BYTESWAP_PRIMOPS) && \
    defined(MEM_UNALIGNED_OPS) && \
   !defined(WORDS_BIGENDIAN)
grabWord16 (Ptr ip#) = W16# (narrow16Word# (byteSwap16# (indexWord16OffAddr# ip# 0#)))
grabWord32 (Ptr ip#) = W32# (narrow32Word# (byteSwap32# (indexWord32OffAddr# ip# 0#)))
#if defined(ARCH_64bit)
grabWord64 (Ptr ip#) = W64# (byteSwap# (indexWord64OffAddr# ip# 0#))
#else
grabWord64 (Ptr ip#) = W64# (byteSwap64# (indexWord64OffAddr# ip# 0#))
#endif
#elif defined(MEM_UNALIGNED_OPS) && \
      defined(WORDS_BIGENDIAN)
grabWord16 (Ptr ip#) = W16# (indexWord16OffAddr# ip# 0#)
grabWord32 (Ptr ip#) = W32# (indexWord32OffAddr# ip# 0#)
grabWord64 (Ptr ip#) = W64# (indexWord64OffAddr# ip# 0#)
#else
grabWord16 (Ptr ip#) =
    case indexWord8OffAddr# ip# 0# of
     w0# ->
      case indexWord8OffAddr# ip# 1# of
       w1# -> W16# w0# `unsafeShiftL` 8 .|.
              W16# w1#
grabWord32 (Ptr ip#) =
    case indexWord8OffAddr# ip# 0# of
     w0# ->
      case indexWord8OffAddr# ip# 1# of
       w1# ->
        case indexWord8OffAddr# ip# 2# of
         w2# ->
          case indexWord8OffAddr# ip# 3# of
           w3# -> W32# w0# `unsafeShiftL` 24 .|.
                  W32# w1# `unsafeShiftL` 16 .|.
                  W32# w2# `unsafeShiftL`  8 .|.
                  W32# w3#
grabWord64 (Ptr ip#) =
    case indexWord8OffAddr# ip# 0# of
     w0# ->
      case indexWord8OffAddr# ip# 1# of
       w1# ->
        case indexWord8OffAddr# ip# 2# of
         w2# ->
          case indexWord8OffAddr# ip# 3# of
           w3# ->
            case indexWord8OffAddr# ip# 4# of
             w4# ->
              case indexWord8OffAddr# ip# 5# of
               w5# ->
                case indexWord8OffAddr# ip# 6# of
                 w6# ->
                  case indexWord8OffAddr# ip# 7# of
                   w7# -> w w0# `unsafeShiftL` 56 .|.
                          w w1# `unsafeShiftL` 48 .|.
                          w w2# `unsafeShiftL` 40 .|.
                          w w3# `unsafeShiftL` 32 .|.
                          w w4# `unsafeShiftL` 24 .|.
                          w w5# `unsafeShiftL` 16 .|.
                          w w6# `unsafeShiftL`  8 .|.
                          w w7#
  where
#if defined(ARCH_64bit)
    w w# = W64# w#
#else
    w w# = W64# (wordToWord64# w#)
#endif
#endif
eatTailWord8 :: ByteString -> Word8
eatTailWord8 xs = withBsPtr grabWord8 (BS.unsafeTail xs)
{-# INLINE eatTailWord8 #-}
eatTailWord16 :: ByteString -> Word16
eatTailWord16 xs = withBsPtr grabWord16 (BS.unsafeTail xs)
{-# INLINE eatTailWord16 #-}
eatTailWord32 :: ByteString -> Word32
eatTailWord32 xs = withBsPtr grabWord32 (BS.unsafeTail xs)
{-# INLINE eatTailWord32 #-}
eatTailWord64 :: ByteString -> Word64
eatTailWord64 xs = withBsPtr grabWord64 (BS.unsafeTail xs)
{-# INLINE eatTailWord64 #-}
withBsPtr :: (Ptr b -> a) -> ByteString -> a
withBsPtr f (BS.PS x off _) =
    unsafeDupablePerformIO $ withForeignPtr x $
        \(Ptr addr#) -> return $! (f (Ptr addr# `plusPtr` off))
{-# INLINE withBsPtr #-}
wordToFloat16 :: Word16 -> Float
wordToFloat16 = \x -> Half.fromHalf (Half.Half (cast x))
  where
    cast :: Word16 -> CUShort
    cast = fromIntegral
{-# INLINE wordToFloat16 #-}
floatToWord16 :: Float -> Word16
floatToWord16 = \x -> cast (Half.getHalf (Half.toHalf x))
  where
    cast :: CUShort -> Word16
    cast = fromIntegral
{-# INLINE floatToWord16 #-}
wordToFloat32 :: Word32 -> Float
#if MIN_VERSION_base(4,11,0)
wordToFloat32 = GHC.Float.castWord32ToFloat
#else
wordToFloat32 (W32# w#) = F# (wordToFloat32# w#)
{-# INLINE wordToFloat32 #-}
wordToFloat32# :: Word# -> Float#
wordToFloat32# w# =
    case newByteArray# 4# realWorld# of
      (# s', mba# #) ->
        case writeWord32Array# mba# 0# w# s' of
          s'' ->
            case readFloatArray# mba# 0# s'' of
              (# _, f# #) -> f#
{-# NOINLINE wordToFloat32# #-}
#endif
wordToFloat64 :: Word64 -> Double
#if MIN_VERSION_base(4,11,0)
wordToFloat64 = GHC.Float.castWord64ToDouble
#else
wordToFloat64 (W64# w#) = D# (wordToFloat64# w#)
{-# INLINE wordToFloat64 #-}
#if defined(ARCH_64bit)
wordToFloat64# :: Word# -> Double#
#else
wordToFloat64# :: Word64# -> Double#
#endif
wordToFloat64# w# =
    case newByteArray# 8# realWorld# of
      (# s', mba# #) ->
        case writeWord64Array# mba# 0# w# s' of
          s'' ->
            case readDoubleArray# mba# 0# s'' of
              (# _, f# #) -> f#
{-# NOINLINE wordToFloat64# #-}
#endif
word8ToWord  :: Word8  -> Word
word16ToWord :: Word16 -> Word
word32ToWord :: Word32 -> Word
#if defined(ARCH_64bit)
word64ToWord :: Word64 -> Word
#else
word64ToWord :: Word64 -> Maybe Word
#endif
word8ToInt  :: Word8  -> Int
word16ToInt :: Word16 -> Int
#if defined(ARCH_64bit)
word32ToInt :: Word32 -> Int
#else
word32ToInt :: Word32 -> Maybe Int
#endif
word64ToInt :: Word64 -> Maybe Int
#if defined(ARCH_32bit)
word8ToInt64  :: Word8  -> Int64
word16ToInt64 :: Word16 -> Int64
word32ToInt64 :: Word32 -> Int64
word64ToInt64 :: Word64 -> Maybe Int64
word8ToWord64  :: Word8  -> Word64
word16ToWord64 :: Word16 -> Word64
word32ToWord64 :: Word32 -> Word64
#endif
intToInt64 :: Int -> Int64
intToInt64 = fromIntegral
{-# INLINE intToInt64 #-}
word8ToWord  (W8#  w#) = W# w#
word16ToWord (W16# w#) = W# w#
word32ToWord (W32# w#) = W# w#
#if defined(ARCH_64bit)
word64ToWord (W64# w#) = W# w#
#else
word64ToWord (W64# w64#) =
  case isTrue# (w64# `leWord64#` wordToWord64# 0xffffffff##) of
    True  -> Just (W# (word64ToWord# w64#))
    False -> Nothing
#endif
{-# INLINE word8ToWord #-}
{-# INLINE word16ToWord #-}
{-# INLINE word32ToWord #-}
{-# INLINE word64ToWord #-}
word8ToInt  (W8#  w#) = I# (word2Int# w#)
word16ToInt (W16# w#) = I# (word2Int# w#)
#if defined(ARCH_64bit)
word32ToInt (W32# w#) = I# (word2Int# w#)
#else
word32ToInt (W32# w#) =
  case isTrue# (w# `ltWord#` 0x80000000##) of
    True  -> Just (I# (word2Int# w#))
    False -> Nothing
#endif
#if defined(ARCH_64bit)
word64ToInt (W64# w#) =
  case isTrue# (w# `ltWord#` 0x8000000000000000##) of
    True  -> Just (I# (word2Int# w#))
    False -> Nothing
#else
word64ToInt (W64# w#) =
  case isTrue# (w# `ltWord64#` wordToWord64# 0x80000000##) of
    True  -> Just (I# (int64ToInt# (word64ToInt64# w#)))
    False -> Nothing
#endif
{-# INLINE word8ToInt #-}
{-# INLINE word16ToInt #-}
{-# INLINE word32ToInt #-}
{-# INLINE word64ToInt #-}
#if defined(ARCH_32bit)
word8ToInt64  (W8#  w#) = I64# (intToInt64# (word2Int# w#))
word16ToInt64 (W16# w#) = I64# (intToInt64# (word2Int# w#))
word32ToInt64 (W32# w#) = I64# (word64ToInt64# (wordToWord64# w#))
word64ToInt64 (W64# w#) =
  case isTrue# (w# `ltWord64#` uncheckedShiftL64# (wordToWord64# 1##) 63#) of
    True  -> Just (I64# (word64ToInt64# w#))
    False -> Nothing
word8ToWord64  (W8#  w#) = W64# (wordToWord64# w#)
word16ToWord64 (W16# w#) = W64# (wordToWord64# w#)
word32ToWord64 (W32# w#) = W64# (wordToWord64# w#)
{-# INLINE word8ToInt64  #-}
{-# INLINE word16ToInt64 #-}
{-# INLINE word32ToInt64 #-}
{-# INLINE word64ToInt64 #-}
{-# INLINE word8ToWord64  #-}
{-# INLINE word16ToWord64 #-}
{-# INLINE word32ToWord64 #-}
#endif
nintegerFromBytes :: BS.ByteString -> Integer
nintegerFromBytes bs = -1 - uintegerFromBytes bs
uintegerFromBytes :: BS.ByteString -> Integer
#if defined(OPTIMIZE_GMP)
uintegerFromBytes (BS.PS fp (I# off#) (I# len#)) =
  
  
  
  unsafeDupablePerformIO $
      withForeignPtr fp $ \(Ptr addr#) ->
          let addrOff# = addr# `plusAddr#` off#
          
          
          in Gmp.importIntegerFromAddr addrOff# (int2Word# len#) 1#
#else
uintegerFromBytes bs =
    case BS.uncons bs of
      Nothing        -> 0
      Just (w0, ws0) -> go (fromIntegral w0) ws0
  where
    go !acc ws =
      case BS.uncons ws of
        Nothing       -> acc
        Just (w, ws') -> go (acc `unsafeShiftL` 8 + fromIntegral w) ws'
#endif
data Counter s = Counter (MutableByteArray# s)
newCounter :: Int -> ST s (Counter s)
newCounter (I# n#) =
    ST (\s ->
      case newByteArray# 8# s of
        (# s', mba# #) ->
          case writeIntArray# mba# 0# n# s' of
            s'' -> (# s'', Counter mba# #))
{-# INLINE newCounter   #-}
readCounter :: Counter s -> ST s Int
readCounter (Counter mba#) =
    ST (\s ->
      case readIntArray# mba# 0# s of
        (# s', n# #) -> (# s', I# n# #))
{-# INLINE readCounter  #-}
writeCounter :: Counter s -> Int -> ST s ()
writeCounter (Counter mba#) (I# n#) =
    ST (\s ->
      case writeIntArray# mba# 0# n# s of
        s' -> (# s', () #))
{-# INLINE writeCounter #-}
incCounter :: Counter s -> ST s ()
incCounter c = do
  x <- readCounter c
  writeCounter c (x+1)
{-# INLINE incCounter #-}
decCounter :: Counter s -> ST s ()
decCounter c = do
  x <- readCounter c
  writeCounter c (x-1)
{-# INLINE decCounter #-}
copyByteStringToByteArray :: BS.ByteString -> Prim.ByteArray
copyByteStringToByteArray (BS.PS fp off len) =
    unsafeDupablePerformIO $
      withForeignPtr fp $ \ptr -> do
        mba <- Prim.newByteArray len
        copyPtrToMutableByteArray (ptr `plusPtr` off) mba 0 len
        Prim.unsafeFreezeByteArray mba
copyByteArrayToByteString :: Prim.ByteArray
                          
                          -> Int
                          
                          -> Int
                          
                          -> BS.ByteString
copyByteArrayToByteString ba off len =
    unsafeDupablePerformIO $ do
      fp <- BS.mallocByteString len
      withForeignPtr fp $ \ptr -> do
        copyByteArrayToPtr ba off ptr len
        return (BS.PS fp 0 len)
copyPtrToMutableByteArray :: Ptr a
                          
                          -> MutableByteArray RealWorld
                          
                          -> Int
                          
                          -> Int
                          
                          -> IO ()
copyPtrToMutableByteArray (Ptr addr#) (MutableByteArray mba#) (I# off#) (I# len#) =
    IO (\s ->
      case copyAddrToByteArray# addr# mba# off# len# s of
        s' -> (# s', () #))
copyByteArrayToPtr :: ByteArray
                   
                   -> Int
                   
                   -> Ptr a
                   
                   -> Int
                   
                   -> IO ()
copyByteArrayToPtr (ByteArray ba#) (I# off#) (Ptr addr#) (I# len#) =
    IO (\s ->
      case copyByteArrayToAddr# ba# off# addr# len# s of
        s' -> (# s', () #))