-- |
--
-- Module      : Data.JSON.Patch.Apply
-- Copyright   : (c) 2025 Patrick Brisbin
-- License     : AGPL-3
-- Maintainer  : pbrisbin@gmail.com
-- Stability   : experimental
-- Portability : POSIX
module Data.JSON.Patch.Apply
  ( applyPatches
  , PatchError (..)
  ) where

import Prelude

import Control.Monad (unless, void, when)
import Control.Monad.Except (MonadError, runExcept, throwError)
import Control.Monad.State (MonadState, execStateT, gets, modify, put)
import Data.Aeson
import Data.Aeson.KeyMap (KeyMap)
import Data.Aeson.Optics
import Data.Foldable (traverse_)
import Data.JSON.Patch.Error
import Data.JSON.Patch.Type
import Data.JSON.Pointer
import Data.JSON.Pointer.Token
import Data.Vector (Vector)
import Data.Vector qualified as V
import Optics

applyPatches :: [Patch] -> Value -> Either PatchError Value
applyPatches :: [Patch] -> Value -> Either PatchError Value
applyPatches [Patch]
ps = Except PatchError Value -> Either PatchError Value
forall e a. Except e a -> Either e a
runExcept (Except PatchError Value -> Either PatchError Value)
-> (Value -> Except PatchError Value)
-> Value
-> Either PatchError Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT Value (ExceptT PatchError Identity) ()
-> Value -> Except PatchError Value
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT ((Patch -> StateT Value (ExceptT PatchError Identity) ())
-> [Patch] -> StateT Value (ExceptT PatchError Identity) ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Patch -> StateT Value (ExceptT PatchError Identity) ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Patch -> m ()
applyPatch [Patch]
ps)

applyPatch :: (MonadError PatchError m, MonadState Value m) => Patch -> m ()
applyPatch :: forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Patch -> m ()
applyPatch = \case
  Add AddOp
op -> Value -> Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Value -> Pointer -> m ()
add AddOp
op.value AddOp
op.path
  Remove RemoveOp
op -> Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m ()
remove RemoveOp
op.path
  Replace ReplaceOp
op -> Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m ()
remove ReplaceOp
op.path m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Value -> Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Value -> Pointer -> m ()
add ReplaceOp
op.value ReplaceOp
op.path
  Move MoveOp
op -> do
    Value
v <- Pointer -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m Value
get MoveOp
op.from
    Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m ()
remove MoveOp
op.from
    Value -> Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Value -> Pointer -> m ()
add Value
v MoveOp
op.path
  Copy CopyOp
op -> (Value -> Pointer -> m ()) -> Pointer -> Value -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Value -> Pointer -> m ()
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Value -> Pointer -> m ()
add CopyOp
op.path (Value -> m ()) -> m Value -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pointer -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m Value
get CopyOp
op.from
  Test TestOp
op -> do
    Value
v <- Pointer -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m Value
get TestOp
op.path
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Value
v Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
== TestOp
op.value) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ PatchError -> m ()
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m ()) -> PatchError -> m ()
forall a b. (a -> b) -> a -> b
$ Pointer -> Value -> Value -> PatchError
TestFailed TestOp
op.path Value
v TestOp
op.value

get :: (MonadError PatchError m, MonadState Value m) => Pointer -> m Value
get :: forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m Value
get = \case
  Pointer
PointerEmpty -> (Value -> Value) -> m Value
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Value -> Value
forall a. a -> a
id
  PointerPath [Token]
ts Token
t -> [Token] -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists ([Token] -> m Value) -> [Token] -> m Value
forall a b. (a -> b) -> a -> b
$ [Token]
ts [Token] -> [Token] -> [Token]
forall a. Semigroup a => a -> a -> a
<> [Token
t]
  PointerPathEnd [Token]
ts -> do
    Value
target <- [Token] -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists [Token]
ts
    Vector Value
vec <- [Token] -> Value -> m (Vector Value)
forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Value -> m (Vector Value)
assertArray [Token]
ts Value
target
    (Vector Value, Value) -> Value
forall a b. (a, b) -> b
snd ((Vector Value, Value) -> Value)
-> m (Vector Value, Value) -> m Value
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Token] -> Vector Value -> m (Vector Value, Value)
forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Vector Value -> m (Vector Value, Value)
assertUnsnoc [Token]
ts Vector Value
vec

add :: (MonadError PatchError m, MonadState Value m) => Value -> Pointer -> m ()
add :: forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Value -> Pointer -> m ()
add Value
v = \case
  Pointer
