{-# LANGUAGE RecordWildCards #-}

-- | A 2D, static wavelet matrix with segment tree, that can handle point add and rectangle sum
-- queries. Points cannot be added after construction, but monoid values in each point can be
-- modified later. Duplicate monoids at the same coordinate will be combined into one.
--
-- ==== `SegTree2d` vs `WaveletMatrix2d`
-- They basically have the same functionalities and performance, however, `SegTree2d` performs better in
-- @ac-library-hs@.
--
-- ==== __Example__
-- Create a `WaveletMatrix2d` with initial vertex values:
--
-- >>> import AtCoder.Extra.WaveletMatrix2d qualified as WM
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> -- 8  9 10 11
-- >>> -- 4  5  6  7
-- >>> -- 0  1  2  3
-- >>> wm <- WM.build negate $ VU.generate 12 $ \i -> let (!y, !x) = i `divMod` 4 in (x, y, Sum i)
--
-- Read the value at \(x = 2, y = 1\):
--
-- >>> WM.read wm (2, 1)
-- Sum {getSum = 6}
--
-- Other segment tree methods are also available, but in 2D:
--
-- >>> WM.allProd wm -- (0 + 11) * 12 / 2 = 66
-- Sum {getSum = 66}
--
-- >>> WM.prod wm {- x -} 1 3 {- y -} 0 3 -- 1 + 2 + 5 + 6 + 9 + 10
-- Sum {getSum = 33}
--
-- >>> WM.modify wm (+ 2) (1, 1)
-- >>> WM.prod wm {- x -} 1 3 {- y -} 0 3 -- 1 + 2 + 7 + 6 + 9 + 10
-- Sum {getSum = 35}
--
-- >>> WM.write wm (1, 1) $ Sum 0
-- >>> WM.prod wm {- x -} 1 3 {- y -} 0 3 -- 1 + 2 + 0 + 6 + 9 + 10
-- Sum {getSum = 28}
--
-- @since 1.1.0.0
module AtCoder.Extra.WaveletMatrix2d
  ( -- * Wavelet matrix 2D
    WaveletMatrix2d (..),

    -- * Counstructor
    new,
    build,

    -- * Segment tree methods
    read,
    write,
    modify,
    prod,
    prodMaybe,
    allProd,
    -- wavelet matrix methods could be implemented, too
  )
where

import AtCoder.Extra.Bisect (lowerBound)
import AtCoder.Extra.WaveletMatrix.BitVector qualified as BV
import AtCoder.Extra.WaveletMatrix.Raw qualified as Rwm
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.SegTree qualified as ST
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Control.Monad.ST (ST)
import Data.Bit (Bit (..))
import Data.Bits (Bits (testBit))
import Data.Vector qualified as V
import Data.Vector.Algorithms.Intro qualified as VAI
import Data.Vector.Generic qualified as VG
import Data.Vector.Unboxed qualified as VU
import GHC.Stack (HasCallStack)
import Prelude hiding (read)

-- NOTE: There are many possible improvements.
-- - Use cumulative sum or fenwick tree instead for the speed.
-- - The inverse operator is not actually required.
-- - Wavelet matrix methods such as `rank` can be implemented
-- - `maxRight` can be implemented.

-- | Segment Tree on Wavelet Matrix: points on a 2D plane and rectangle products of them.
--
-- @since 1.3.0.0
data WaveletMatrix2d s a = WaveletMatrix2d
  { -- | The wavelet matrix that represents points on a 2D plane.
    --
    -- @since 1.3.0.0
    forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
rawWm2d :: !Rwm.RawWaveletMatrix,
    -- | (x, y) index compression dictionary.
    --
    -- @since 1.1.0.0
    forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
xyDictWm2d :: !(VU.Vector (Int, Int)),
    -- | y index compression dictionary.
    --
    -- @since 1.1.0.0
    forall s a. WaveletMatrix2d s a -> Vector Int
yDictWm2d :: !(VU.Vector Int),
    -- | The segment tree of the weights of the points in the order of `xyDictWm2d`.
    --
    -- @since 1.1.0.0
    forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
segTreesWm2d :: !(V.Vector (ST.SegTree s a)),
    -- | The inverse operator of the interested monoid.
    --
    -- @since 1.1.0.0
    forall s a. WaveletMatrix2d s a -> a -> a
invWm2d :: !(a -> a)
  }

-- | \(O(n \log n)\) Creates a `WaveletMatrix2d` with `mempty` as the initial monoid
-- values for each point.
--
-- @since 1.1.0.0
{-# INLINEABLE new #-}
new ::
  (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) =>
  -- | Inverse operator of the monoid
  (a -> a) ->
  -- | Input points
  VU.Vector (Int, Int) ->
  -- | A 2D wavelet matrix
  m (WaveletMatrix2d (PrimState m) a)
new :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
(a -> a)
-> Vector (Int, Int) -> m (WaveletMatrix2d (PrimState m) a)
new a -> a
invWm2d Vector (Int, Int)
xys = ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
-> m (WaveletMatrix2d (PrimState m) a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
 -> m (WaveletMatrix2d (PrimState m) a))
-> ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
-> m (WaveletMatrix2d (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = Vector (Int, Int) -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector (Int, Int)
xys
  let xyDictWm2d :: Vector (Int, Int)
xyDictWm2d = Vector (Int, Int) -> Vector (Int, Int)
forall a. (Unbox a, Eq a) => Vector a -> Vector a
VU.uniq (Vector (Int, Int) -> Vector (Int, Int))
-> (Vector (Int, Int) -> Vector (Int, Int))
-> Vector (Int, Int)
-> Vector (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s. MVector s (Int, Int) -> ST s ())
-> Vector (Int, Int) -> Vector (Int, Int)
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
VU.modify (Comparison (Int, Int)
-> MVector (PrimState (ST s)) (Int, Int) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
VAI.sortBy Comparison (Int, Int)
forall a. Ord a => a -> a -> Ordering
compare) (Vector (Int, Int) -> Vector (Int, Int))
-> Vector (Int, Int) -> Vector (Int, Int)
forall a b. (a -> b) -> a -> b
$ Vector (Int, Int)
xys
  let (!Vector Int
_, !Vector Int
ys) = Vector (Int, Int) -> (Vector Int, Vector Int)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
VU.unzip Vector (Int, Int)
xys
  let yDictWm2d :: Vector Int
yDictWm2d = Vector Int -> Vector Int
forall a. (Unbox a, Eq a) => Vector a -> Vector a
VU.uniq (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s Int -> ST s ()) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
VU.modify (Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
VAI.sortBy Comparison Int
forall a. Ord a => a -> a -> Ordering
compare) Vector Int
ys
  -- REMARK: Be sure to use `n + 1` because the product function cannot handle the case
  --         `yUpper` is `2^{height}`.
  let (!Vector Int
_, !Vector Int
ysInput) = Vector (Int, Int) -> (Vector Int, Vector Int)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
VU.unzip Vector (Int, Int)
xyDictWm2d
  let rawWm2d :: RawWaveletMatrix
rawWm2d = HasCallStack => Int -> Vector Int -> RawWaveletMatrix
Int -> Vector Int -> RawWaveletMatrix
Rwm.build (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Vector Int -> RawWaveletMatrix) -> Vector Int -> RawWaveletMatrix
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> Vector Int -> Vector Int
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
yDictWm2d) Vector Int
ysInput
  Vector (SegTree (PrimState m) a)
segTreesWm2d <- Int
-> ST (PrimState m) (SegTree (PrimState m) a)
-> ST (PrimState m) (Vector (SegTree (PrimState m) a))
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM (RawWaveletMatrix -> Int
Rwm.heightRwm RawWaveletMatrix
rawWm2d) (Int -> ST (PrimState m) (SegTree (PrimState (ST (PrimState m))) a)
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
Int -> m (SegTree (PrimState m) a)
ST.new Int
n)
  WaveletMatrix2d (PrimState m) a
-> ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
invWm2d :: a -> a
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
rawWm2d :: RawWaveletMatrix
segTreesWm2d :: Vector (SegTree (PrimState m) a)
..}

-- | \(O(n \log n)\) Creates a `WaveletMatrix2d` with wavelet matrix with segment tree with initial
-- monoid values. Duplicate monoids at the same coordinate will be combined with `(<>)`.
--
-- @since 1.1.0.0
{-# INLINEABLE build #-}
build ::
  (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) =>
  -- | Inverse operator of the monoid
  (a -> a) ->
  -- | Input points with initial values
  VU.Vector (Int, Int, a) ->
  -- | A 2D wavelet matrix
  m (WaveletMatrix2d (PrimState m) a)
build :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
(a -> a)
-> Vector (Int, Int, a) -> m (WaveletMatrix2d (PrimState m) a)
build a -> a
invWm2d Vector (Int, Int, a)
xysw = ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
-> m (WaveletMatrix2d (PrimState m) a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
 -> m (WaveletMatrix2d (PrimState m) a))
-> ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
-> m (WaveletMatrix2d (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ do
  let (!Vector Int
xs, !Vector Int
ys, !Vector a
_) = Vector (Int, Int, a) -> (Vector Int, Vector Int, Vector a)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector (a, b, c) -> (Vector a, Vector b, Vector c)
VU.unzip3 Vector (Int, Int, a)
xysw
  WaveletMatrix2d (PrimState m) a
wm <- (a -> a)
-> Vector (Int, Int)
-> ST
     (PrimState m) (WaveletMatrix2d (PrimState (ST (PrimState m))) a)
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
(a -> a)
-> Vector (Int, Int) -> m (WaveletMatrix2d (PrimState m) a)
new a -> a
invWm2d (Vector (Int, Int)
 -> ST
      (PrimState m) (WaveletMatrix2d (PrimState (ST (PrimState m))) a))
-> Vector (Int, Int)
-> ST
     (PrimState m) (WaveletMatrix2d (PrimState (ST (PrimState m))) a)
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int -> Vector (Int, Int)
forall a b.
(Unbox a, Unbox b) =>
Vector a -> Vector b -> Vector (a, b)
VU.zip Vector Int
xs Vector Int
ys
  -- not the fastest implementation though
  Vector (Int, Int, a)
-> ((Int, Int, a) -> ST (PrimState m) ()) -> ST (PrimState m) ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
VU.forM_ Vector (Int, Int, a)
xysw (((Int, Int, a) -> ST (PrimState m) ()) -> ST (PrimState m) ())
-> ((Int, Int, a) -> ST (PrimState m) ()) -> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$ \(!Int
x, !Int
y, !a
w) -> do
    WaveletMatrix2d (PrimState (ST (PrimState m))) a
-> (a -> a) -> (Int, Int) -> ST (PrimState m) ()
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
WaveletMatrix2d (PrimState m) a -> (a -> a) -> (Int, Int) -> m ()
modify WaveletMatrix2d (PrimState m) a
WaveletMatrix2d (PrimState (ST (PrimState m))) a
wm (a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
w) (Int
x, Int
y)
  WaveletMatrix2d (PrimState m) a
-> ST (PrimState m) (WaveletMatrix2d (PrimState m) a)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure WaveletMatrix2d (PrimState m) a
wm

-- | \(O(1)\) Returns the monoid value at \((x, y)\).
--
-- @since 1.1.0.0
{-# INLINEABLE read #-}
read :: (HasCallStack, PrimMonad m, VU.Unbox a, Monoid a) => WaveletMatrix2d (PrimState m) a -> (Int, Int) -> m a
read :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a -> (Int, Int) -> m a
read WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
..} (!Int
x, !Int
y) = do
  SegTree (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
SegTree (PrimState m) a -> Int -> m a
ST.read (Vector (SegTree (PrimState m) a) -> SegTree (PrimState m) a
forall a. Vector a -> a
V.head Vector (SegTree (PrimState m) a)
segTreesWm2d) (Int -> m a) -> Int -> m a
forall a b. (a -> b) -> a -> b
$ Vector (Int, Int) -> (Int, Int) -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector (Int, Int)
xyDictWm2d (Int
x, Int
y)

-- | \(O(\log^2 n)\) Writes the monoid value at \((x, y)\). Access to unknown points are undefined.
--
-- @since 1.1.0.0
{-# INLINEABLE write #-}
write :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => WaveletMatrix2d (PrimState m) a -> (Int, Int) -> a -> m ()
write :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
WaveletMatrix2d (PrimState m) a -> (Int, Int) -> a -> m ()
write WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
..} (!Int
x, !Int
y) a
v = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  let !i_ :: Int
i_ = Vector (Int, Int) -> (Int, Int) -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector (Int, Int)
xyDictWm2d (Int
x, Int
y)
  (Int
 -> Int
 -> (BitVector, SegTree (PrimState m) a)
 -> ST (PrimState m) Int)
-> Int
-> Vector (BitVector, SegTree (PrimState m) a)
-> ST (PrimState m) ()
forall (m :: * -> *) a b.
Monad m =>
(a -> Int -> b -> m a) -> a -> Vector b -> m ()
V.ifoldM'_
    ( \Int
i Int
iRow (!BitVector
bits, !SegTree (PrimState m) a
seg) -> do
        let !i0 :: Int
i0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
i
        let !i' :: Int
i'
              | Bit -> Bool
unBit (Bit -> Bool) -> Bit -> Bool
forall a b. (a -> b) -> a -> b
$ Vector Bit -> Int -> Bit
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex (BitVector -> Vector Bit
BV.bitsBv BitVector
bits) Int
i =
                  Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RawWaveletMatrix -> Vector Int
Rwm.nZerosRwm RawWaveletMatrix
rawWm2d Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i0
              | Bool
otherwise = Int
i0
        SegTree (PrimState (ST (PrimState m))) a
-> Int -> a -> ST (PrimState m) ()
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
SegTree (PrimState m) a -> Int -> a -> m ()
ST.write SegTree (PrimState m) a
SegTree (PrimState (ST (PrimState m))) a
seg Int
i' a
v
        Int -> ST (PrimState m) Int
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i'
    )
    Int
i_
    (Vector (BitVector, SegTree (PrimState m) a)
 -> ST (PrimState m) ())
-> Vector (BitVector, SegTree (PrimState m) a)
-> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$ Vector BitVector
-> Vector (SegTree (PrimState m) a)
-> Vector (BitVector, SegTree (PrimState m) a)
forall a b. Vector a -> Vector b -> Vector (a, b)
V.zip (RawWaveletMatrix -> Vector BitVector
Rwm.bitsRwm RawWaveletMatrix
rawWm2d) Vector (SegTree (PrimState m) a)
segTreesWm2d

-- | \(O(\log^2 n)\) Given a user function \(f\), odifies the monoid value at \((x, y)\). Access to
-- unknown points are undefined.
--
-- @since 1.1.0.0
{-# INLINEABLE modify #-}
modify :: (HasCallStack, PrimMonad m, Monoid a, VU.Unbox a) => WaveletMatrix2d (PrimState m) a -> (a -> a) -> (Int, Int) -> m ()
modify :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
WaveletMatrix2d (PrimState m) a -> (a -> a) -> (Int, Int) -> m ()
modify WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
..} a -> a
f (!Int
x, !Int
y) = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  let !i_ :: Int
i_ = Vector (Int, Int) -> (Int, Int) -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector (Int, Int)
xyDictWm2d (Int
x, Int
y)
  (Int
 -> Int
 -> (BitVector, SegTree (PrimState m) a)
 -> ST (PrimState m) Int)
-> Int
-> Vector (BitVector, SegTree (PrimState m) a)
-> ST (PrimState m) ()
forall (m :: * -> *) a b.
Monad m =>
(a -> Int -> b -> m a) -> a -> Vector b -> m ()
V.ifoldM'_
    ( \Int
i Int
iRow (!BitVector
bits, !SegTree (PrimState m) a
seg) -> do
        let !i0 :: Int
i0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
i
        let !i' :: Int
i'
              | Bit -> Bool
unBit (Bit -> Bool) -> Bit -> Bool
forall a b. (a -> b) -> a -> b
$ Vector Bit -> Int -> Bit
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex (BitVector -> Vector Bit
BV.bitsBv BitVector
bits) Int
i =
                  Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RawWaveletMatrix -> Vector Int
Rwm.nZerosRwm RawWaveletMatrix
rawWm2d Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i0
              | Bool
otherwise = Int
i0
        SegTree (PrimState (ST (PrimState m))) a
-> (a -> a) -> Int -> ST (PrimState m) ()
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
SegTree (PrimState m) a -> (a -> a) -> Int -> m ()
ST.modify SegTree (PrimState m) a
SegTree (PrimState (ST (PrimState m))) a
seg a -> a
f Int
i'
        Int -> ST (PrimState m) Int
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i'
    )
    Int
i_
    (Vector (BitVector, SegTree (PrimState m) a)
 -> ST (PrimState m) ())
-> Vector (BitVector, SegTree (PrimState m) a)
-> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$ Vector BitVector
-> Vector (SegTree (PrimState m) a)
-> Vector (BitVector, SegTree (PrimState m) a)
forall a b. Vector a -> Vector b -> Vector (a, b)
V.zip (RawWaveletMatrix -> Vector BitVector
Rwm.bitsRwm RawWaveletMatrix
rawWm2d) Vector (SegTree (PrimState m) a)
segTreesWm2d

-- | \(O(\log^2 n)\) Returns monoid product \(\Pi_{p \in [x_1, x_2) \times [y_1, y_2)} a_p\).
--
-- @since 1.1.0.0
{-# INLINEABLE prod #-}
prod :: (HasCallStack, PrimMonad m, VU.Unbox a, Monoid a) => WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
prod :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
prod wm :: WaveletMatrix2d (PrimState m) a
wm@WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
..} !Int
xl !Int
xr !Int
yl !Int
yr
  | Int
xl' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
xr' Bool -> Bool -> Bool
|| Int
yl' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
yr' = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Monoid a => a
mempty
  | Bool
otherwise = WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
unsafeProd WaveletMatrix2d (PrimState m) a
wm Int
xl' Int
xr' Int
yl' Int
yr'
  where
    (!Vector Int
xDict, !Vector Int
_) = Vector (Int, Int) -> (Vector Int, Vector Int)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
VU.unzip Vector (Int, Int)
xyDictWm2d
    -- NOTE: clamping here!
    xl' :: Int
xl' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
xDict Int
xl
    xr' :: Int
xr' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
xDict Int
xr
    yl' :: Int
yl' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
yDictWm2d Int
yl
    yr' :: Int
yr' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
yDictWm2d Int
yr
    !()
_ = HasCallStack => String -> Int -> Int -> Int -> ()
String -> Int -> Int -> Int -> ()
ACIA.checkInterval String
"AtCoder.Extra.WaveletMatrix.SegTree.prod (compressed x)" Int
xl' Int
xr' (Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
xDict)
    !()
_ = HasCallStack => String -> Int -> Int -> Int -> ()
String -> Int -> Int -> Int -> ()
ACIA.checkInterval String
"AtCoder.Extra.WaveletMatrix.SegTree.prod (compressed y)" Int
yl' Int
yr' (Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
yDictWm2d)

-- | \(O(\log^2 n)\) Returns the monoid product in \([x_1, x_2) \times [y_1, y_2)\). Returns `Nothing` for invalid
-- intervals.
--
-- @since 1.1.0.0
{-# INLINEABLE prodMaybe #-}
prodMaybe :: (PrimMonad m, VU.Unbox a, Monoid a) => WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m (Maybe a)
prodMaybe :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a
-> Int -> Int -> Int -> Int -> m (Maybe a)
prodMaybe wm :: WaveletMatrix2d (PrimState m) a
wm@WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
..} !Int
xl !Int
xr !Int
yl !Int
yr
  | Bool -> Bool
not (Int -> Int -> Int -> Bool
ACIA.testInterval Int
xl' Int
xr' (Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
xDict)) = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
  | Bool -> Bool
not (Int -> Int -> Int -> Bool
ACIA.testInterval Int
yl' Int
yr' (Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
yDictWm2d)) = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
  | Int
xl' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
xr' Bool -> Bool -> Bool
|| Int
yl' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
yr' = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ a -> Maybe a
forall a. a -> Maybe a
Just a
forall a. Monoid a => a
mempty
  | Bool
otherwise = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
unsafeProd WaveletMatrix2d (PrimState m) a
wm Int
xl' Int
xr' Int
yl' Int
yr'
  where
    (!Vector Int
xDict, !Vector Int
_) = Vector (Int, Int) -> (Vector Int, Vector Int)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
VU.unzip Vector (Int, Int)
xyDictWm2d
    -- NOTE: clamping here!
    xl' :: Int
xl' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
xDict Int
xl
    xr' :: Int
xr' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
xDict Int
xr
    yl' :: Int
yl' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
yDictWm2d Int
yl
    yr' :: Int
yr' = Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a, Ord a) =>
v a -> a -> Int
lowerBound Vector Int
yDictWm2d Int
yr

-- | \(O(\log^2 n)\) Return the monoid product in \([-\infty, \infty) \times [-\infty, \infty)\).
--
-- @since 1.1.0.0
{-# INLINEABLE allProd #-}
allProd :: (HasCallStack, PrimMonad m, PrimMonad m, VU.Unbox a, Monoid a) => WaveletMatrix2d (PrimState m) a -> m a
allProd :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a -> m a
allProd WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree (PrimState m) a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree (PrimState m) a)
invWm2d :: a -> a
..} = do
  -- ST.allProd (V.last segTreesWm2d)
  SegTree (PrimState m) a -> m a
forall (m :: * -> *) a.
(PrimMonad m, Monoid a, Unbox a) =>
SegTree (PrimState m) a -> m a
ST.allProd (Vector (SegTree (PrimState m) a) -> SegTree (PrimState m) a
forall a. Vector a -> a
V.head Vector (SegTree (PrimState m) a)
segTreesWm2d)

-- | \(O(\log^2 n)\) The input is compressed indices.
--
-- @since 1.1.0.0
{-# INLINE unsafeProd #-}
unsafeProd :: (PrimMonad m, VU.Unbox a, Monoid a) => WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
unsafeProd :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Monoid a) =>
WaveletMatrix2d (PrimState m) a -> Int -> Int -> Int -> Int -> m a
unsafeProd WaveletMatrix2d (PrimState m) a
wm Int
xl' Int
xr' Int
yl' Int
yr' = ST (PrimState m) a -> m a
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) a -> m a) -> ST (PrimState m) a -> m a
forall a b. (a -> b) -> a -> b
$ do
  a
