-- | Generic tree functions.
--
-- @since 1.1.0.0
module AtCoder.Extra.Tree
  ( -- * Tree properties
    diameter,
    diameterPath,

    -- * Minimum spanning tree
    mst,
    mstBy,

    -- * Tree folding

    -- | These function are built around the three type parameters: \(w\), \(f\) and \(a\).
    --
    -- - \(w\): Edge weight.
    -- - \(f\): Monoid action to a vertex value. These actions are created from vertex value \(a\)
    -- and edge information @(Int, w)@.
    -- - \(a\): Monoid values stored at vertices.
    fold,
    scan,
    foldReroot,
  )
where

import AtCoder.Dsu qualified as Dsu
import AtCoder.Extra.Graph qualified as Gr
import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Bit (Bit (..))
import Data.Functor.Identity (runIdentity)
import Data.Maybe (isJust)
import Data.Ord (comparing)
import Data.Vector.Algorithms.Intro qualified as VAI
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)

-- | \(O(n + m)\) Returns the endpoints of the diameter of a tree and their distance: \(((u, v), w)\).
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let es = VU.fromList [(0, 1, 1 :: Int), (1, 2, 10), (1, 3, 10)]
-- >>> let gr = Gr.build 4 $ Gr.swapDupe es
-- >>> Tree.diameter 4 (Gr.adjW gr) (-1)
-- ((2,3),20)
--
-- @since 1.2.4.0
{-# INLINEABLE diameter #-}
diameter ::
  (HasCallStack, VU.Unbox w, Num w, Ord w) =>
  -- | The number of vertices.
  Int ->
  -- | Graph given as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | Distances assigned to unreachable vertices.
  w ->
  -- | Tuple of (endpoints of the longest path in a tree, distance of it).
  ((Int, Int), w)
diameter :: forall w.
(HasCallStack, Unbox w, Num w, Ord w) =>
Int -> (Int -> Vector (Int, w)) -> w -> ((Int, Int), w)
diameter Int
n Int -> Vector (Int, w)
gr !w
undefW =
  let !bfs1 :: Vector w
bfs1 = Int -> (Int -> Vector (Int, w)) -> w -> Vector (Int, w) -> Vector w
forall i w.
(HasCallStack, Ix0 i, Unbox i, Unbox w, Num w, Eq w) =>
i -> (i -> Vector (i, w)) -> w -> Vector (i, w) -> Vector w
Gr.bfs Int
n Int -> Vector (Int, w)
gr w
undefW (Vector (Int, w) -> Vector w) -> Vector (Int, w) -> Vector w
forall a b. (a -> b) -> a -> b
$ (Int, w) -> Vector (Int, w)
forall a. Unbox a => a -> Vector a
VU.singleton (Int
0, w
0)
      !from :: Int
from = Vector w -> Int
forall a. (Unbox a, Ord a) => Vector a -> Int
VU.maxIndex Vector w
bfs1
      !bfs2 :: Vector w
bfs2 = Int -> (Int -> Vector (Int, w)) -> w -> Vector (Int, w) -> Vector w
forall i w.
(HasCallStack, Ix0 i, Unbox i, Unbox w, Num w, Eq w) =>
i -> (i -> Vector (i, w)) -> w -> Vector (i, w) -> Vector w
Gr.bfs Int
n Int -> Vector (Int, w)
gr w
undefW (Vector (Int, w) -> Vector w) -> Vector (Int, w) -> Vector w
forall a b. (a -> b) -> a -> b
$ (Int, w) -> Vector (Int, w)
forall a. Unbox a => a -> Vector a
VU.singleton (Int
from, w
0)
      !to :: Int
to = Vector w -> Int
forall a. (Unbox a, Ord a) => Vector a -> Int
VU.maxIndex Vector w
bfs2
      !w :: w
w = Vector w -> w
forall a. (Unbox a, Ord a) => Vector a -> a
VU.maximum Vector w
bfs2
   in ((Int
from, Int
to), w
w)

-- | \(O(n + m)\) Returns the path longest path in a tree and the distance of it.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let es = VU.fromList [(0, 1, 1 :: Int), (1, 2, 10), (1, 3, 10)]
-- >>> let gr = Gr.build 4 $ Gr.swapDupe es
-- >>> Tree.diameterPath 4 (Gr.adjW gr) (-1)
-- ([2,1,3],20)
--
-- @since 1.2.4.0
{-# INLINEABLE diameterPath #-}
diameterPath ::
  (HasCallStack, Show w, VU.Unbox w, Num w, Ord w) =>
  -- | The number of vertices.
  Int ->
  -- | Graph given as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | Distances assigned to unreachable vertices.
  w ->
  -- | Tuple of (the longest path, distance of it).
  (VU.Vector Int, w)
diameterPath :: forall w.
(HasCallStack, Show w, Unbox w, Num w, Ord w) =>
Int -> (Int -> Vector (Int, w)) -> w -> (Vector Int, w)
diameterPath Int
n Int -> Vector (Int, w)
gr !w
undefW =
  let !bfs1 :: Vector w
bfs1 = Int -> (Int -> Vector (Int, w)) -> w -> Vector (Int, w) -> Vector w
forall i w.
(HasCallStack, Ix0 i, Unbox i, Unbox w, Num w, Eq w) =>
i -> (i -> Vector (i, w)) -> w -> Vector (i, w) -> Vector w
Gr.bfs Int
n Int -> Vector (Int, w)
gr w
undefW (Vector (Int, w) -> Vector w) -> Vector (Int, w) -> Vector w
forall a b. (a -> b) -> a -> b
$ (Int, w) -> Vector (Int, w)
forall a. Unbox a => a -> Vector a
VU.singleton (Int
0, w
0)
      !from :: Int
from = Vector w -> Int
forall a. (Unbox a, Ord a) => Vector a -> Int
VU.maxIndex Vector w
bfs1
      (!Vector w
bfs2, !Vector Int
parents) = Int
-> (Int -> Vector (Int, w))
-> w
-> Vector (Int, w)
-> (Vector w, Vector Int)
forall i w.
(HasCallStack, Ix0 i, Unbox i, Unbox w, Num w, Eq w) =>
i
-> (i -> Vector (i, w))
-> w
-> Vector (i, w)
-> (Vector w, Vector Int)
Gr.trackingBfs Int
n Int -> Vector (Int, w)
gr w
undefW (Vector (Int, w) -> (Vector w, Vector Int))
-> Vector (Int, w) -> (Vector w, Vector Int)
forall a b. (a -> b) -> a -> b
$ (Int, w) -> Vector (Int, w)
forall a. Unbox a => a -> Vector a
VU.singleton (Int
from, w
0)
      !to :: Int
to = Vector w -> Int
forall a. (Unbox a, Ord a) => Vector a -> Int
VU.maxIndex Vector w
bfs2
      !w :: w
w = Vector w
bfs2 Vector w -> Int -> w
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
to
   in (HasCallStack => Vector Int -> Int -> Vector Int
Vector Int -> Int -> Vector Int
Gr.constructPathFromRoot Vector Int
parents Int
to, w
w)

-- | \(O(m \log m)\) Kruscal's algorithm. Returns edge indices for building a minimum spanning tree.
--
-- NOTE: The edges should not be duplicated: only one of \((u, v, w)\) or \((v, u w)\) is required
-- for each edge.
--
-- ==== __Example__
-- Create a minimum spanning tree:
--
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let es = VU.fromList [(0, 1, 1 :: Int), (1, 2, 10), (0, 2, 2)]
-- >>> let (!wSum, !edgeUse, !gr) = Tree.mst 3 es
-- >>> wSum
-- 3
--
-- >>> edgeUse
-- [1,0,1]
--
-- >>> Gr.adj gr 0
-- [1,2]
--
-- @since 1.2.4.0
{-# INLINE mst #-}
mst :: (Num w, Ord w, VU.Unbox w) => Int -> VU.Vector (Int, Int, w) -> (w, VU.Vector Bit, Gr.Csr w)
mst :: forall w.
(Num w, Ord w, Unbox w) =>
Int -> Vector (Int, Int, w) -> (w, Vector Bit, Csr w)
mst = (w -> w -> Ordering)
-> Int -> Vector (Int, Int, w) -> (w, Vector Bit, Csr w)
forall w.
(Num w, Ord w, Unbox w) =>
(w -> w -> Ordering)
-> Int -> Vector (Int, Int, w) -> (w, Vector Bit, Csr w)
mstBy ((w -> w) -> w -> w -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing w -> w
forall a. a -> a
id)

-- | \(O(m \log m)\) Kruscal's algorithm. Returns edge indices for building a minimum/maximum
-- spanning tree.
--
-- NOTE: The edges should not be duplicated: only one of \((u, v, w)\) or \((v, u, w)\) is required
-- for each edge.
--
-- ==== __Example__
-- Create a maximum spanning tree:
--
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Ord (Down (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let es = VU.fromList [(0, 1, 1 :: Int), (1, 2, 10), (0, 2, 2)]
-- >>> let (!wSum, !edgeUse, !gr) = Tree.mstBy (comparing Down) 3 es
-- >>> wSum
-- 12
--
-- >>> edgeUse
-- [0,1,1]
--
-- >>> Gr.adj gr 0
-- [2]
--
-- @since 1.2.4.0
{-# INLINEABLE mstBy #-}
mstBy :: (Num w, Ord w, VU.Unbox w) => (w -> w -> Ordering) -> Int -> VU.Vector (Int, Int, w) -> (w, VU.Vector Bit, Gr.Csr w)
mstBy :: forall w.
(Num w, Ord w, Unbox w) =>
(w -> w -> Ordering)
-> Int -> Vector (Int, Int, w) -> (w, Vector Bit, Csr w)
mstBy !w -> w -> Ordering
f Int
nVerts Vector (Int, Int, w)
edges = (forall s. ST s (w, Vector Bit, Csr w)) -> (w, Vector Bit, Csr w)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (w, Vector Bit, Csr w)) -> (w, Vector Bit, Csr w))
-> (forall s. ST s (w, Vector Bit, Csr w))
-> (w, Vector Bit, Csr w)
forall a b. (a -> b) -> a -> b
$ do
  Dsu s
dsu <- Int -> ST s (Dsu (PrimState (ST s)))
forall (m :: * -> *). PrimMonad m => Int -> m (Dsu (PrimState m))
Dsu.new Int
nVerts
  MVector s w
wSum <- Int -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
1 w
0
  Vector Bit
use <-
    ( (Bit -> Bit -> Bit)
-> Vector Bit -> Vector (Int, Bit) -> Vector Bit
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> Vector a -> Vector (Int, b) -> Vector a
VU.accumulate
        ((Bit -> Bit) -> Bit -> Bit -> Bit
forall a b. a -> b -> a
const Bit -> Bit
forall a. a -> a
id)
        (Int -> Bit -> Vector Bit
forall a. Unbox a => Int -> a -> Vector a
VU.replicate (Vector (Int, Int, w) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Int, w)
edges) (Bool -> Bit
Bit Bool
False))
        <$>
      )
      (ST s (Vector (Int, Bit)) -> ST s (Vector Bit))
-> (Vector Int -> ST s (Vector (Int, Bit)))
-> Vector Int
-> ST s (Vector Bit)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> ST s (Int, Bit)) -> Vector Int -> ST s (Vector (Int, Bit))
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
VU.mapM
        ( \(Int
i :: Int) -> do
            let !u :: Int
u = Vector Int
us Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i
            let !v :: Int
v = Vector Int
vs Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i
            Bool
b <- Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Int -> Bool) -> ST s (Maybe Int) -> ST s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Dsu (PrimState (ST s)) -> Int -> Int -> ST s (Maybe Int)
forall (m :: * -> *).
(HasCallStack, PrimMonad m) =>
Dsu (PrimState m) -> Int -> Int -> m (Maybe Int)
Dsu.mergeMaybe Dsu s
Dsu (PrimState (ST s))
dsu Int
u Int
v
            Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
b (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
              MVector (PrimState (ST s)) w -> (w -> w) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s w
MVector (PrimState (ST s)) w
wSum (w -> w -> w
forall a. Num a => a -> a -> a
+ Vector w
ws Vector w -> Int -> w
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i) Int
0
            (Int, Bit) -> ST s (Int, Bit)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
i, Bool -> Bit
Bit Bool
b)
        )
      (Vector Int -> ST s (Vector (Int, Bit)))
-> (Vector Int -> Vector Int)
-> Vector Int
-> ST s (Vector (Int, Bit))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s. MVector s Int -> ST s ()) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
VU.modify (Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
VAI.sortBy (\(Int
i :: Int) (Int
j :: Int) -> w -> w -> Ordering
f (Vector w
ws Vector w -> Int -> w
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i) (Vector w
ws Vector w -> Int -> w
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
j)))
      (Vector Int -> ST s (Vector Bit))
