-- | Data null-padded to a given length.

{-# LANGUAGE UndecidableInstances #-} -- for PredicateName
{-# LANGUAGE OverloadedStrings #-} -- for refine error builder

module Binrep.Type.NullPadded where

import Binrep
import Bytezap.Poke qualified as BZ
import Bytezap.Struct qualified as BZ.Struct
import FlatParse.Basic qualified as FP
import Raehik.Compat.FlatParse.Basic.WithLength qualified as FP
import Control.Monad.Combinators ( skipCount )

import Rerefined.Predicate.Common
import Rerefined.Refine
import TypeLevelShow.Natural
import TypeLevelShow.Utils
import Data.Text.Builder.Linear qualified as TBL

import GHC.TypeNats
import Util.TypeNats ( natValInt )

import Bytezap.Parser.Struct qualified as BZG
import GHC.Exts ( Int(I#) )

data NullPad (n :: Natural)
instance Predicate (NullPad n) where
    type PredicateName d (NullPad n) = ShowParen (d > 9)
        ("NullPad " ++ ShowNatDec n)

{- | A type which is to be null-padded to a given total length.

Given some @a :: 'NullPadded' n a@, it is guaranteed that

@
'blen' a '<=' 'natValInt' \@n
@

thus

@
'natValInt' \@n '-' 'blen' a '>=' 0
@

That is, the serialized stored data will not be longer than the total length.
-}
type NullPadded n a = Refined (NullPad n) a

instance IsCBLen (NullPadded n a) where type CBLen (NullPadded n a) = n
deriving via ViaCBLen (NullPadded n a) instance KnownNat n => BLen (NullPadded n a)

-- | Assert that term will fit.
instance (KnownPredicateName (NullPad n), BLen a, KnownNat n)
  => Refine (NullPad n) a where
    validate :: Proxy# (NullPad n) -> a -> Maybe RefineFailure
validate Proxy# (NullPad n)
p a
a = Proxy# (NullPad n) -> Bool -> Builder -> Maybe RefineFailure
forall {k} (p :: k).
(Predicate p, KnownPredicateName p) =>
Proxy# p -> Bool -> Builder -> Maybe RefineFailure
validateBool Proxy# (NullPad n)
p (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (Builder -> Maybe RefineFailure) -> Builder -> Maybe RefineFailure
forall a b. (a -> b) -> a -> b
$
        Builder
"too long: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Int -> Builder
forall a. (Integral a, FiniteBits a) => a -> Builder
TBL.fromDec Int
len Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
" > " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Int -> Builder
forall a. (Integral a, FiniteBits a) => a -> Builder
TBL.fromDec Int
n
      where
        n :: Int
n = forall (n :: Natural). KnownNat n => Int
natValInt @n
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a

instance (BLen a, KnownNat n, Put a) => PutC (NullPadded n a) where
    putC :: NullPadded n a -> PutterC
putC NullPadded n a
ra = PutterC -> Int -> PutterC -> PutterC
forall s. Poke s -> Int -> Poke s -> Poke s
BZ.Struct.sequencePokes (Poke RealWorld -> PutterC
forall s. Poke s -> Poke s
BZ.toStructPoke (a -> Poke RealWorld
forall a. Put a => a -> Poke RealWorld
put a
a)) Int
len
        (Int -> Word8 -> PutterC
BZ.Struct.replicateByte Int
paddingLen Word8
0x00)
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) a. Refined p a -> a
unrefine NullPadded n a
ra
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a
        paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        -- ^ refinement guarantees >=0

instance (BLen a, KnownNat n, Put a) => Put (NullPadded n a) where
    put :: NullPadded n a -> Poke RealWorld
put NullPadded n a
ra = a -> Poke RealWorld
forall a. Put a => a -> Poke RealWorld
put a
a Poke RealWorld -> Poke RealWorld -> Poke RealWorld
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> Poke RealWorld
BZ.replicateByte Int
paddingLen Word8
0x00
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) a. Refined p a -> a
unrefine NullPadded n a
ra
        paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- a -> Int
forall a. BLen a => a -> Int
blen a
a
        -- ^ refinement guarantees >=0

-- | Run a @Getter a@ isolated to @n@ bytes.
instance (KnownNat n, Get a) => GetC (NullPadded n a) where
    getC :: GetterC (NullPadded n a)
getC = ParserT PureMode (ParseError Pos Builder) a
-> Int#
-> (a -> Int# -> GetterC (NullPadded n a))
-> GetterC (NullPadded n a)
forall (st :: ZeroBitType) text a r.
ParserT st (ParseError Pos text) a
-> Int#
-> (a -> Int# -> ParserT st (ParseError Int text) r)
-> ParserT st (ParseError Int text) r
fpToBz ParserT PureMode (ParseError Pos Builder) a
forall a. Get a => Getter a
get Int#
len# ((a -> Int# -> GetterC (NullPadded n a))
 -> GetterC (NullPadded n a))
-> (a -> Int# -> GetterC (NullPadded n a))
-> GetterC (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ \a
a Int#
_unconsumed# ->
        -- TODO consume nulls lol
        NullPadded n a -> GetterC (NullPadded n a)
forall a (st :: ZeroBitType) e. a -> ParserT st e a
BZG.constParse (NullPadded n a -> GetterC (NullPadded n a))
-> NullPadded n a -> GetterC (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} a (p :: k). a -> Refined p a
unsafeRefine a
a
      where
        !(I# Int#
len#) = forall (n :: Natural). KnownNat n => Int
natValInt @n

instance (Get a, KnownNat n) => Get (NullPadded n a) where
    get :: Getter (NullPadded n a)
get = do
        (a
a, Int
len) <- ParserT PureMode (ParseError Pos Builder) a
-> ParserT PureMode (ParseError Pos Builder) (a, Int)
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e (a, Int)
FP.parseWithLength ParserT PureMode (ParseError Pos Builder) a
forall a. Get a => Getter a
get
        let paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        if   Int
paddingLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
        then [Builder] -> Getter (NullPadded n a)
forall text (st :: ZeroBitType) a.
[text] -> ParserT st (ParseError Pos text) a
err1 [Builder
"TODO used to be EOverlong, cba"]
        else do Int
-> ParserT PureMode (ParseError Pos Builder) ()
-> ParserT PureMode (ParseError Pos Builder) ()
forall (m :: Type -> Type) a. Monad m => Int -> m a -> m ()
skipCount Int
paddingLen (Word8 -> ParserT PureMode (ParseError Pos Builder) ()
forall (st :: ZeroBitType) e. Word8 -> ParserT st e ()
FP.word8 Word8
0x00)
                NullPadded n a -> Getter (NullPadded n a)
forall a. a -> ParserT PureMode (ParseError Pos Builder) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (NullPadded n a -> Getter (NullPadded n a))
-> NullPadded n a -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} a (p :: k). a -> Refined p a
unsafeRefine a
a