{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE UnliftedFFITypes #-}
{-# LANGUAGE BangPatterns #-}
module Std.Data.CBytes
  ( CBytes
  , create
  , pack
  , unpack
  , null , length
  , empty, append, concat
  , toBytes, fromBytes, fromText
  , fromCStringMaybe, fromCString, fromCStringN
  , withCBytes
  ) where
import           Control.Monad
import           Control.Monad.Primitive
import           Control.Monad.ST
import           Data.Bits
import           Data.Foldable           (foldlM)
import           Data.Hashable           (Hashable(..),
                                            hashByteArrayWithSalt, hashPtrWithSalt)
import qualified Data.List               as List
import           Data.Monoid             (Monoid (..))
import           Data.Semigroup          (Semigroup (..))
import           Data.String             (IsString (..))
import           Data.Primitive.PrimArray
import           Data.Word
import           Foreign.C
import           Foreign.Storable        (peekElemOff)
import           GHC.CString
import           GHC.Ptr
import           Prelude                 hiding (all, any, appendFile, break,
                                          concat, concatMap, drop, dropWhile,
                                          elem, filter, foldl, foldl1, foldr,
                                          foldr1, getContents, getLine, head,
                                          init, interact, last, length, lines,
                                          map, maximum, minimum, notElem, null,
                                          putStr, putStrLn, readFile, replicate,
                                          reverse, scanl, scanl1, scanr, scanr1,
                                          span, splitAt, tail, take, takeWhile,
                                          unlines, unzip, writeFile, zip,
                                          zipWith)
import           Std.Data.Array
import qualified Std.Data.Text           as T
import           Std.Data.Text.UTF8Codec (encodeCharModifiedUTF8)
import qualified Std.Data.Vector.Base    as V
import           Std.IO.Exception
import           System.IO.Unsafe        (unsafeDupablePerformIO)
data CBytes
    = CBytesOnHeap  {-# UNPACK #-} !(PrimArray Word8)   
                                                        
                                                        
                                                        
    | CBytesLiteral {-# UNPACK #-} !CString             
create :: HasCallStack
       => Int  
       -> (CString -> IO Int)  
                               
       -> IO CBytes
{-# INLINE create #-}
create n fill = do
    mba <- newPinnedPrimArray n :: IO (MutablePrimArray RealWorld Word8)
    l <- withMutablePrimArrayContents mba (fill . castPtr)
    writePrimArray mba l 0 
    shrinkMutablePrimArray mba (l+1)
    CBytesOnHeap <$> unsafeFreezePrimArray mba
instance Show CBytes where
    show = unpack
instance Read CBytes where
    readsPrec p s = [(pack x, r) | (x, r) <- readsPrec p s]
instance Eq CBytes where
    cbyteA == cbyteB = unsafeDupablePerformIO $
        withCBytes cbyteA $ \ pA ->
        withCBytes cbyteB $ \ pB ->
            if pA == pB
            then return True
            else do
                r <- c_strcmp pA pB
                return (r == 0)
instance Ord CBytes where
    cbyteA `compare` cbyteB = unsafeDupablePerformIO $
        withCBytes cbyteA $ \ pA ->
        withCBytes cbyteB $ \ pB ->
            if pA == pB
            then return EQ
            else do
                r <- c_strcmp pA pB
                return (r `compare` 0)
instance Semigroup CBytes where
    (<>) = append
instance Monoid CBytes where
    {-# INLINE mempty #-}
    mempty  = empty
    {-# INLINE mappend #-}
    mappend = append
    {-# INLINE mconcat #-}
    mconcat = concat
instance Hashable CBytes where
    hashWithSalt salt (CBytesOnHeap pa@(PrimArray ba#)) = unsafeDupablePerformIO $ do
        V.c_fnv_hash_ba ba# 0 (sizeofPrimArray pa - 1) salt
    hashWithSalt salt (CBytesLiteral p@(Ptr addr#)) = unsafeDupablePerformIO $ do
        len <- c_strlen p
        V.c_fnv_hash_addr addr# (fromIntegral len) salt
append :: CBytes -> CBytes -> CBytes
{-# INLINABLE append #-}
append strA strB
    | lenA == 0 = strB
    | lenB == 0 = strA
    | otherwise = unsafeDupablePerformIO $ do
        mpa <- newPinnedPrimArray (lenA+lenB+1)
        withCBytes strA $ \ pa ->
            withCBytes strB $ \ pb -> do
                copyPtrToMutablePrimArray mpa 0    (castPtr pa) lenA
                copyPtrToMutablePrimArray mpa lenA (castPtr pb) lenB
                writePrimArray mpa (lenA + lenB) 0     
                pa <- unsafeFreezePrimArray mpa
                return (CBytesOnHeap pa)
  where
    lenA = length strA
    lenB = length strB
empty :: CBytes
{-# NOINLINE empty #-}
empty = CBytesLiteral (Ptr "\0"#)
concat :: [CBytes] -> CBytes
{-# INLINABLE concat #-}
concat bs = case pre 0 0 bs of
    (0, _) -> empty
    (1, _) -> let Just b = List.find (not . null) bs in b 
    (_, l) -> runST $ do
        buf <- newPinnedPrimArray (l+1)
        copy bs 0 buf
        writePrimArray buf l 0 
        CBytesOnHeap <$> unsafeFreezePrimArray buf
  where
    
    
    pre :: Int -> Int -> [CBytes] -> (Int, Int)
    pre !nacc !lacc [] = (nacc, lacc)
    pre !nacc !lacc (b:bs)
        | length b <= 0 = pre nacc lacc bs
        | otherwise     = pre (nacc+1) (length b + lacc) bs
    copy :: [CBytes] -> Int -> MutablePrimArray s Word8 -> ST s ()
    copy [] !_ !_       = return ()
    copy (b:bs) !i !mba = do
        let l = length b
        when (l /= 0) (case b of
            CBytesOnHeap ba ->
                copyPrimArray mba i ba 0 l
            CBytesLiteral p ->
                copyPtrToMutablePrimArray mba i (castPtr p) l)
        copy bs (i+l) mba
instance IsString CBytes where
    {-# INLINE fromString #-}
    fromString = pack
{-# RULES
    "CBytes pack/unpackCString#" forall addr# .
        pack (unpackCString# addr#) = CBytesLiteral (Ptr addr#)
 #-}
{-# RULES
    "CBytes pack/unpackCStringUtf8#" forall addr# .
        pack (unpackCStringUtf8# addr#) = CBytesLiteral (Ptr addr#)
 #-}
pack :: String -> CBytes
{-# INLINE CONLIKE [1] pack #-}
pack s = runST $ do
    mba <- newPinnedPrimArray V.defaultInitSize
    (SP2 i mba') <- foldlM go (SP2 0 mba) s
    writePrimArray mba' i 0     
    shrinkMutablePrimArray mba' (i+1)
    ba <- unsafeFreezePrimArray mba'
    return (CBytesOnHeap ba)
  where
    
    
    go :: SP2 s -> Char -> ST s (SP2 s)
    go (SP2 i mba) !c     = do
        siz <- getSizeofMutablePrimArray mba
        if i < siz - 4  
        then do
            i' <- encodeCharModifiedUTF8 mba i c
            return (SP2 i' mba)
        else do
            let !siz' = siz `shiftL` 1
            !mba' <- resizeMutablePrimArray mba siz'
            i' <- encodeCharModifiedUTF8 mba' i c
            return (SP2 i' mba')
data SP2 s = SP2 {-# UNPACK #-}!Int {-# UNPACK #-}!(MutablePrimArray s Word8)
unpack :: CBytes -> String
{-# INLINABLE unpack #-}
unpack cbytes = unsafeDupablePerformIO . withCBytes cbytes $ \ (Ptr addr#) ->
    return (unpackCStringUtf8# addr#)
null :: CBytes -> Bool
{-# INLINE null #-}
null (CBytesOnHeap pa) = indexPrimArray pa 0 == 0
null (CBytesLiteral p) = unsafeDupablePerformIO (peekElemOff p 0) == 0
length :: CBytes -> Int
{-# INLINE length #-}
length (CBytesOnHeap pa) = sizeofPrimArray pa - 1
length (CBytesLiteral p) = fromIntegral $ unsafeDupablePerformIO (c_strlen p)
toBytes :: CBytes -> V.Bytes
{-# INLINABLE toBytes #-}
toBytes cbytes@(CBytesOnHeap pa) = V.PrimVector pa 0 l
  where l = length cbytes
toBytes cbytes@(CBytesLiteral p) = V.create (l+1) (\ mpa -> do
    copyPtrToMutablePrimArray mpa 0 (castPtr p) l
    writePrimArray mpa l 0)    
  where l = length cbytes
fromBytes :: V.Bytes -> CBytes
{-# INLINABLE fromBytes #-}
fromBytes (V.Vec arr s l) =  runST (do
        mpa <- newPinnedPrimArray (l+1)
        copyPrimArray mpa 0 arr s l
        writePrimArray mpa l 0     
        pa <- unsafeFreezePrimArray mpa
        return (CBytesOnHeap pa))
fromText :: T.Text -> CBytes
{-# INLINABLE fromText #-}
fromText = fromBytes . T.getUTF8Bytes
fromCStringMaybe :: HasCallStack => CString -> IO (Maybe CBytes)
{-# INLINABLE fromCStringMaybe #-}
fromCStringMaybe cstring =
    if cstring == nullPtr
    then return Nothing
    else do
        len <- fromIntegral <$> c_strlen cstring
        mpa <- newPinnedPrimArray (len+1)
        copyPtrToMutablePrimArray mpa 0 (castPtr cstring) len
        writePrimArray mpa len 0     
        pa <- unsafeFreezePrimArray mpa
        return (Just (CBytesOnHeap pa))
fromCString :: HasCallStack
            => CString
            -> IO CBytes
{-# INLINABLE fromCString #-}
fromCString cstring =
    if cstring == nullPtr
    then throwIO (InvalidArgument
        (IOEInfo "" "unexpected null pointer" callStack))
    else do
        len <- fromIntegral <$> c_strlen cstring
        mpa <- newPinnedPrimArray (len+1)
        copyPtrToMutablePrimArray mpa 0 (castPtr cstring) len
        writePrimArray mpa len 0     
        pa <- unsafeFreezePrimArray mpa
        return (CBytesOnHeap pa)
fromCStringN :: HasCallStack
            => CString
            -> Int
            -> IO CBytes
{-# INLINABLE fromCStringN #-}
fromCStringN cstring len =
    if cstring == nullPtr
    then throwIO (InvalidArgument
        (IOEInfo "" "unexpected null pointer" callStack))
    else do
        mpa <- newPinnedPrimArray (len+1)
        copyPtrToMutablePrimArray mpa 0 (castPtr cstring) len
        writePrimArray mpa len 0     
        pa <- unsafeFreezePrimArray mpa
        return (CBytesOnHeap pa)
withCBytes :: CBytes -> (CString -> IO a) -> IO a
{-# INLINABLE withCBytes #-}
withCBytes (CBytesOnHeap pa) f = withPrimArrayContents pa (f . castPtr)
withCBytes (CBytesLiteral ptr) f = f ptr
c_strcmp :: CString -> CString -> IO CInt
{-# INLINE c_strcmp #-}
c_strcmp (Ptr a#) (Ptr b#) = V.c_strcmp a# b#
c_strlen :: CString -> IO CSize
{-# INLINE c_strlen #-}
c_strlen (Ptr a#) = V.c_strlen a#