-> Vector Int -> ST s (Vector Bit)
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Vector (Int, Int, w) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Int, w)
edges) Int -> Int
forall a. a -> a
id
  let !gr :: Csr w
gr = Int -> Vector (Int, Int, w) -> Csr w
forall w.
(HasCallStack, Unbox w) =>
Int -> Vector (Int, Int, w) -> Csr w
Gr.build Int
nVerts (Vector (Int, Int, w) -> Csr w) -> Vector (Int, Int, w) -> Csr w
forall a b. (a -> b) -> a -> b
$ Vector (Int, Int, w) -> Vector (Int, Int, w)
forall w. Unbox w => Vector (Int, Int, w) -> Vector (Int, Int, w)
Gr.swapDupe (Vector (Int, Int, w) -> Vector (Int, Int, w))
-> Vector (Int, Int, w) -> Vector (Int, Int, w)
forall a b. (a -> b) -> a -> b
$ (Int -> (Int, Int, w) -> Bool)
-> Vector (Int, Int, w) -> Vector (Int, Int, w)
forall a. Unbox a => (Int -> a -> Bool) -> Vector a -> Vector a
VU.ifilter (\Int
i (Int, Int, w)
_ -> Bit -> Bool
unBit (Vector Bit
use Vector Bit -> Int -> Bit
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i)) Vector (Int, Int, w)
edges
  (,Vector Bit
use,Csr w
gr) (w -> (w, Vector Bit, Csr w))
-> ST s w -> ST s (w, Vector Bit, Csr w)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s w
MVector (PrimState (ST s)) w
wSum Int
0
  where
    (!Vector Int
us, !Vector Int
vs, !Vector w
ws) = Vector (Int, Int, w) -> (Vector Int, Vector Int, Vector w)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector (a, b, c) -> (Vector a, Vector b, Vector c)
VU.unzip3 Vector (Int, Int, w)
edges

