{-# LANGUAGE BangPatterns #-}
module Codec.Serialise.Internal.GeneralisedUTF8
    ( encodeGenUTF8
    , UTF8Encoding(..)
    , decodeGenUTF8
      
    , isSurrogate
    , isValid
    ) where
import Control.Monad.ST
import Data.Bits
import Data.Char
import Data.Word
import qualified Codec.CBOR.ByteArray.Sliced as BAS
import Data.Primitive.ByteArray
data UTF8Encoding = ConformantUTF8 | GeneralisedUTF8
                  deriving (Show, Eq)
isSurrogate :: Char -> Bool
isSurrogate c = c >= '\xd800' && c <= '\xdfff'
encodeGenUTF8 :: String -> (BAS.SlicedByteArray, UTF8Encoding)
encodeGenUTF8 st = runST $ do
    
    
    ba <- newByteArray (length st + 4)
    go ba ConformantUTF8 0 st
  where
    go :: MutableByteArray s -> UTF8Encoding
       -> Int -> [Char]
       -> ST s (BAS.SlicedByteArray, UTF8Encoding)
    go ba !enc !off  [] = do
        ba' <- unsafeFreezeByteArray ba
        return (BAS.SBA ba' 0 off, enc)
    go ba enc off  (c:cs)
      | off + 4 >= cap = do
        
        ba' <- newByteArray (cap + cap `div` 2 + 1)
        copyMutableByteArray ba' 0 ba 0 off
        go ba' enc off (c:cs)
      | c >= '\x10000' = do
        writeByteArray ba (off+0) (0xf0 .|. (0x07 .&. shiftedByte 18))
        writeByteArray ba (off+1) (0x80 .|. (0x3f .&. shiftedByte 12))
        writeByteArray ba (off+2) (0x80 .|. (0x3f .&. shiftedByte  6))
        writeByteArray ba (off+3) (0x80 .|. (0x3f .&. shiftedByte  0))
        go ba enc (off+4) cs
      | c >= '\x0800'  = do
        writeByteArray ba (off+0) (0xe0 .|. (0x0f .&. shiftedByte 12))
        writeByteArray ba (off+1) (0x80 .|. (0x3f .&. shiftedByte  6))
        writeByteArray ba (off+2) (0x80 .|. (0x3f .&. shiftedByte  0))
        
        let enc'
              | isSurrogate c = GeneralisedUTF8
              | otherwise     = enc
        go ba enc' (off+3) cs
      | c >= '\x0080'  = do
        writeByteArray ba (off+0) (0xc0 .|. (0x1f .&. shiftedByte  6))
        writeByteArray ba (off+1) (0x80 .|. (0x3f .&. shiftedByte  0))
        go ba enc (off+2) cs
      | c <= '\x007f'  = do
        writeByteArray ba off (fromIntegral n :: Word8)
        go ba enc (off+1) cs
      | otherwise      = error "encodeGenUTF8: Impossible"
      where
        cap = sizeofMutableByteArray ba
        n = ord c
        shiftedByte :: Int -> Word8
        shiftedByte shft = fromIntegral $ n `shiftR` shft
decodeGenUTF8 :: ByteArray -> String
decodeGenUTF8 ba = go 0
  where
    !len = sizeofByteArray ba
    index :: Int -> Int
    index i = fromIntegral (ba `indexByteArray` i :: Word8)
    go !off
      | off == len = []
      | n0 .&. 0xf8 == 0xf0 =
        let n1 = index (off + 1)
            n2 = index (off + 2)
            n3 = index (off + 3)
            c  = chr $  (n0 .&. 0x07) `shiftL` 18
                    .|. (n1 .&. 0x3f) `shiftL` 12
                    .|. (n2 .&. 0x3f) `shiftL`  6
                    .|. (n3 .&. 0x3f)
        in c : go (off + 4)
      | n0 .&. 0xf0 == 0xe0 =
        let n1 = index (off + 1)
            n2 = index (off + 2)
            c  = chr $  (n0 .&. 0x0f) `shiftL` 12
                    .|. (n1 .&. 0x3f) `shiftL`  6
                    .|. (n2 .&. 0x3f)
        in c : go (off + 3)
      | n0 .&. 0xe0 == 0xc0 =
        let n1 = index (off + 1)
            c  = chr $  (n0 .&. 0x1f) `shiftL`  6
                    .|. (n1 .&. 0x3f)
        in c : go (off + 2)
      | otherwise =
        let c =  chr $  (n0 .&. 0x7f)
        in c : go (off + 1)
      where !n0 = index off
isValid :: UTF8Encoding -> [Word8] -> Bool
isValid encoding = go
  where
    go [] = True
    go (b0:bs)
      | inRange 0x00 0x7f b0 = go bs
    go (b0:b1:bs)
      | inRange 0xc2 0xdf b0
      , inRange 0x80 0xbf b1 = go bs
    go (0xe0:b1:b2:bs)
      | inRange 0xa0 0xbf b1
      , inRange 0x80 0xbf b2 = go bs
    go (0xed:b1:_)
      
      | encoding == ConformantUTF8
      , inRange 0xa0 0xbf b1
      = False
    go (b0:b1:b2:bs)
      | inRange 0xe1 0xef b0
      , inRange 0x80 0xbf b1
      , inRange 0x80 0xbf b2 = go bs
    go (0xf0:b1:b2:b3:bs)
      | inRange 0x90 0xbf b1
      , inRange 0x80 0xbf b2
      , inRange 0x80 0xbf b3 = go bs
    go (b0:b1:b2:b3:bs)
      | inRange 0xf1 0xf3 b0
      , inRange 0x80 0xbf b1
      , inRange 0x80 0xbf b2
      , inRange 0x80 0xbf b3 = go bs
    go (0xf4:b1:b2:b3:bs)
      | inRange 0x80 0x8f b1
      , inRange 0x80 0xbf b2
      , inRange 0x80 0xbf b3 = go bs
    go _ = False
inRange :: Ord a => a -> a -> a -> Bool
inRange lower upper x = lower <= x && x <= upper