{-# LANGUAGE TypeFamilies #-}

-- | A dynamic, sparse segment tree that covers a half-open interval \([l_0, r_0)\). Nodes are
-- instantinated as needed, with the required capacity being \(q\), where \(q\) is the number of
-- mutable operations. The traid-off compared to the non-sparse variant is that initial monoid
-- values are fixed at `mempty`.
--
-- ==== __Example__
--
-- >>> import AtCoder.Extra.DynSparseSegTree qualified as Seg
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
--
-- Create a `DynSegTree` over \([0, 4)\) with some initial capacity:
--
-- >>> let capacityFor len q = q * max 2 (2 + ceiling (logBase 2 (fromIntegral len) :: Double))
-- >>> let len = 4; q = 2
-- >>> seg <- Seg.new @_ @(Sum Int) (capacityFor len q) 0 4
--
-- Different from the @SegTree@ module, it requires explicit root handle:
--
-- >>> -- [0, 0, 0, 0]
-- >>> root <- Seg.newRoot seg
-- >>> Seg.write seg root 1 $ Sum 10
-- >>> Seg.write seg root 2 $ Sum 20
-- >>> -- [0, 10, 20, 0]
-- >>> Seg.prod seg root 0 3
-- Sum {getSum = 30}
--
-- >>> Seg.maxRight seg root (< (Sum 30))
-- 2
--
-- @since 1.2.1.0
module AtCoder.Extra.DynSparseSegTree
  ( -- * Dynamic, sparse segment tree
    Raw.DynSparseSegTree (..),

    -- * Re-exports
    P.Handle (..),

    -- * Constructors
    new,
    recommendedCapacity,
    newRoot,

    -- * Accessing elements
    write,
    modify,
    modifyM,
    -- exchange,
    -- read,

    -- * Products
    prod,
    -- prodMaybe,
    allProd, -- FIXME: rename it to prodAll

    -- * Binary searches
    maxRight,
    maxRightM,
    -- -- * Conversions
    -- freeze,
  )
where

import AtCoder.Extra.DynSparseSegTree.Raw qualified as Raw
import AtCoder.Extra.Pool qualified as P
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import GHC.Stack (HasCallStack)
import Prelude hiding (read)

-- This module is based on `Handle` because the root is defined as `P.undefIndex`

-- | \(O(n)\) Creates a `DynSparseSegTree` of capacity \(n\) for interval \([l_0, r_0)\) with `mempty` as
-- initial leaf values.
--
-- @since 1.2.1.0
{-# INLINE new #-}
new ::
  (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) =>
  -- | Capacity \(n\)
  Int ->
  -- | Left index boundary \(l_0\)
  Int ->
  -- | Right index boundary \(r_0\)
  Int ->
  -- | Dynamic, sparse segment tree
  m (Raw.DynSparseSegTree (PrimState m) a)
new :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
Int -> Int -> Int -> m (DynSparseSegTree (PrimState m) a)
new Int
nDsst Int
l Int
r = ST (PrimState m) (DynSparseSegTree (PrimState m) a)
-> m (DynSparseSegTree (PrimState m) a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (DynSparseSegTree (PrimState m) a)
 -> m (DynSparseSegTree (PrimState m) a))
-> ST (PrimState m) (DynSparseSegTree (PrimState m) a)
-> m (DynSparseSegTree (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ Bool
-> Int
-> Int
-> Int
-> ST (PrimState m) (DynSparseSegTree (PrimState m) a)
forall a s.
(HasCallStack, Unbox a) =>
Bool -> Int -> Int -> Int -> ST s (DynSparseSegTree s a)
Raw.newST Bool
False Int
nDsst Int
l Int
r

-- | \(O(1)\) Returns recommended capacity for \(L\) and \(q\): \(q\).
--
-- @since 1.2.1.0
{-# INLINE recommendedCapacity #-}
recommendedCapacity :: Int -> Int -> Int
recommendedCapacity :: Int -> Int -> Int
recommendedCapacity Int
_ Int
q = Int
q

-- | \(O(1)\) Creates a new root in \([l_0, r_0)\).
--
-- @since 1.2.1.0
newRoot :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> m (P.Handle (PrimState m))
newRoot :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a -> m (Handle (PrimState m))
newRoot DynSparseSegTree (PrimState m) a
dst = ST (PrimState m) (Handle (PrimState m)) -> m (Handle (PrimState m))
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (Handle (PrimState m))
 -> m (Handle (PrimState m)))
-> ST (PrimState m) (Handle (PrimState m))
-> m (Handle (PrimState m))
forall a b. (a -> b) -> a -> b
$ Index -> ST (PrimState m) (Handle (PrimState m))
Index -> ST (PrimState m) (Handle (PrimState (ST (PrimState m))))
forall (m :: * -> *).
PrimMonad m =>
Index -> m (Handle (PrimState m))
P.newHandle (Index -> ST (PrimState m) (Handle (PrimState m)))
-> ST (PrimState m) Index
-> ST (PrimState m) (Handle (PrimState m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DynSparseSegTree (PrimState m) a -> ST (PrimState m) Index
forall a s.
(HasCallStack, Monoid a, Unbox a) =>
DynSparseSegTree s a -> ST s Index
Raw.newRootST DynSparseSegTree (PrimState m) a
dst

-- | \(O(\log L)\) Writes to the monoid value of the node at \(i\).
--
-- ==== Constraints
-- - \(l_0 \le i \lt r_0\)
--
-- @since 1.2.1.0
{-# INLINE write #-}
write :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> Int -> a -> m ()
write :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Handle (PrimState m) -> Int -> a -> m ()
write DynSparseSegTree (PrimState m) a
dst (P.Handle MVector (PrimState m) Index
handle) Int
i a
x = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  MVector (PrimState (ST (PrimState m))) Index
-> (Index -> ST (PrimState m) Index) -> Int -> ST (PrimState m) ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> m a) -> Int -> m ()
VGM.modifyM
    MVector (PrimState m) Index
MVector (PrimState (ST (PrimState m))) Index
handle
    (\Index
root -> DynSparseSegTree (PrimState (ST (PrimState m))) a
-> Index
-> (a -> ST (PrimState m) a)
-> Int
-> ST (PrimState m) Index
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Index -> (a -> m a) -> Int -> m Index
Raw.modifyMST DynSparseSegTree (PrimState m) a
DynSparseSegTree (PrimState (ST (PrimState m))) a
dst Index
root (a -> ST (PrimState m) a
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> ST (PrimState m) a) -> (a -> a) -> a -> ST (PrimState m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a -> a
forall a b. a -> b -> a
const a
x) Int
i)
    Int
0

-- | \(O(\log L)\) Modifies the monoid value of the node at \(i\).
--
-- ==== Constraints
-- - \(l_0 \le i \lt r_0\)
--
-- @since 1.2.1.0
{-# INLINE modify #-}
modify :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> (a -> a) -> Int -> m ()
modify :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Handle (PrimState m) -> (a -> a) -> Int -> m ()
modify DynSparseSegTree (PrimState m) a
dst (P.Handle MVector (PrimState m) Index
handle) a -> a
f Int
i = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  MVector (PrimState (ST (PrimState m))) Index
-> (Index -> ST (PrimState m) Index) -> Int -> ST (PrimState m) ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> m a) -> Int -> m ()
VGM.modifyM
    MVector (PrimState m) Index
MVector (PrimState (ST (PrimState m))) Index
handle
    (\Index
root -> DynSparseSegTree (PrimState (ST (PrimState m))) a
-> Index
-> (a -> ST (PrimState m) a)
-> Int
-> ST (PrimState m) Index
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Index -> (a -> m a) -> Int -> m Index
Raw.modifyMST DynSparseSegTree (PrimState m) a
DynSparseSegTree (PrimState (ST (PrimState m))) a
dst Index
root (a -> ST (PrimState m) a
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> ST (PrimState m) a) -> (a -> a) -> a -> ST (PrimState m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f) Int
i)
    Int
0

-- | \(O(\log L)\) Modifies the monoid value of the node at \(i\).
--
-- ==== Constraints
-- - \(l_0 \le i \lt r_0\)
--
-- @since 1.2.1.0
{-# INLINE modifyM #-}
modifyM :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> (a -> m a) -> Int -> m ()
modifyM :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Handle (PrimState m) -> (a -> m a) -> Int -> m ()
modifyM DynSparseSegTree (PrimState m) a
dst (P.Handle MVector (PrimState m) Index
handle) a -> m a
f Int
i = do
  MVector (PrimState m) Index -> (Index -> m Index) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> m a) -> Int -> m ()
VGM.modifyM
    MVector (PrimState m) Index
handle
    (\Index
root -> DynSparseSegTree (PrimState m) a
-> Index -> (a -> m a) -> Int -> m Index
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Index -> (a -> m a) -> Int -> m Index
Raw.modifyMST DynSparseSegTree (PrimState m) a
dst Index
root a -> m a
f Int
i)
    Int
0

-- | \(O(\log L)\) Returns the monoid product in \([l, r)\).
--
-- ==== Constraints
-- - \(l_0 \le l \le r \le r_0\)
--
-- @since 1.2.1.0
{-# INLINE prod #-}
prod :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> Int -> Int -> m a
prod :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Handle (PrimState m) -> Int -> Int -> m a
prod DynSparseSegTree (PrimState m) a
dst (P.Handle MVector (PrimState m) Index
handle) Int
l Int
r = ST (PrimState m) a -> m a
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) a -> m a) -> ST (PrimState m) a -> m a
forall a b. (a -> b) -> a -> b
$ do
  Index
root <- MVector (PrimState (ST (PrimState m))) Index
-> Int -> ST (PrimState m) Index
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Index
MVector (PrimState (ST (PrimState m))) Index
handle Int
0
  DynSparseSegTree (PrimState m) a
-> Index -> Int -> Int -> ST (PrimState m) a
forall a s.
(HasCallStack, Monoid a, Unbox a) =>
DynSparseSegTree s a -> Index -> Int -> Int -> ST s a
Raw.prodST DynSparseSegTree (PrimState m) a
dst Index
root Int
l Int
r

-- | \(O(\log L)\) Returns the monoid product in \([l_0, r_0)\).
--
-- @since 1.2.1.0
{-# INLINE allProd #-}
allProd :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> m a
allProd :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a -> Handle (PrimState m) -> m a
allProd dst :: DynSparseSegTree (PrimState m) a
dst@Raw.DynSparseSegTree {Int
l0Dsst :: Int
l0Dsst :: forall s a. DynSparseSegTree s a -> Int
l0Dsst, Int
r0Dsst :: Int
r0Dsst :: forall s a. DynSparseSegTree s a -> Int
r0Dsst} (P.Handle MVector (PrimState m) Index
handle) = ST (PrimState m) a -> m a
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) a -> m a) -> ST (PrimState m) a -> m a
forall a b. (a -> b) -> a -> b
$ do
  Index
root <- MVector (PrimState (ST (PrimState m))) Index
-> Int -> ST (PrimState m) Index
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Index
MVector (PrimState (ST (PrimState m))) Index
handle Int
0
  DynSparseSegTree (PrimState m) a
-> Index -> Int -> Int -> ST (PrimState m) a
forall a s.
(HasCallStack, Monoid a, Unbox a) =>
DynSparseSegTree s a -> Index -> Int -> Int -> ST s a
Raw.prodST DynSparseSegTree (PrimState m) a
dst Index
root Int
l0Dsst Int
r0Dsst

-- | \(O(\log L)\) Returns the maximum \(r \in [l_0, r_0)\) where \(f(a_{l_0} a_{l_0 + 1} \dots a_{r - 1})\) holds.
--
-- @since 1.2.1.0
{-# INLINE maxRight #-}
maxRight :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> (a -> Bool) -> m Int
maxRight :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Handle (PrimState m) -> (a -> Bool) -> m Int
maxRight DynSparseSegTree (PrimState m) a
dst (P.Handle MVector (PrimState m) Index
handle) a -> Bool
f = do
  Index
root <- MVector (PrimState m) Index -> Int -> m Index
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Index
handle Int
0
  DynSparseSegTree (PrimState m) a -> Index -> (a -> m Bool) -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a -> Index -> (a -> m Bool) -> m Int
Raw.maxRightM DynSparseSegTree (PrimState m) a
dst Index
root (Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> (a -> Bool) -> a -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Bool
f)

-- | \(O(\log L)\) Returns the maximum \(r \in [l_0, r_0)\) where \(f(a_{l_0} a_{l_0 + 1} \dots a_{r - 1})\) holds.
--
-- @since 1.2.1.0
{-# INLINE maxRightM #-}
maxRightM :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => Raw.DynSparseSegTree (PrimState m) a -> P.Handle (PrimState m) -> (a -> m Bool) -> m Int
maxRightM :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a
-> Handle (PrimState m) -> (a -> m Bool) -> m Int
maxRightM DynSparseSegTree (PrimState m) a
dst (P.Handle MVector (PrimState m) Index
handle) a -> m Bool
f = do
  Index
root <- MVector (PrimState m) Index -> Int -> m Index
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Index
handle Int
0
  DynSparseSegTree (PrimState m) a -> Index -> (a -> m Bool) -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
DynSparseSegTree (PrimState m) a -> Index -> (a -> m Bool) -> m Int
Raw.maxRightM DynSparseSegTree (PrimState m) a
dst Index
root a -> m Bool
f