{-# INLINEABLE foldImpl #-}
foldImpl ::
  forall m w f a.
  (HasCallStack, Monad m, VU.Unbox w) =>
  (Int -> VU.Vector (Int, w)) ->
  (Int -> a) ->
  (a -> (Int, w) -> f) ->
  (f -> a -> a) ->
  Int ->
  (Int -> a -> m ()) ->
  m a
foldImpl :: forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root Int -> a -> m ()
memo = Int -> Int -> m a
inner (-Int
1) Int
root
  where
    inner :: Int -> Int -> m a
    inner :: Int -> Int -> m a
inner !Int
parent !Int
v1 = do
      let !acc0 :: a
acc0 = Int -> a
valAt Int
v1
      let !v2s :: Vector (Int, w)
v2s = ((Int, w) -> Bool) -> Vector (Int, w) -> Vector (Int, w)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
VU.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
parent) (Int -> Bool) -> ((Int, w) -> Int) -> (Int, w) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, w) -> Int
forall a b. (a, b) -> a
fst) (Vector (Int, w) -> Vector (Int, w))
-> Vector (Int, w) -> Vector (Int, w)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, w)
tree Int
v1
      !a
res <- (a -> (Int, w) -> m a) -> a -> Vector (Int, w) -> m a
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM' (\a
acc (!Int
v2, !w
w) -> (f -> a -> a
`act` a
acc) (f -> a) -> (a -> f) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m a
inner Int
v1 Int
v2) a
acc0 Vector (Int, w)
v2s
      Int -> a -> m ()
