module Data.Yaml.Marked.Replace
  ( Replace
  , newReplace
  , replaceMarked
  , ReplaceException (..)
  , runReplaces
  , runReplacesOnOverlapping
  ) where

import Prelude

import Control.Monad (void, when)
import Control.Monad.Trans.Resource (MonadThrow (..))
import Data.Bifunctor (second)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS8
import Data.List (sortOn)
import Data.Yaml.Marked
import Numeric.Natural
import UnliftIO.Exception (Exception (..))

data Replace = Replace
  { Replace -> Natural
replaceIndex :: Natural
  , Replace -> Natural
replacedLength :: Natural
  , Replace -> ByteString
replacedBy :: ByteString
  }
  deriving stock (Replace -> Replace -> Bool
(Replace -> Replace -> Bool)
-> (Replace -> Replace -> Bool) -> Eq Replace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Replace -> Replace -> Bool
== :: Replace -> Replace -> Bool
$c/= :: Replace -> Replace -> Bool
/= :: Replace -> Replace -> Bool
Eq, Int -> Replace -> ShowS
[Replace] -> ShowS
Replace -> String
(Int -> Replace -> ShowS)
-> (Replace -> String) -> ([Replace] -> ShowS) -> Show Replace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Replace -> ShowS
showsPrec :: Int -> Replace -> ShowS
$cshow :: Replace -> String
show :: Replace -> String
$cshowList :: [Replace] -> ShowS
showList :: [Replace] -> ShowS
Show)

-- | Create a 'Replace' directly at given index and length
--
-- NB. This function is unsafe in that it can be used with negative literals and
-- will fail at runtime. Prefer 'replaceMarked'.
newReplace :: Natural -> Natural -> ByteString -> Replace
newReplace :: Natural -> Natural -> ByteString -> Replace
newReplace Natural
idx Natural
len ByteString
bs =
  Replace
    { replaceIndex :: Natural
replaceIndex = Natural
idx
    , replacedLength :: Natural
replacedLength = Natural
len
    , replacedBy :: ByteString
replacedBy = ByteString
bs
    }

-- | Create a 'Replace' for something 'Marked'
replaceMarked :: Marked a -> ByteString -> Replace
replaceMarked :: forall a. Marked a -> ByteString -> Replace
replaceMarked Marked {a
String
Maybe JSONPath
Location
markedItem :: a
markedPath :: String
markedJSONPath :: Maybe JSONPath
markedLocationStart :: Location
markedLocationEnd :: Location
markedItem :: forall a. Marked a -> a
markedPath :: forall a. Marked a -> String
markedJSONPath :: forall a. Marked a -> Maybe JSONPath
markedLocationStart :: forall a. Marked a -> Location
markedLocationEnd :: forall a. Marked a -> Location
..} = Natural -> Natural -> ByteString -> Replace
newReplace Natural
idx Natural
len
 where
  idx :: Natural
idx = Location -> Natural
locationIndex Location
markedLocationStart
  end :: Natural
end = Location -> Natural
locationIndex Location
markedLocationEnd
  len :: Natural
len
    | Natural
end Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
>= Natural
idx = Natural
end Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
idx
    | Bool
otherwise = Natural
0

data ReplaceException
  = ReplaceOutOfBounds Replace Natural
  | OverlappingReplace Replace
  deriving stock (ReplaceException -> ReplaceException -> Bool
(ReplaceException -> ReplaceException -> Bool)
-> (ReplaceException -> ReplaceException -> Bool)
-> Eq ReplaceException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ReplaceException -> ReplaceException -> Bool
== :: ReplaceException -> ReplaceException -> Bool
$c/= :: ReplaceException -> ReplaceException -> Bool
/= :: ReplaceException -> ReplaceException -> Bool
Eq, Int -> ReplaceException -> ShowS
[ReplaceException] -> ShowS
ReplaceException -> String
(Int -> ReplaceException -> ShowS)
-> (ReplaceException -> String)
-> ([ReplaceException] -> ShowS)
-> Show ReplaceException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ReplaceException -> ShowS
showsPrec :: Int -> ReplaceException -> ShowS
$cshow :: ReplaceException -> String
show :: ReplaceException -> String
$cshowList :: [ReplaceException] -> ShowS
showList :: [ReplaceException] -> ShowS
Show)