PointerEmpty -> Value -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put Value
v
  PointerPath [Token]
ts Token
t -> do
    Value
target <- [Token] -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists [Token]
ts

    -- Additional validations based on type of final target
    case Token
t of
      K Key
_ -> m (KeyMap Value) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (KeyMap Value) -> m ()) -> m (KeyMap Value) -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> Value -> m (KeyMap Value)
forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Value -> m (KeyMap Value)
assertObject [Token]
ts Value
target
      N Int
n -> do
        case Value
target of
          Object {} -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- n will be used as Key, no bounds check
          Array Vector Value
vec -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ PatchError -> m ()
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m ()) -> PatchError -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> Int -> Vector Value -> PatchError
IndexOutOfBounds [Token]
ts Int
n Vector Value
vec
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Vector Value -> Int
forall a. Vector a -> Int
V.length Vector Value
vec) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ PatchError -> m ()
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m ()) -> PatchError -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> Int -> Vector Value -> PatchError
IndexOutOfBounds [Token]
ts Int
n Vector Value
vec
          Value
v' -> PatchError -> m ()
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m ()) -> PatchError -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> Value -> PatchError
InvalidArrayOperation [Token]
ts Value
v'

    (Value -> Value) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Value -> Value) -> m ()) -> (Value -> Value) -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> AffineTraversal' Value Value
tokensL [Token]
ts AffineTraversal' Value Value
-> Optic
     An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
-> Optic
     An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Token
-> Optic
     An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
atTokenL Token
t Optic
  An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
-> Value -> Value -> Value
forall k (is :: IxList) s t a b.
Is k A_Setter =>
Optic k is s t a (Maybe b) -> b -> s -> t
?~ Value
v
  PointerPathEnd [Token]
ts -> do
    Value
target <- [Token] -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists [Token]
ts
    m (Vector Value) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Vector Value) -> m ()) -> m (Vector Value) -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> Value -> m (Vector Value)
forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Value -> m (Vector Value)
assertArray [Token]
ts Value
target
    (Value -> Value) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Value -> Value) -> m ()) -> (Value -> Value) -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> AffineTraversal' Value Value
tokensL [Token]
ts AffineTraversal' Value Value
-> Optic A_Prism NoIx Value Value (Vector Value) (Vector Value)
-> Optic
     An_AffineTraversal NoIx Value Value (Vector Value) (Vector Value)
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic A_Prism NoIx Value Value (Vector Value) (Vector Value)
forall t. AsValue t => Prism' t (Vector Value)
_Array Optic
  An_AffineTraversal NoIx Value Value (Vector Value) (Vector Value)
-> (Vector Value -> Vector Value) -> Value -> Value
forall k (is :: IxList) s t a b.
Is k A_Setter =>
Optic k is s t a b -> (a -> b) -> s -> t
%~ (Vector Value -> Vector Value -> Vector Value
forall a. Semigroup a => a -> a -> a
<> Value -> Vector Value
forall a. a -> Vector a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Value
v)

remove :: (MonadError PatchError m, MonadState Value m) => Pointer -> m ()
remove :: forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
Pointer -> m ()
remove = \case
  Pointer
PointerEmpty -> Value -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put Value
Null -- NB. unspecified behavior
  PointerPath [Token]
ts Token
t -> do
    m Value -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Value -> m ()) -> m Value -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists ([Token] -> m Value) -> [Token] -> m Value
forall a b. (a -> b) -> a -> b
$ [Token]
ts [Token] -> [Token] -> [Token]
forall a. Semigroup a => a -> a -> a
<> [Token
t]

    -- NB. odd that the tests don't exercise any additional validation (e.g.
    -- bounds checking) like we saw with add.

    (Value -> Value) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Value -> Value) -> m ()) -> (Value -> Value) -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> AffineTraversal' Value Value
tokensL [Token]
ts AffineTraversal' Value Value
-> Optic
     An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
-> Optic
     An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Token
-> Optic
     An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
atTokenL Token
t Optic
  An_AffineTraversal NoIx Value Value (Maybe Value) (Maybe Value)
-> Maybe Value -> Value -> Value
forall k (is :: IxList) s t a b.
Is k A_Setter =>
Optic k is s t a b -> b -> s -> t
.~ Maybe Value
forall a. Maybe a
Nothing
  PointerPathEnd [Token]
ts -> do
    Value
target <- [Token] -> m Value
forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists [Token]
ts
    Vector Value