memo Int
v1 a
res
      a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res

-- | \(O(n)\) Folds a tree from a root vertex, also known as tree DP.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let gr = Gr.build @(Sum Int) 5 . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.fold (gr `Gr.adjW`) valAt toF act 2
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in getSum res
-- :}
-- 4
--
-- @since 1.1.0.0
{-# INLINE fold #-}
fold ::
  (HasCallStack, VU.Unbox w) =>
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@: Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Root vertex.
  Int ->
  -- | Tree folding result from the root vertex.
  a
fold :: forall w a f.
(HasCallStack, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a) -> (a -> (Int, w) -> f) -> (f -> a -> a) -> Int -> a
fold Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root = Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a) -> Identity a -> a
forall a b. (a -> b) -> a -> b
$ do
  (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> Identity ())
-> Identity a
forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root (\Int
_ a
_ -> () -> Identity ()
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- | \(O(n)\) Folds a tree from a root vertex, also known as tree DP. The calculation process on
-- every vertex is recoreded and returned as a vector.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let n = 5
-- >>> let gr = Gr.build @(Sum Int) n . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.scan n (gr `Gr.adjW`) valAt toF act 2
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in VU.map getSum res
-- :}
-- [0,1,4,1,0]
--
-- @since 1.1.0.0
{-# INLINE scan #-}
scan ::
  (VU.Unbox w, VG.Vector v a) =>
  -- | The number of vertices.
  Int ->
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@: Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Root vertex.
  Int ->
  -- | Tree scanning result from a root vertex.
  v a
scan :: forall w (v :: * -> *) a f.
(Unbox w, Vector v a) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> v a
scan Int
n Int -> Vector (Int, w)
tree Int -> a
acc0At a -> (Int, w) -> f
toF f -> a -> a
act Int
root = (forall s. ST s (Mutable v s a)) -> v a
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
VG.create ((forall s. ST s (Mutable v s a)) -> v a)
-> (forall s. ST s (Mutable v s a)) -> v a
forall a b. (a -> b) -> a -> b
$ do
  Mutable v s a
dp <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.unsafeNew Int
n
  !a
_ <- (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> ST s ())
-> ST s a
forall (m :: * -> *) w f a.
(HasCallStack, Monad m, Unbox w) =>
(Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> (Int -> a -> m ())
-> m a
foldImpl Int -> Vector (Int, w)
tree Int -> a
acc0At a -> (Int, w) -> f
toF f -> a -> a
act Int
root ((Int -> a -> ST s ()) -> ST s a)
-> (Int -> a -> ST s ()) -> ST s a
forall a b. (a -> b) -> a -> b
$ \Int
v a
a -> do
    Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
dp Int
v a
a
  Mutable v s a -> ST s (Mutable v s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Mutable v s a
dp

-- | \(O(n)\) Folds a tree from every vertex, using the rerooting technique.
--
-- ==== Constraints
-- - The action monoid \(f\) must be commutative.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Graph qualified as Gr
-- >>> import AtCoder.Extra.Tree qualified as Tree
-- >>> import Data.Semigroup (Sum (..))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let n = 5
-- >>> let gr = Gr.build @(Sum Int) n . Gr.swapDupe $ VU.fromList [(2, 1, Sum 1), (1, 0, Sum 1), (2, 3, Sum 1), (3, 4, Sum 1)]
-- >>> type W = Sum Int -- edge weight
-- >>> type F = Sum Int -- action type
-- >>> type X = Sum Int -- vertex value
-- >>> :{
--  let res = Tree.foldReroot n (gr `Gr.adjW`) valAt toF act
--        where
--          valAt :: Int -> X
--          valAt = const $ mempty @(Sum Int)
--          toF :: X -> (Int, W) -> F
--          toF x (!_i, !dx) = x + dx
--          act :: F -> X -> X
--          act dx x = dx + x
--   in VU.map getSum res
-- :}
-- [4,4,4,4,4]
--
-- @since 1.1.0.0
{-# INLINEABLE foldReroot #-}
foldReroot ::
  forall w f a.
  (HasCallStack, VU.Unbox w, VU.Unbox a, VU.Unbox f, Monoid f) =>
  -- | The number of vertices.
  Int ->
  -- | Graph as a function.
  (Int -> VU.Vector (Int, w)) ->
  -- | @valAt@:Assignment of initial vertex values.
  (Int -> a) ->
  -- | @toF@: Converts a vertex value into an action onto a neighbor vertex.
  (a -> (Int, w) -> f) ->
  -- | @act@: Performs an action onto a vertex value.
  (f -> a -> a) ->
  -- | Tree folding result from every vertex as a root.
  VU.Vector a
foldReroot :: forall w f a.
(HasCallStack, Unbox w, Unbox a, Unbox f, Monoid f) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Vector a
foldReroot Int
n Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  -- Calculate tree DP for every vertex as a root:
  !MVector (PrimState (ST s)) a
dp <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
n
  let reroot :: Int -> f -> Int -> ST s ()
reroot Int
parent f
parentF Int
v1 = do
        let !children :: Vector (Int, w)
children = ((Int, w) -> Bool) -> Vector (Int, w) -> Vector (Int, w)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
VU.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
parent) (Int -> Bool) -> ((Int, w) -> Int) -> (Int, w) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, w) -> Int
forall a b. (a, b) -> a
fst) (Vector (Int, w) -> Vector (Int, w))
-> Vector (Int, w) -> Vector (Int, w)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, w)
tree Int
v1
        let !fL :: Vector f