instance Exception ReplaceException where
  displayException :: ReplaceException -> String
displayException = \case
    ReplaceOutOfBounds Replace
r Natural
bLen ->
      String
"The replacement "
        String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Replace -> String
forall a. Show a => a -> String
show Replace
r
        String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is trying to replace more characters than remain in the ByteString ("
        String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
bLen
        String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"
    OverlappingReplace Replace
r ->
      String
"The replacement "
        String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Replace -> String
forall a. Show a => a -> String
show Replace
r
        String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is where an earlier replacement has already been made"

runReplaces :: MonadThrow m => [Replace] -> ByteString -> m ByteString
runReplaces :: forall (m :: * -> *).
MonadThrow m =>
[Replace] -> ByteString -> m ByteString
runReplaces = (Replace -> m Any) -> [Replace] -> ByteString -> m ByteString
forall (m :: * -> *) a.
MonadThrow m =>
(Replace -> m a) -> [Replace] -> ByteString -> m ByteString
runReplacesOnOverlapping ((Replace -> m Any) -> [Replace] -> ByteString -> m ByteString)
-> (Replace -> m Any) -> [Replace] -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ReplaceException -> m Any
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (ReplaceException -> m Any)
-> (Replace -> ReplaceException) -> Replace -> m Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Replace -> ReplaceException
OverlappingReplace

runReplacesOnOverlapping
  :: MonadThrow m
  => (Replace -> m a)
  -- ^ What to do with the first overlapping 'Replace' if encountered
  --
  -- NB. the overlapping replace(s) will be ignored, but this allows you to log
  -- it as a warning, or use 'throwM' to halt (which 'runReplaces' does).
  -> [Replace]
  -> ByteString
  -> m ByteString
runReplacesOnOverlapping :: forall (m :: * -> *) a.
MonadThrow m =>
(Replace -> m a) -> [Replace] -> ByteString -> m ByteString
runReplacesOnOverlapping Replace -> m a
f [Replace]
rs ByteString
bs = do
  [Replace]
rs' <- (Replace -> m a) -> [Replace] -> m [Replace]
forall (m :: * -> *) a.
Monad m =>
(Replace -> m a) -> [Replace] -> m [Replace]
filterOverlapping Replace -> m a
f ([Replace] -> m [Replace]) -> [Replace] -> m [Replace]
forall a b. (a -> b) -> a -> b
$ (Replace -> Natural) -> [Replace] -> [Replace]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Replace -> Natural
replaceIndex [Replace]
rs
  Natural -> ByteString -> [Replace] -> ByteString -> m ByteString
forall (m :: * -> *).
MonadThrow m =>
Natural -> ByteString -> [Replace] -> ByteString -> m ByteString
runReplaces' Natural
0 ByteString
"" [Replace]
rs' ByteString
bs

runReplaces'
  :: MonadThrow m
  => Natural
  -> ByteString
  -> [Replace]
  -> ByteString
  -> m ByteString
runReplaces' :: forall (m :: * -> *).
MonadThrow m =>
Natural -> ByteString -> [Replace] -> ByteString -> m ByteString
runReplaces' Natural
_ ByteString
acc [] ByteString
bs = ByteString -> m ByteString
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> m ByteString) -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
runReplaces' Natural
offset ByteString
acc (Replace
r : [Replace]
rs) ByteString
bs = do
  (ByteString
before, ByteString
after) <- Natural -> Replace -> ByteString -> m (ByteString, ByteString)
forall (m :: * -> *).
MonadThrow m =>
Natural -> Replace -> ByteString -> m (ByteString, ByteString)
breakAtOffsetReplace Natural
offset Replace
r ByteString
bs
  let newOffset :: Natural
newOffset = Natural
offset Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS8.length ByteString
before) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Replace -> Natural
replacedLength Replace
r
  Natural -> ByteString -> [Replace] -> ByteString -> m ByteString
