{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}

module Data.Attoparsec.ByteString.Extra
  ( takeWhileMN
  , countMN
  ) where

import Control.Applicative ( optional )
import Control.Monad ( MonadPlus )
import Data.Attoparsec.ByteString ( Parser, scan )
import Data.ByteString ( ByteString )
import qualified Data.ByteString as BS
import Data.Word ( Word8 )
import Prelude

-- | Consume the longest (@m <= len <= n@) input slice where the predicate
-- returns 'True', and return the consumed input.
--
-- This parser fails in the event that the length of its consumed input does
-- not satisfy @m <= len <= n@.
takeWhileMN ::
  -- | @m@.
  Word ->
  -- | @n@.
  Word ->
  -- | Predicate.
  (Word8 -> Bool) ->
  Parser ByteString
takeWhileMN :: Word -> Word -> (Word8 -> Bool) -> Parser ByteString
takeWhileMN Word
m Word
n Word8 -> Bool
f
  | Word
m Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
> Word
n = String -> Parser ByteString
forall a. String -> Parser ByteString a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"takeWhileMN: m cannot be greater than n"
  | Bool
otherwise = do
      ByteString
bs <- Word -> (Word -> Word8 -> Maybe Word) -> Parser ByteString
forall s. s -> (s -> Word8 -> Maybe s) -> Parser ByteString
scan Word
0 Word -> Word8 -> Maybe Word
transformState
      let len :: Int
len = ByteString -> Int
BS.length ByteString
bs
      if Int
mI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
len Bool -> Bool -> Bool
&& Int
nI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
        then ByteString -> Parser ByteString
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs
        else
          String -> Parser ByteString
forall a. String -> Parser ByteString a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Parser ByteString) -> String -> Parser ByteString
forall a b. (a -> b) -> a -> b
$
            String
"takeWhileMN: consumed input length ("
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
len
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
") must be >= "
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
mI
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and <= "
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
nI
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"."
  where
    mI :: Int
    mI :: Int
mI = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
m

    nI :: Int
    nI :: Int
nI = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
n

    -- Parse up to @n@ bytes where the predicate returns 'True'.
    transformState :: Word -> Word8 -> Maybe Word
    transformState :: Word -> Word8 -> Maybe Word
transformState Word
s Word8
b
      | Word
s Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
n = Maybe Word
forall a. Maybe a
Nothing
      | Word
s Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
< Word
n Bool -> Bool -> Bool
&& Word8 -> Bool
f Word8
b = Word -> Maybe Word
forall a. a -> Maybe a
Just (Word
s Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1)
      | Bool
otherwise = Maybe Word
forall a. Maybe a
Nothing

-- | Applies from @m@ to @n@ occurrences of @p@. Returns a list of the returned
-- values of @p@. The value returned by @p@ is forced to WHNF.
countMN :: MonadPlus m => Word -> Word -> m a -> m [a]
countMN :: forall (m :: * -> *) a. MonadPlus m => Word -> Word -> m a -> m [a]
countMN Word
m Word
n m a
p
  | Word
m Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
> Word
n = String -> m [a]
forall a. HasCallStack => String -> a
error String
"countMN: m cannot be greater than n"
  | Word
n Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
0 = [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
  | Bool
otherwise = [a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> m [a] -> m [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> m [a]
goUntilM []
  where
    mI :: Int
    mI :: Int
mI = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
m

    nI :: Int
    nI :: Int
nI = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
n

    goUntilM :: [a] -> m [a]
goUntilM ![a]
acc
      | [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
acc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mI = [a] -> m [a]
goUntilN [a]
acc
      | Bool
otherwise = do
          !a
x <- m a
p
          [a] -> m [a]
goUntilM (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)

    goUntilN :: [a] -> m [a]
goUntilN ![a]
acc
      | [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
acc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nI = [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
acc
      | Bool
otherwise = m a -> m (Maybe a)
forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional m a
p m (Maybe a) -> (Maybe a -> m [a]) -> m [a]
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Maybe a
Nothing -> [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
acc
          Just !a
x -> [a] -> m [a]
goUntilN (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)