fL = (f -> (Int, w) -> f) -> f -> Vector (Int, w) -> Vector f
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
VU.scanl' (\ !f
f (!Int
v2, !w
w) -> (f
f <>) (f -> f) -> (a -> f) -> a -> f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> f) -> a -> f
forall a b. (a -> b) -> a -> b
$ Vector a
treeDp Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v2) f
f0 Vector (Int, w)
children
        let !fR :: Vector f
fR = ((Int, w) -> f -> f) -> f -> Vector (Int, w) -> Vector f
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> b) -> b -> Vector a -> Vector b
VU.scanr' (\(!Int
v2, !w
w) !f
f -> (f -> f -> f
forall a. Semigroup a => a -> a -> a
<> f
f) (f -> f) -> (a -> f) -> a -> f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (Int, w) -> f
`toF` (Int
v1, w
w)) (a -> f) -> a -> f
forall a b. (a -> b) -> a -> b
$ Vector a
treeDp Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
v2) f
f0 Vector (Int, w)
children

        -- save
        let !x1 :: a
x1 = (f
parentF f -> f -> f
forall a. Semigroup a => a -> a -> a
<> Vector f -> f
forall a. Unbox a => Vector a -> a
VU.last Vector f
fL) f -> a -> a
`act` Int -> a
valAt Int
v1
        MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState (ST s)) a
