{-# LANGUAGE RecordWildCards #-}

-- | A disjoint set union with commutative monoid values associated with each group.
--
-- ==== __Example__
--
-- >>> import AtCoder.Extra.DsuMonoid qualified as Dm
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> dsu <- Dm.build $ VU.generate 4 Sum
-- >>> Dm.merge dsu 0 1
-- 0
--
-- >>> Dm.read dsu 0
-- Sum {getSum = 1}
--
-- >>> Dm.read dsu 1
-- Sum {getSum = 1}
--
-- >>> Dm.mergeMaybe dsu 0 2
-- Just 0
--
-- >>> Dm.read dsu 0
-- Sum {getSum = 3}
--
-- @since 1.5.3.0
module AtCoder.Extra.DsuMonoid
  ( -- * Disjoint set union
    DsuMonoid (dsuDm, mDm),

    -- * Constructors
    new,
    build,

    -- * Merging
    merge,
    mergeMaybe,
    merge_,

    -- * Leader
    leader,

    -- * Component information
    same,
    size,
    groups,

    -- * Monoid values
    read,
    unsafeRead,
    unsafeWrite,
  )
where

import AtCoder.Dsu qualified as Dsu
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Data.Vector qualified as V
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
import Prelude hiding (read)

-- | A disjoint set union with commutative monoid values associated with each group.
--
-- @since 1.5.3.0
data DsuMonoid s a = DsuMonoid
  { -- | The original DSU.
    --
    -- @since 1.5.3.0
    forall s a. DsuMonoid s a -> Dsu s
dsuDm :: {-# UNPACK #-} !(Dsu.Dsu s),
    -- | Commutative monoid values for each group.
    --
    -- @since 1.5.3.0
    forall s a. DsuMonoid s a -> MVector s a
mDm :: !(VUM.MVector s a)
  }

-- | Creates an undirected graph with \(n\) vertices and \(0\) edges.
--
-- ==== Constraints
-- - \(0 \le n\)
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.5.3.0
{-# INLINE new #-}
new :: (PrimMonad m, Monoid a, VU.Unbox a) => Int -> m (DsuMonoid (PrimState m) a)
new :: forall (m :: * -> *) a.
(PrimMonad m, Monoid a, Unbox a) =>
Int -> m (DsuMonoid (PrimState m) a)
new Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = Vector a -> m (DsuMonoid (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
Vector a -> m (DsuMonoid (PrimState m) a)
build (Vector a -> m (DsuMonoid (PrimState m) a))
-> Vector a -> m (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
VU.replicate Int
n a
forall a. Monoid a => a
mempty
  | Bool
otherwise = [Char] -> m (DsuMonoid (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (DsuMonoid (PrimState m) a))
-> [Char] -> m (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [Char]
"AtCoder.Extra.DsuMonoid: given negative size (`" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
n [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"`)"

-- | Creates an undirected graph with \(n\) vertices and \(0\) edges.
--
-- ==== Constraints
-- - \(0 \le n\)
--
-- ==== Complexity
-- - \(O(n)\)
--
-- @since 1.5.3.0
{-# INLINE build #-}
build :: (PrimMonad m, Semigroup a, VU.Unbox a) => VU.Vector a -> m (DsuMonoid (PrimState m) a)
build :: forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
Vector a -> m (DsuMonoid (PrimState m) a)
build Vector a
ms = ST (PrimState m) (DsuMonoid (PrimState m) a)
-> m (DsuMonoid (PrimState m) a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (DsuMonoid (PrimState m) a)
 -> m (DsuMonoid (PrimState m) a))
-> ST (PrimState m) (DsuMonoid (PrimState m) a)
-> m (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ do
  Dsu (PrimState m)
dsuDm <- Int -> ST (PrimState m) (Dsu (PrimState (ST (PrimState m))))
forall (m :: * -> *). PrimMonad m => Int -> m (Dsu (PrimState m))
Dsu.new (Int -> ST (PrimState m) (Dsu (PrimState (ST (PrimState m)))))
-> Int -> ST (PrimState m) (Dsu (PrimState (ST (PrimState m))))
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
ms
  MVector (PrimState m) a
mDm <- Vector a
-> ST (PrimState m) (MVector (PrimState (ST (PrimState m))) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector a
ms
  DsuMonoid (PrimState m) a
-> ST (PrimState m) (DsuMonoid (PrimState m) a)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DsuMonoid (PrimState m) a
 -> ST (PrimState m) (DsuMonoid (PrimState m) a))
-> DsuMonoid (PrimState m) a
-> ST (PrimState m) (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..}

-- | Adds an edge \((a, b)\). If the vertices \(a\) and \(b\) are in the same connected component, it
-- returns the representative (`leader`) of this connected component. Otherwise, it returns the
-- representative of the new connected component.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.5.3.0
{-# INLINEABLE merge #-}
merge :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m Int
merge :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m Int
merge DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
a Int
b = ST (PrimState m) Int -> m Int
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) Int -> m Int) -> ST (PrimState m) Int -> m Int
forall a b. (a -> b) -> a -> b
$ do
  Int
r1 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a
  Int
r2 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
b
  if Int
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2
    then Int -> ST (PrimState m) Int
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
r1
    else do
      !a
m1 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r1
      !a
m2 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r2
      Int
r' <- Dsu (PrimState (ST (PrimState m)))
-> Int -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
Dsu.merge Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a Int
b
      MVector (PrimState (ST (PrimState m))) a
-> Int -> a -> ST (PrimState m) ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r' (a -> ST (PrimState m) ()) -> a -> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$! a
m1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
m2
      Int -> ST (PrimState m) Int
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
r'

-- | Adds an edge \((a, b)\). It returns the representative of the new connected component, or
-- `Nothing` if the two vertices are in the same connected component.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.2.4.0
{-# INLINEABLE mergeMaybe #-}
mergeMaybe :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m (Maybe Int)
mergeMaybe :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m (Maybe Int)
mergeMaybe DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
a Int
b = ST (PrimState m) (Maybe Int) -> m (Maybe Int)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (Maybe Int) -> m (Maybe Int))
-> ST (PrimState m) (Maybe Int) -> m (Maybe Int)
forall a b. (a -> b) -> a -> b
$ do
  Int
r1 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a
  Int
r2 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
b
  if Int
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2
    then Maybe Int -> ST (PrimState m) (Maybe Int)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
    else do
      !a
m1 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r1
      !a
m2 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r2
      Int
r' <- Dsu (PrimState (ST (PrimState m)))
-> Int -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
Dsu.merge Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a Int
b
      MVector (PrimState (ST (PrimState m))) a
-> Int -> a -> ST (PrimState m) ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r' (a -> ST (PrimState m) ()) -> a -> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$! a
m1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
m2
      Maybe Int -> ST (PrimState m) (Maybe Int)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> ST (PrimState m) (Maybe Int))
-> Maybe Int -> ST (PrimState m) (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
r'

-- | `merge` with the return value discarded.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.5.3.0
{-# INLINE merge_ #-}
merge_ :: (PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m ()
merge_ :: forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m ()
merge_ DsuMonoid (PrimState m) a
dsu Int
a Int
b = do
  Int
_ <- DsuMonoid (PrimState m) a -> Int -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m Int
merge DsuMonoid (PrimState m) a
dsu Int
a Int
b
  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Returns whether the vertices \(a\) and \(b\) are in the same connected component.
--
-- ==== Constraints
-- - \(0 \leq a < n\)
-- - \(0 \leq b < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.5.3.0
{-# INLINE same #-}
same :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> Int -> m Bool
same :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m Bool
same DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> Int -> Int -> m Bool
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Bool
Dsu.same (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)

-- | Returns the representative of the connected component that contains the vertex \(a\).
--
-- ==== Constraints
-- - \(0 \leq a \lt n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\) amortized
--
-- @since 1.5.3.0
{-# INLINE leader #-}
leader :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> m Int
leader :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m) =>
DsuMonoid (PrimState m) a -> Int -> m Int
leader DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)

-- | Returns the size of the connected component that contains the vertex \(a\).
--
-- ==== Constraints
-- -  \(0 \leq a < n\)
--
-- ==== Complexity
-- - \(O(\alpha(n))\)
--
-- @since 1.5.3.0
{-# INLINE size #-}
size :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> m Int
size :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m) =>
DsuMonoid (PrimState m) a -> Int -> m Int
size DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.size (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)

-- | \O(n)\) Divides the graph into connected components and returns the vector of them.
--
-- More precisely, it returns a vector of the "vector of the vertices in a connected component".
-- Both of the orders of the connected components and the vertices are undefined.
--
-- @since 1.5.3.0
{-# INLINE groups #-}
groups :: (PrimMonad m) => DsuMonoid (PrimState m) a -> m (V.Vector (VU.Vector Int))
groups :: forall (m :: * -> *) a.
PrimMonad m =>
DsuMonoid (PrimState m) a -> m (Vector (Vector Int))
groups DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> m (Vector (Vector Int))
forall (m :: * -> *).
PrimMonad m =>
Dsu (PrimState m) -> m (Vector (Vector Int))
Dsu.groups (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)

-- | \(O(1)\) Reads the group value of the \(k\)-th node.
--
-- @since 1.5.3.0
{-# INLINE read #-}
read :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> m a
read :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> m a
read DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
i = do
  MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
mDm (Int -> m a) -> m Int -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
dsuDm Int
i

-- | \(O(1)\) Reads the group value of the \(k\)-th node.
--
-- @since 1.5.3.0
{-# INLINE unsafeRead #-}
unsafeRead :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> m a
unsafeRead :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> m a
unsafeRead DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
i = do
  MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
mDm Int
i

-- | \(O(1)\) Writes to the group value of the \(k\)-th node.
--
-- @since 1.5.3.0
{-# INLINE unsafeWrite #-}
unsafeWrite :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> a -> m ()
unsafeWrite :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> a -> m ()
unsafeWrite DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
i a
x = do
  MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
mDm Int
i a
x