{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE BlockArguments #-}
-- {-# LANGUAGE DataKinds #-} -- needed for manual ZeroBitType def (unsure why)
-- {-# LANGUAGE FlexibleInstances #-}

{- | Struct parser.

We do still have to do failure checking, because unlike C we check some types
(e.g. bitfields). Hopefully inlining can remove those checks when unnecessary.
-}

module Bytezap.Parser.Struct where

import GHC.Exts
import GHC.ForeignPtr
import Data.Void ( Void )

import Data.Word ( Word8 )
import Data.ByteString.Internal qualified as B
import System.IO.Unsafe ( unsafePerformIO )

import Raehik.Compat.Data.Primitive.Types

import Data.Bits
  ( Bits( (.&.), unsafeShiftR, xor )
  , FiniteBits(countTrailingZeros)
  )

type PureMode = Proxy# Void
type IOMode   = State# RealWorld
type STMode s = State# s

type ParserT# (st :: ZeroBitType) e a =
       ForeignPtrContents {- ^ pointer provenance (does not change) -}
    -> Addr#              {- ^ base address (does not change) -}
    -> Int#               {- ^ cursor offset from base -}
    -> st                 {- ^ state token -}
    -> Res# st e a        {- ^ result -}

{- | Like flatparse, but no buffer length (= no buffer overflow checking), and
     no 'Addr#' on success (= no dynamic length parses).

Unlike flatparse, we separate base address from offset, rather than adding
them. This fits the unaligned 'Addr#' primops (added in GHC 9.10) better, and
in my head should hopefully assist in emitting immediates where possible for
offsets on the assembly level.

Combining them like in flatparse might be faster; but I really don't know how to
find out, without doing both and comparing various examples. After a lot of
scratching my head, I think this is most appropriate.

The 'ForeignPtrContents' is for keeping the 'Addr#' data in scope.
-}
newtype ParserT (st :: ZeroBitType) e a =
    ParserT { forall (st :: ZeroBitType) e a. ParserT st e a -> ParserT# st e a
runParserT# :: ParserT# st e a }

instance Functor (ParserT st e) where
  fmap :: forall a b. (a -> b) -> ParserT st e a -> ParserT st e b
fmap a -> b
f (ParserT ParserT# st e a
g) = ParserT# st e b -> ParserT st e b
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
fpc Addr#
base Int#
os st
st0 -> case ParserT# st e a
g ForeignPtrContents
fpc Addr#
base Int#
os st
st0 of
    OK# st
st1 a
a -> let !b :: b
b = a -> b
f a
a in st -> b -> Res# st e b
forall (st :: ZeroBitType) a e. st -> a -> Res# st e a
OK# st
st1 b
b
    Res# st e a
x         -> Res# st e a -> Res# st e b
forall a b. a -> b
unsafeCoerce# Res# st e a
x
  {-# inline fmap #-}

-- No Applicative due to no offset passing.

-- | The type of pure parsers.
type Parser     = ParserT PureMode

-- | The type of parsers which can embed `IO` actions.
type ParserIO   = ParserT IOMode

-- | The type of parsers which can embed `ST` actions.
type ParserST s = ParserT (STMode s)

-- | Primitive parser result wrapped with a state token.
--
-- You should rarely need to manipulate values of this type directly. Use the
-- provided bidirectional pattern synonyms 'OK#', 'Fail#' and 'Err#'.
type Res# (st :: ZeroBitType) e a =
  (# st, ResI# e a #)

-- | Primitive parser result.
--
-- Like flatparse, but no 'Addr#' on success.
type ResI# e a =
  (#
    (# a #)
  | (# #)
  | (# e #)
  #)

-- | 'Res#' constructor for a successful parse.
--   Contains the return value and a state token.
pattern OK# :: (st :: ZeroBitType) -> a -> Res# st e a
pattern $mOK# :: forall {r} {st :: ZeroBitType} {a} {e}.
Res# st e a -> (st -> a -> r) -> ((# #) -> r) -> r
$bOK# :: forall (st :: ZeroBitType) a e. st -> a -> Res# st e a
OK# st a = (# st, (# (# a #) | | #) #)

-- | 'Res#' constructor for recoverable failure.
--   Contains only a state token.
pattern Fail# :: (st :: ZeroBitType) -> Res# st e a
pattern $mFail# :: forall {r} {st :: ZeroBitType} {e} {a}.
Res# st e a -> (st -> r) -> ((# #) -> r) -> r
$bFail# :: forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st = (# st, (# | (# #) | #) #)

-- | 'Res#' constructor for errors which are by default non-recoverable.
--    Contains the error, plus a state token.
pattern Err# :: (st :: ZeroBitType) -> e -> Res# st e a
pattern $mErr# :: forall {r} {st :: ZeroBitType} {e} {a}.
Res# st e a -> (st -> e -> r) -> ((# #) -> r) -> r
$bErr# :: forall (st :: ZeroBitType) e a. st -> e -> Res# st e a
Err# st e = (# st, (# | | (# e #) #) #)
{-# complete OK#, Fail#, Err# #-}

-- | caller must guarantee that buffer is long enough for parser!!
unsafeRunParserBs :: forall a e. B.ByteString -> Parser e a -> Result e a
unsafeRunParserBs :: forall a e. ByteString -> Parser e a -> Result e a
unsafeRunParserBs (B.BS ForeignPtr Word8
fptr Int
_) = ForeignPtr Word8 -> Parser e a -> Result e a
forall a e. ForeignPtr Word8 -> Parser e a -> Result e a
unsafeRunParserFPtr ForeignPtr Word8
fptr

-- | caller must guarantee that buffer is long enough for parser!!
unsafeRunParserPtr :: forall a e. Ptr Word8 -> Parser e a -> Result e a
unsafeRunParserPtr :: forall a e. Ptr Word8 -> Parser e a -> Result e a
unsafeRunParserPtr (Ptr Addr#
base#) = Addr# -> ForeignPtrContents -> Parser e a -> Result e a
forall a e. Addr# -> ForeignPtrContents -> Parser e a -> Result e a
unsafeRunParser' Addr#
base# ForeignPtrContents
FinalPtr

-- | caller must guarantee that buffer is long enough for parser!!
unsafeRunParserFPtr :: forall a e. ForeignPtr Word8 -> Parser e a -> Result e a
unsafeRunParserFPtr :: forall a e. ForeignPtr Word8 -> Parser e a -> Result e a
unsafeRunParserFPtr ForeignPtr Word8
fptr Parser e a
p =
    IO (Result e a) -> Result e a
forall a. IO a -> a
unsafePerformIO (IO (Result e a) -> Result e a) -> IO (Result e a) -> Result e a
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8
-> (Ptr Word8 -> IO (Result e a)) -> IO (Result e a)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
B.unsafeWithForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO (Result e a)) -> IO (Result e a))
-> (Ptr Word8 -> IO (Result e a)) -> IO (Result e a)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
        Result e a -> IO (Result e a)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Result e a -> IO (Result e a)) -> Result e a -> IO (Result e a)
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Parser e a -> Result e a
forall a e. Ptr Word8 -> Parser e a -> Result e a
unsafeRunParserPtr Ptr Word8
ptr Parser e a
p

-- | caller must guarantee that buffer is long enough for parser!!
unsafeRunParser'
    :: forall a e. Addr# -> ForeignPtrContents -> Parser e a -> Result e a
unsafeRunParser' :: forall a e. Addr# -> ForeignPtrContents -> Parser e a -> Result e a
unsafeRunParser' Addr#
base# ForeignPtrContents
fpc (ParserT ParserT# PureMode e a
p) =
    case ParserT# PureMode e a
p ForeignPtrContents
fpc Addr#
base# Int#
0# PureMode
forall {k} (a :: k). Proxy# a
proxy# of
      OK#   PureMode
_st1 a
a -> a -> Result e a
forall e a. a -> Result e a
OK a
a
      Err#  PureMode
_st1 e
e -> e -> Result e a
forall e a. e -> Result e a
Err e
e
      Fail# PureMode
_st1   -> Result e a
forall e a. Result e a
Fail

-- | Higher-level boxed data type for parsing results.
data Result e a =
    OK a    -- ^ Contains return value.
  | Fail    -- ^ Recoverable-by-default failure.
  | Err !e  -- ^ Unrecoverable-by-default error.
  deriving Int -> Result e a -> ShowS
[Result e a] -> ShowS
Result e a -> String
(Int -> Result e a -> ShowS)
-> (Result e a -> String)
-> ([Result e a] -> ShowS)
-> Show (Result e a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall e a. (Show a, Show e) => Int -> Result e a -> ShowS
forall e a. (Show a, Show e) => [Result e a] -> ShowS
forall e a. (Show a, Show e) => Result e a -> String
$cshowsPrec :: forall e a. (Show a, Show e) => Int -> Result e a -> ShowS
showsPrec :: Int -> Result e a -> ShowS
$cshow :: forall e a. (Show a, Show e) => Result e a -> String
show :: Result e a -> String
$cshowList :: forall e a. (Show a, Show e) => [Result e a] -> ShowS
showList :: [Result e a] -> ShowS
Show

-- | can't provide via 'pure' as no 'Applicative'
constParse :: a -> ParserT st e a
constParse :: forall a (st :: ZeroBitType) e. a -> ParserT st e a
constParse a
a = ParserT# st e a -> ParserT st e a
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
_fpc Addr#
_base Int#
_os st
st -> st -> a -> Res# st e a
forall (st :: ZeroBitType) a e. st -> a -> Res# st e a
OK# st
st a
a

sequenceParsers
    :: Int -> (a -> b -> c)
    -> ParserT st e a -> ParserT st e b -> ParserT st e c
sequenceParsers :: forall a b c (st :: ZeroBitType) e.
Int
-> (a -> b -> c)
-> ParserT st e a
-> ParserT st e b
-> ParserT st e c
sequenceParsers (I# Int#
len#) a -> b -> c
f (ParserT ParserT# st e a
pa) (ParserT ParserT# st e b
pb) =
    ParserT# st e c -> ParserT st e c
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
fpc Addr#
base Int#
os# st
st0 ->
        case ParserT# st e a
pa ForeignPtrContents
fpc Addr#
base Int#
os# st
st0 of
          OK# st
st1 a
a ->
            case ParserT# st e b
pb ForeignPtrContents
fpc Addr#
base (Int#
os# Int# -> Int# -> Int#
+# Int#
len#) st
st1 of
              OK# st
st2 b
b -> st -> c -> Res# st e c
forall (st :: ZeroBitType) a e. st -> a -> Res# st e a
OK# st
st2 (a -> b -> c
f a
a b
b)
              Fail# st
st2 ->  st -> Res# st e c
forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st2
              Err# st
st2 e
e -> st -> e -> Res# st e c
forall (st :: ZeroBitType) e a. st -> e -> Res# st e a
Err# st
st2 e
e
          Err# st
st1 e
e -> st -> e -> Res# st e c
forall (st :: ZeroBitType) e a. st -> e -> Res# st e a
Err# st
st1 e
e
          Fail# st
st1 ->  st -> Res# st e c
forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st1

-- TODO using indexWord8OffAddrAs to permit pure mode. flatparse does this (at
-- least for integers). guess it's OK?
-- TODO this doesn't use the state token. scary.
prim :: forall a st e. Prim' a => ParserT st e a
prim :: forall a (st :: ZeroBitType) e. Prim' a => ParserT st e a
prim = ParserT# st e a -> ParserT st e a
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
_fpc Addr#
base Int#
os st
st ->
    case Addr# -> Int# -> a
forall a. Prim' a => Addr# -> Int# -> a
indexWord8OffAddrAs# Addr#
base Int#
os of a
a -> st -> a -> Res# st e a
forall (st :: ZeroBitType) a e. st -> a -> Res# st e a
OK# st
st a
a

-- | parse literal
lit :: Eq a => a -> ParserT st e a -> ParserT st e ()
lit :: forall a (st :: ZeroBitType) e.
Eq a =>
a -> ParserT st e a -> ParserT st e ()
lit a
al (ParserT ParserT# st e a
p) = ParserT# st e () -> ParserT st e ()
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
fpc Addr#
base Int#
os st
st0 ->
    case ParserT# st e a
p ForeignPtrContents
fpc Addr#
base Int#
os st
st0 of
      OK#   st
st1 a
ar -> if a
al a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
ar then st -> () -> Res# st e ()
forall (st :: ZeroBitType) a e. st -> a -> Res# st e a
OK# st
st1 () else st -> Res# st e ()
forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st1
      Err#  st
st1 e
e  -> st -> e -> Res# st e ()
forall (st :: ZeroBitType) e a. st -> e -> Res# st e a
Err#  st
st1 e
e
      Fail# st
st1    -> st -> Res# st e ()
forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st1

-- | parse literal (CPS)
withLit
    :: Eq a => Int# -> a -> ParserT st e a -> ParserT st e r -> ParserT st e r
withLit :: forall a (st :: ZeroBitType) e r.
Eq a =>
Int# -> a -> ParserT st e a -> ParserT st e r -> ParserT st e r
withLit Int#
len# a
al (ParserT ParserT# st e a
p) (ParserT ParserT# st e r
pCont) = ParserT# st e r -> ParserT st e r
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
fpc Addr#
base Int#
os# st
st0 ->
    case ParserT# st e a
p ForeignPtrContents
fpc Addr#
base Int#
os# st
st0 of
      OK#   st
st1 a
ar ->
        if a
al a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
ar then ParserT# st e r
pCont ForeignPtrContents
fpc Addr#
base (Int#
os# Int# -> Int# -> Int#
+# Int#
len#) st
st1 else st -> Res# st e r
forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st1
      Err#  st
st1 e
e  -> st -> e -> Res# st e r
forall (st :: ZeroBitType) e a. st -> e -> Res# st e a
Err#  st
st1 e
e
      Fail# st
st1    -> st -> Res# st e r
forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st1

{- | parse literal, return first (leftmost) failing byte on error (CPS)

This can be used to parse large literals via chunking, rather than byte-by-byte,
while retaining useful error behaviour.

We don't check equality with XOR even though we use that when handling errors,
because it's hard to tell if it would be faster with modern CPUs and compilers.
-}
withLitErr
    :: (Integral a, FiniteBits a)
    => Int# -> a -> Int -> (Addr# -> Int# -> a)
    -> ParserT st (Int, Word8) r
    -> ParserT st (Int, Word8) r
withLitErr :: forall a (st :: ZeroBitType) r.
(Integral a, FiniteBits a) =>
Int#
-> a
-> Int
-> (Addr# -> Int# -> a)
-> ParserT st (Int, Word8) r
-> ParserT st (Int, Word8) r
withLitErr Int#
len# a
aLit Int
idxStart Addr# -> Int# -> a
p (ParserT ParserT# st (Int, Word8) r
pCont) = ParserT# st (Int, Word8) r -> ParserT st (Int, Word8) r
forall (st :: ZeroBitType) e a. ParserT# st e a -> ParserT st e a
ParserT \ForeignPtrContents
fpc Addr#
base# Int#
os# st
st ->
    let aParsed :: a
aParsed = Addr# -> Int# -> a
p Addr#
base# Int#
os#
    in  if   a
aLit a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
aParsed
        then ParserT# st (Int, Word8) r
pCont ForeignPtrContents
fpc Addr#
base# (Int#
os# Int# -> Int# -> Int#
+# Int#
len#) st
st
        else let idxFail :: Int
idxFail = a -> a -> Int
forall a. FiniteBits a => a -> a -> Int
firstNonMatchByteIdx a
aLit a
aParsed
                 bFailed :: a
bFailed = a -> Int -> a
forall a. (Num a, Bits a) => a -> Int -> a
unsafeByteAt a
aParsed Int
idxFail
             in  st -> (Int, Word8) -> Res# st (Int, Word8) r
forall (st :: ZeroBitType) e a. st -> e -> Res# st e a
Err# st
st (Int
idxStart Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
idxFail, a -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
bFailed)
{-# INLINE withLitErr #-}

-- | Given two non-equal words @wActual@ and @wExpect@, return the index of the
--   first non-matching byte. Zero indexed.
--
-- If both words are equal, returns word_size (e.g. 4 for 'Word32').
firstNonMatchByteIdx :: FiniteBits a => a -> a -> Int
firstNonMatchByteIdx :: forall a. FiniteBits a => a -> a -> Int
firstNonMatchByteIdx a
wExpect a
wActual =
    a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (a
wExpect a -> a -> a
forall a. Bits a => a -> a -> a
`xor` a
wActual) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3
{-# INLINE firstNonMatchByteIdx #-}

-- | Get the byte at the given index.
--
-- The return value is guaranteed to be 0x00 - 0xFF (inclusive).
--
-- TODO meaning based on endianness?
unsafeByteAt :: (Num a, Bits a) => a -> Int -> a
unsafeByteAt :: forall a. (Num a, Bits a) => a -> Int -> a
unsafeByteAt a
a Int
idx = (a
a a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
0xFF
{-# INLINE unsafeByteAt #-}