dp Int
v1 a
x1

        Vector (Int, w) -> (Int -> (Int, w) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector (Int, w)
children ((Int -> (Int, w) -> ST s ()) -> ST s ())
-> (Int -> (Int, w) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i2 (!Int
v2, !w
w) -> do
          -- composited operator excluding @v2@:
          let !f1 :: f
f1 = f
parentF f -> f -> f
forall a. Semigroup a => a -> a -> a
<> (Vector f
fL Vector f -> Int -> f
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i2) f -> f -> f
forall a. Semigroup a => a -> a -> a
<> (Vector f
fR Vector f -> Int -> f
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
          let !v1Acc :: a
v1Acc = f
f1 f -> a -> a
`act` Int -> a
valAt Int
v1
          let !f2 :: f
f2 = a -> (Int, w) -> f
toF a
v1Acc (Int
v2, w
w)
          Int -> f -> Int -> ST s ()
reroot Int
v1 f
f2 Int
v2

  Int -> f -> Int -> ST s ()
reroot (-Int
1 :: Int) f
f0 Int
root0
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
MVector (PrimState (ST s)) a
dp
  where
    !root0 :: Int
root0 = Int
0 :: Int
    !f0 :: f
f0 = forall a. Monoid a => a
mempty @f
    !treeDp :: Vector a
treeDp = Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> Vector a
forall w (v :: * -> *) a f.
(Unbox w, Vector v a) =>
Int
-> (Int -> Vector (Int, w))
-> (Int -> a)
-> (a -> (Int, w) -> f)
-> (f -> a -> a)
-> Int
-> v a
scan Int
n Int -> Vector (Int, w)
tree Int -> a
valAt a -> (Int, w) -> f
toF f -> a -> a
act Int
root0 :: VU.Vector a