vec <- [Token] -> Value -> m (Vector Value)
forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Value -> m (Vector Value)
assertArray [Token]
ts Value
target
    (Vector Value
vs, Value
_) <- [Token] -> Vector Value -> m (Vector Value, Value)
forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Vector Value -> m (Vector Value, Value)
assertUnsnoc [Token]
ts Vector Value
vec
    (Value -> Value) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Value -> Value) -> m ()) -> (Value -> Value) -> m ()
forall a b. (a -> b) -> a -> b
$ [Token] -> AffineTraversal' Value Value
tokensL [Token]
ts AffineTraversal' Value Value
-> Optic A_Prism NoIx Value Value (Vector Value) (Vector Value)
-> Optic
     An_AffineTraversal NoIx Value Value (Vector Value) (Vector Value)
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic A_Prism NoIx Value Value (Vector Value) (Vector Value)
forall t. AsValue t => Prism' t (Vector Value)
_Array Optic
  An_AffineTraversal NoIx Value Value (Vector Value) (Vector Value)
-> Vector Value -> Value -> Value
forall k (is :: IxList) s t a b.
Is k A_Setter =>
Optic k is s t a b -> b -> s -> t
.~ Vector Value
vs

assertExists
  :: (MonadError PatchError m, MonadState Value m) => [Token] -> m Value
assertExists :: forall (m :: * -> *).
(MonadError PatchError m, MonadState Value m) =>
[Token] -> m Value
assertExists [Token]
ts =
  (Value -> Maybe Value) -> m (Maybe Value)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (AffineTraversal' Value Value -> Value -> Maybe Value
forall k (is :: IxList) s a.
Is k An_AffineFold =>
Optic' k is s a -> s -> Maybe a
preview (AffineTraversal' Value Value -> Value -> Maybe Value)
-> AffineTraversal' Value Value -> Value -> Maybe Value
forall a b. (a -> b) -> a -> b
$ [Token] -> AffineTraversal' Value Value
tokensL [Token]
ts) m (Maybe Value) -> (Maybe Value -> m Value) -> m Value
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe Value
Nothing -> PatchError -> m Value
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m Value) -> PatchError -> m Value
forall a b. (a -> b) -> a -> b
$ [Token] -> Maybe String -> PatchError
PointerNotFound [Token]
ts Maybe String
forall a. Maybe a
Nothing
    Just Value
v -> Value -> m Value
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Value
v

assertObject :: MonadError PatchError m => [Token] -> Value -> m (KeyMap Value)
assertObject :: forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Value -> m (KeyMap Value)
assertObject [Token]
ts = \case
  Object KeyMap Value
o -> KeyMap Value -> m (KeyMap Value)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KeyMap Value
o
  Value
v -> PatchError -> m (KeyMap Value)
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m (KeyMap Value)) -> PatchError -> m (KeyMap Value)
forall a b. (a -> b) -> a -> b
$ [Token] -> Value -> PatchError
InvalidObjectOperation [Token]
ts Value
v

assertArray :: MonadError PatchError m => [Token] -> Value -> m (Vector Value)
assertArray :: forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Value -> m (Vector Value)
assertArray [Token]
ts = \case
  Array Vector Value
vec -> Vector Value -> m (Vector Value)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector Value
vec
  Value
v -> PatchError -> m (Vector Value)
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m (Vector Value)) -> PatchError -> m (Vector Value)
forall a b. (a -> b) -> a -> b
$ [Token] -> Value -> PatchError
InvalidArrayOperation [Token]
ts Value
v

assertUnsnoc
  :: MonadError PatchError m
  => [Token]
  -> Vector Value
  -> m (Vector Value, Value)
assertUnsnoc :: forall (m :: * -> *).
MonadError PatchError m =>
[Token] -> Vector Value -> m (Vector Value, Value)
assertUnsnoc [Token]
ts Vector Value
vec =
  case Vector Value -> Maybe (Vector Value, Value)
forall a. Vector a -> Maybe (Vector a, a)
V.unsnoc Vector Value
vec of
    Maybe (Vector Value, Value)
Nothing -> PatchError -> m (Vector Value, Value)
forall a. PatchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (PatchError -> m (Vector Value, Value))
-> PatchError -> m (Vector Value, Value)
forall a b. (a -> b) -> a -> b
$ [Token] -> PatchError
EmptyArray [Token]
ts
    Just (Vector Value, Value)
tp -> (Vector Value, Value) -> m (Vector Value, Value)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector Value, Value)
tp