{-# LANGUAGE RecordWildCards #-}
module AtCoder.Extra.DsuMonoid
(
DsuMonoid (dsuDm, mDm),
new,
build,
merge,
mergeMaybe,
merge_,
leader,
same,
size,
groups,
read,
unsafeRead,
unsafeWrite,
)
where
import AtCoder.Dsu qualified as Dsu
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Data.Vector qualified as V
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
import Prelude hiding (read)
data DsuMonoid s a = DsuMonoid
{
forall s a. DsuMonoid s a -> Dsu s
dsuDm :: {-# UNPACK #-} !(Dsu.Dsu s),
forall s a. DsuMonoid s a -> MVector s a
mDm :: !(VUM.MVector s a)
}
{-# INLINE new #-}
new :: (PrimMonad m, Monoid a, VU.Unbox a) => Int -> m (DsuMonoid (PrimState m) a)
new :: forall (m :: * -> *) a.
(PrimMonad m, Monoid a, Unbox a) =>
Int -> m (DsuMonoid (PrimState m) a)
new Int
n
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = Vector a -> m (DsuMonoid (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
Vector a -> m (DsuMonoid (PrimState m) a)
build (Vector a -> m (DsuMonoid (PrimState m) a))
-> Vector a -> m (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
VU.replicate Int
n a
forall a. Monoid a => a
mempty
| Bool
otherwise = [Char] -> m (DsuMonoid (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (DsuMonoid (PrimState m) a))
-> [Char] -> m (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [Char]
"AtCoder.Extra.DsuMonoid: given negative size (`" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
n [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"`)"
{-# INLINE build #-}
build :: (PrimMonad m, Semigroup a, VU.Unbox a) => VU.Vector a -> m (DsuMonoid (PrimState m) a)
build :: forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
Vector a -> m (DsuMonoid (PrimState m) a)
build Vector a
ms = ST (PrimState m) (DsuMonoid (PrimState m) a)
-> m (DsuMonoid (PrimState m) a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (DsuMonoid (PrimState m) a)
-> m (DsuMonoid (PrimState m) a))
-> ST (PrimState m) (DsuMonoid (PrimState m) a)
-> m (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ do
Dsu (PrimState m)
dsuDm <- Int -> ST (PrimState m) (Dsu (PrimState (ST (PrimState m))))
forall (m :: * -> *). PrimMonad m => Int -> m (Dsu (PrimState m))
Dsu.new (Int -> ST (PrimState m) (Dsu (PrimState (ST (PrimState m)))))
-> Int -> ST (PrimState m) (Dsu (PrimState (ST (PrimState m))))
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
ms
MVector (PrimState m) a
mDm <- Vector a
-> ST (PrimState m) (MVector (PrimState (ST (PrimState m))) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector a
ms
DsuMonoid (PrimState m) a
-> ST (PrimState m) (DsuMonoid (PrimState m) a)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DsuMonoid (PrimState m) a
-> ST (PrimState m) (DsuMonoid (PrimState m) a))
-> DsuMonoid (PrimState m) a
-> ST (PrimState m) (DsuMonoid (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..}
{-# INLINEABLE merge #-}
merge :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m Int
merge :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m Int
merge DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
a Int
b = ST (PrimState m) Int -> m Int
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) Int -> m Int) -> ST (PrimState m) Int -> m Int
forall a b. (a -> b) -> a -> b
$ do
Int
r1 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a
Int
r2 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
b
if Int
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2
then Int -> ST (PrimState m) Int
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
r1
else do
!a
m1 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r1
!a
m2 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r2
Int
r' <- Dsu (PrimState (ST (PrimState m)))
-> Int -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
Dsu.merge Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a Int
b
MVector (PrimState (ST (PrimState m))) a
-> Int -> a -> ST (PrimState m) ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r' (a -> ST (PrimState m) ()) -> a -> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$! a
m1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
m2
Int -> ST (PrimState m) Int
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
r'
{-# INLINEABLE mergeMaybe #-}
mergeMaybe :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m (Maybe Int)
mergeMaybe :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m (Maybe Int)
mergeMaybe DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
a Int
b = ST (PrimState m) (Maybe Int) -> m (Maybe Int)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (Maybe Int) -> m (Maybe Int))
-> ST (PrimState m) (Maybe Int) -> m (Maybe Int)
forall a b. (a -> b) -> a -> b
$ do
Int
r1 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a
Int
r2 <- Dsu (PrimState (ST (PrimState m))) -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
b
if Int
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
r2
then Maybe Int -> ST (PrimState m) (Maybe Int)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
else do
!a
m1 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r1
!a
m2 <- MVector (PrimState (ST (PrimState m))) a
-> Int -> ST (PrimState m) a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r2
Int
r' <- Dsu (PrimState (ST (PrimState m)))
-> Int -> Int -> ST (PrimState m) Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Int
Dsu.merge Dsu (PrimState m)
Dsu (PrimState (ST (PrimState m)))
dsuDm Int
a Int
b
MVector (PrimState (ST (PrimState m))) a
-> Int -> a -> ST (PrimState m) ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
MVector (PrimState (ST (PrimState m))) a
mDm Int
r' (a -> ST (PrimState m) ()) -> a -> ST (PrimState m) ()
forall a b. (a -> b) -> a -> b
$! a
m1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
m2
Maybe Int -> ST (PrimState m) (Maybe Int)
forall a. a -> ST (PrimState m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> ST (PrimState m) (Maybe Int))
-> Maybe Int -> ST (PrimState m) (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
r'
{-# INLINE merge_ #-}
merge_ :: (PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m ()
merge_ :: forall (m :: * -> *) a.
(PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m ()
merge_ DsuMonoid (PrimState m) a
dsu Int
a Int
b = do
Int
_ <- DsuMonoid (PrimState m) a -> Int -> Int -> m Int
forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Semigroup a, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m Int
merge DsuMonoid (PrimState m) a
dsu Int
a Int
b
() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE same #-}
same :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> Int -> m Bool
same :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m) =>
DsuMonoid (PrimState m) a -> Int -> Int -> m Bool
same DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> Int -> Int -> m Bool
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m Bool
Dsu.same (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)
{-# INLINE leader #-}
leader :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> m Int
leader :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m) =>
DsuMonoid (PrimState m) a -> Int -> m Int
leader DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)
{-# INLINE size #-}
size :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> m Int
size :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m) =>
DsuMonoid (PrimState m) a -> Int -> m Int
size DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.size (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)
{-# INLINE groups #-}
groups :: (PrimMonad m) => DsuMonoid (PrimState m) a -> m (V.Vector (VU.Vector Int))
groups :: forall (m :: * -> *) a.
PrimMonad m =>
DsuMonoid (PrimState m) a -> m (Vector (Vector Int))
groups DsuMonoid (PrimState m) a
dsu = Dsu (PrimState m) -> m (Vector (Vector Int))
forall (m :: * -> *).
PrimMonad m =>
Dsu (PrimState m) -> m (Vector (Vector Int))
Dsu.groups (DsuMonoid (PrimState m) a -> Dsu (PrimState m)
forall s a. DsuMonoid s a -> Dsu s
dsuDm DsuMonoid (PrimState m) a
dsu)
{-# INLINE read #-}
read :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> m a
read :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> m a
read DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
i = do
MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
mDm (Int -> m a) -> m Int -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Dsu (PrimState m) -> Int -> m Int
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> m Int
Dsu.leader Dsu (PrimState m)
dsuDm Int
i
{-# INLINE unsafeRead #-}
unsafeRead :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> m a
unsafeRead :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> m a
unsafeRead DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
i = do
MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) a
mDm Int
i
{-# INLINE unsafeWrite #-}
unsafeWrite :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> a -> m ()
unsafeWrite :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
DsuMonoid (PrimState m) a -> Int -> a -> m ()
unsafeWrite DsuMonoid {MVector (PrimState m) a
Dsu (PrimState m)
dsuDm :: forall s a. DsuMonoid s a -> Dsu s
mDm :: forall s a. DsuMonoid s a -> MVector s a
dsuDm :: Dsu (PrimState m)
mDm :: MVector (PrimState m) a
..} Int
i a
x = do
MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) a
mDm Int
i a
x