sR <- WaveletMatrix2d (PrimState m) a
-> Int -> Int -> Int -> ST (PrimState m) a
forall a s.
(Monoid a, Unbox a) =>
WaveletMatrix2d s a -> Int -> Int -> Int -> ST s a
prodLT WaveletMatrix2d (PrimState m) a
wm Int
xl' Int
xr' Int
yr'
  a
sL <- WaveletMatrix2d (PrimState m) a
-> Int -> Int -> Int -> ST (PrimState m) a
forall a s.
(Monoid a, Unbox a) =>
WaveletMatrix2d s a -> Int -> Int -> Int -> ST s a
prodLT WaveletMatrix2d (PrimState m) a
wm Int
xl' Int
xr' Int
yl'
  a -> ST (PrimState m) a
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> ST (PrimState m) a) -> a -> ST (PrimState m) a
forall a b. (a -> b) -> a -> b
$! a
sR a -> a -> a
forall a. Semigroup a => a -> a -> a
<> WaveletMatrix2d (PrimState m) a -> a -> a
forall s a. WaveletMatrix2d s a -> a -> a
invWm2d WaveletMatrix2d (PrimState m) a
wm a
sL

-- | \(O(\log^2 n)\)
{-# INLINEABLE prodLT #-}
prodLT :: (Monoid a, VU.Unbox a) => WaveletMatrix2d s a -> Int -> Int -> Int -> ST s a
prodLT :: forall a s.
(Monoid a, Unbox a) =>
WaveletMatrix2d s a -> Int -> Int -> Int -> ST s a
prodLT WaveletMatrix2d {Vector Int
Vector (Int, Int)
Vector (SegTree s a)
RawWaveletMatrix
a -> a
rawWm2d :: forall s a. WaveletMatrix2d s a -> RawWaveletMatrix
xyDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector (Int, Int)
yDictWm2d :: forall s a. WaveletMatrix2d s a -> Vector Int
segTreesWm2d :: forall s a. WaveletMatrix2d s a -> Vector (SegTree s a)
invWm2d :: forall s a. WaveletMatrix2d s a -> a -> a
rawWm2d :: RawWaveletMatrix
xyDictWm2d :: Vector (Int, Int)
yDictWm2d :: Vector Int
segTreesWm2d :: Vector (SegTree s a)
invWm2d :: a -> a
..} !Int
l_ !Int
r_ Int
yUpper = do
  (!a
res, !Int
_, !Int
_) <- do
    ((a, Int, Int)
 -> Int -> (BitVector, SegTree s a) -> ST s (a, Int, Int))
-> (a, Int, Int)
-> Vector (BitVector, SegTree s a)
-> ST s (a, Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> Int -> b -> m a) -> a -> Vector b -> m a
V.ifoldM'
      ( \(!a
acc, !Int
l, !Int
r) !Int
iRow (!BitVector
bits, !SegTree s a
seg) -> do
          let !l0 :: Int
l0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
l
              !r0 :: Int
r0 = BitVector -> Int -> Int
BV.rank0 BitVector
bits Int
r
          -- REMARK: The function cannot handle the case yUpper = N = 2^i. See the constructor for
          -- how it's handled and note that l_ and r_ are compressed indices.
          if Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
yUpper (RawWaveletMatrix -> Int
Rwm.heightRwm RawWaveletMatrix
rawWm2d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iRow)
            then do
              !a
acc' <- (a
acc <>) (a -> a) -> ST s a -> ST s a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegTree (PrimState (ST s)) a -> Int -> Int -> ST s a
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Monoid a, Unbox a) =>
SegTree (PrimState m) a -> Int -> Int -> m a
ST.prod SegTree s a
SegTree (PrimState (ST s)) a
seg Int
l0 Int
r0
              let !l' :: Int
l' = Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RawWaveletMatrix -> Vector Int
Rwm.nZerosRwm RawWaveletMatrix
rawWm2d Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l0
              let !r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ RawWaveletMatrix -> Vector Int
Rwm.nZerosRwm RawWaveletMatrix
rawWm2d Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0
              (a, Int, Int) -> ST s (a, Int, Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
acc', Int
l', Int
r')
            else do
              (a, Int, Int) -> ST s (a, Int, Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
acc, Int
l0, Int
r0)
      )
      (a
forall a. Monoid a => a
mempty, Int
l_, Int
r_)
      (Vector (BitVector, SegTree s a) -> ST s (a, Int, Int))
-> Vector (BitVector, SegTree s a) -> ST s (a, Int, Int)
forall a b. (a -> b) -> a -> b
$ Vector BitVector
-> Vector (SegTree s a) -> Vector (BitVector, SegTree s a)
forall a b. Vector a -> Vector b -> Vector (a, b)
V.zip (RawWaveletMatrix -> Vector BitVector
Rwm.bitsRwm RawWaveletMatrix
rawWm2d) Vector (SegTree s a)
segTreesWm2d
  a -> ST s a
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res