forall (m :: * -> *).
MonadThrow m =>
Natural -> ByteString -> [Replace] -> ByteString -> m ByteString
runReplaces' Natural
newOffset (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
before ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Replace -> ByteString
replacedBy Replace
r) [Replace]
rs ByteString
after

filterOverlapping
  :: Monad m => (Replace -> m a) -> [Replace] -> m [Replace]
filterOverlapping :: forall (m :: * -> *) a.
Monad m =>
(Replace -> m a) -> [Replace] -> m [Replace]
filterOverlapping Replace -> m a
onOverlap = [Replace] -> [Replace] -> m [Replace]
go []
 where
  go :: [Replace] -> [Replace] -> m [Replace]
go [Replace]
acc [] = [Replace] -> m [Replace]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Replace]
acc
  go [Replace]
acc (Replace
r : [Replace]
rs)
    | (Replace -> Bool) -> [Replace] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Replace
r Replace -> Replace -> Bool
`precedesEndOf`) [Replace]
acc = do
        m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m a -> m ()) -> m a -> m ()
forall a b. (a -> b) -> a -> b
$ Replace -> m a
onOverlap Replace
r
        [Replace] -> [Replace] -> m [Replace]
go [Replace]
acc [Replace]
rs
    | Bool
otherwise = [Replace] -> [Replace] -> m [Replace]
go ([Replace]
acc [Replace] -> [Replace] -> [Replace]
forall a. Semigroup a => a -> a -> a
<> [Replace
r]) [Replace]
rs

precedesEndOf :: Replace -> Replace -> Bool
precedesEndOf :: Replace -> Replace -> Bool
precedesEndOf Replace
a Replace
b = Replace -> Natural
replaceIndex Replace
a Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
< Replace -> Natural
replaceIndex Replace
b Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Replace -> Natural
replacedLength Replace
b

-- | Break a 'ByteString' into the content before/after a replacement
--
-- Will throw 'ReplaceException' if the 'Replace' is not valid for the given
-- input.
breakAtOffsetReplace
  :: MonadThrow m
  => Natural
  -- ^ An amount to shift the 'replaceIndex' by
  --
  -- Since this function is called recursively to incrementally replace within
  -- an overall 'ByteString', to which the 'replaceIndex' is relative, we need
  -- to track how much to shift it as we recur.
  -> Replace
  -> ByteString
  -> m (ByteString, ByteString)
breakAtOffsetReplace :: forall (m :: * -> *).
MonadThrow m =>
Natural -> Replace -> ByteString -> m (ByteString, ByteString)
breakAtOffsetReplace Natural
offset Replace
r ByteString
bs = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Natural
rLen Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
> Natural
bLen) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ ReplaceException -> m ()
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (ReplaceException -> m ()) -> ReplaceException -> m ()
forall a b. (a -> b) -> a -> b
$ Replace -> Natural -> ReplaceException
ReplaceOutOfBounds Replace
r Natural
bLen
  (ByteString, ByteString) -> m (ByteString, ByteString)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ByteString, ByteString) -> m (ByteString, ByteString))
-> (ByteString, ByteString) -> m (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$
    (ByteString -> ByteString)
-> (ByteString, ByteString) -> (ByteString, ByteString)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Int -> ByteString -> ByteString
BS8.drop (Int -> ByteString -> ByteString)
-> Int -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
rLen) ((ByteString, ByteString) -> (ByteString, ByteString))
-> (ByteString, ByteString) -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$
      Int -> ByteString -> (ByteString, ByteString)
BS8.splitAt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
sIdx) ByteString
bs
 where
  sIdx :: Natural
sIdx
    | Natural
offset Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
> Replace -> Natural
replaceIndex Replace
r = String -> Natural
forall a. HasCallStack => String -> a
error String
"TODO"
    | Bool
otherwise = Replace -> Natural
replaceIndex Replace
r Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
offset
  rLen :: Natural
rLen = Replace -> Natural
replacedLength Replace
r
  bLen :: Natural
bLen = Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Natural) -> Int -> Natural
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS8.length ByteString
bs