{-# LANGUAGE RecordWildCards #-}

-- | Minimum binary heap. Mutable and fixed-sized.
--
-- <https://en.wikipedia.org/wiki/Binary_heap>
--
-- ==== __Example__
-- >>> import AtCoder.Internal.MinHeap qualified as MH
-- >>> heap <- MH.new @_ @Int 4
-- >>> MH.capacity heap
-- 4
--
-- >>> MH.push heap 10
-- >>> MH.push heap 0
-- >>> MH.push heap 5
-- >>> MH.length heap -- [0, 5, 10]
-- 3
--
-- >>> MH.pop heap    -- [5, 10]
-- Just 0
--
-- >>> MH.peek heap   -- [5, 10]
-- Just 5
--
-- >>> MH.pop heap    -- [10]
-- Just 5
--
-- >>> MH.clear heap  -- []
-- >>> MH.null heap
-- True
--
-- @since 1.0.0.0
module AtCoder.Internal.MinHeap
  ( -- * Heap
    Heap,

    -- * Constructors
    new,

    -- * Metadata
    capacity,
    length,
    null,

    -- * Reset
    clear,

    -- * Push/pop/peek
    push,
    pop,
    pop_,
    peek,
  )
where

import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Control.Monad.ST (ST)
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
import Prelude hiding (length, null)

-- | Minimum binary heap. Mutable and fixed-sized.
--
-- Indices are zero-based.
--
-- @
--     0
--   1   2
--  3 4 5 6
-- @
--
-- INVARIANT (min heap): child values are bigger than or equal to their parent value.
--
-- @since 1.0.0.0
data Heap s a = Heap
  { -- | Size of the heap.
    --
    -- @since 1.2.4.0
    forall s a. Heap s a -> MVector s Int
sizeH :: !(VUM.MVector s Int),
    -- | Storage.
    --
    -- @since 1.2.4.0
    forall s a. Heap s a -> MVector s a
dataH :: !(VUM.MVector s a)
  }

-- | \(O(n)\) Creates a `Heap` with capacity \(n\).
--
-- @since 1.0.0.0
{-# INLINE new #-}
new :: (PrimMonad m, VU.Unbox a) => Int -> m (Heap (PrimState m) a)
new :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (Heap (PrimState m) a)
new Int
n = do
  MVector (PrimState m) Int
sizeH <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
1 Int
0
  MVector (PrimState m) a
dataH <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew Int
n
  Heap (PrimState m) a -> m (Heap (PrimState m) a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Heap {MVector (PrimState m) a
MVector (PrimState m) Int
sizeH :: MVector (PrimState m) Int
dataH :: MVector (PrimState m) a
sizeH :: MVector (PrimState m) Int
dataH :: MVector (PrimState m) a
..}

-- | \(O(1)\) Returns the maximum number of elements in the heap.
--
-- @since 1.0.0.0
{-# INLINE capacity #-}
capacity :: (VU.Unbox a) => Heap s a -> Int
capacity :: forall a s. Unbox a => Heap s a -> Int
capacity = MVector s a -> Int
forall a s. Unbox a => MVector s a -> Int
VUM.length (MVector s a -> Int)
-> (Heap s a -> MVector s a) -> Heap s a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap s a -> MVector s a
forall s a. Heap s a -> MVector s a
dataH

-- | \(O(1)\) Returns the number of elements in the heap.
--
-- @since 1.0.0.0
{-# INLINE length #-}
length :: (PrimMonad m, VU.Unbox a) => Heap (PrimState m) a -> m Int
length :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Heap (PrimState m) a -> m Int
length Heap {MVector (PrimState m) Int
sizeH :: forall s a. Heap s a -> MVector s Int
sizeH :: MVector (PrimState m) Int
sizeH} = MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
sizeH Int
0

-- | \(O(1)\) Returns `True` if the heap is empty.
--
-- @since 1.0.0.0
{-# INLINE null #-}
null :: (PrimMonad m, VU.Unbox a) => Heap (PrimState m) a -> m Bool
null :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Heap (PrimState m) a -> m Bool
null = (Int -> Bool) -> m Int -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
(<$>) (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (m Int -> m Bool)
-> (Heap (PrimState m) a -> m Int)
-> Heap (PrimState m) a
-> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap (PrimState m) a -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Heap (PrimState m) a -> m Int
length

-- | \(O(1)\) Sets the `length` to zero.
--
-- @since 1.0.0.0
{-# INLINE clear #-}
clear :: (PrimMonad m, VU.Unbox a) => Heap (PrimState m) a -> m ()
clear :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Heap (PrimState m) a -> m ()
clear Heap {MVector (PrimState m) Int
sizeH :: forall s a. Heap s a -> MVector s Int
sizeH :: MVector (PrimState m) Int
sizeH} = MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
sizeH Int
0 Int
0

-- | \(O(\log n)\) Inserts an element to the heap.
--
-- @since 1.0.0.0
{-# INLINE push #-}
push :: (HasCallStack, PrimMonad m, Ord a, VU.Unbox a) => Heap (PrimState m) a -> a -> m ()
push :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Ord a, Unbox a) =>
Heap (PrimState m) a -> a -> m ()
push Heap (PrimState m) a
heap a
x = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ Heap (PrimState m) a -> a -> ST (PrimState m) ()
forall a s.
(HasCallStack, Ord a, Unbox a) =>
Heap s a -> a -> ST s ()
pushST Heap (PrimState m) a
heap a
x

-- | \(O(\log n)\) Removes the last element from the heap and returns it, or `Nothing` if it is
-- empty.
--
-- @since 1.0.0.0
{-# INLINE pop #-}
pop :: (HasCallStack, PrimMonad m, Ord a, VU.Unbox a) => Heap (PrimState m) a -> m (Maybe a)
pop :: forall (m :: * -> *) a.
(HasCallStack, PrimMonad m, Ord a, Unbox a) =>
Heap (PrimState m) a -> m (Maybe a)
pop Heap (PrimState m) a
heap = ST (PrimState m) (Maybe a) -> m (Maybe a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (Maybe a) -> m (Maybe a))
-> ST (PrimState m) (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ Heap (PrimState m) a -> ST (PrimState m) (Maybe a)
forall a s.
(HasCallStack, Ord a, Unbox a) =>
Heap s a -> ST s (Maybe a)
popST Heap (PrimState m) a
heap

-- | \(O(\log n)\) `pop` with the return value discarded.
--
-- @since 1.0.0.0
{-# INLINE pop_ #-}
pop_ :: (HasCallStack, Ord a, VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m ()
pop_ :: forall a (m :: * -> *).
(HasCallStack, Ord a, Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m ()
pop_ Heap (PrimState m) a
heap = do
  Maybe a
_ <- ST (PrimState m) (Maybe a) -> m (Maybe a)
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) (Maybe a) -> m (Maybe a))
-> ST (PrimState m) (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ Heap (PrimState m) a -> ST (PrimState m) (Maybe a)
forall a s.
(HasCallStack, Ord a, Unbox a) =>
Heap s a -> ST s (Maybe a)
popST Heap (PrimState m) a
heap
  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | \(O(1)\) Returns the smallest value in the heap, or `Nothing` if it is empty.
--
-- @since 1.0.0.0
{-# INLINE peek #-}
peek :: (VU.Unbox a, PrimMonad m) => Heap (PrimState m) a -> m (Maybe a)
peek :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Heap (PrimState m) a -> m (Maybe a)
peek Heap (PrimState m) a
heap = do
  Bool
isNull <- Heap (PrimState m) a -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Heap (PrimState m) a -> m Bool
null Heap (PrimState m) a
heap
  if Bool
isNull
    then Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
    else a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read (Heap (PrimState m) a -> MVector (PrimState m) a
forall s a. Heap s a -> MVector s a
dataH Heap (PrimState m) a
heap) Int
0

-- -------------------------------------------------------------------------------------------------
-- Internal
-- -------------------------------------------------------------------------------------------------

{-# INLINEABLE pushST #-}
pushST :: (HasCallStack, Ord a, VU.Unbox a) => Heap s a -> a -> ST s ()
pushST :: forall a s.
(HasCallStack, Ord a, Unbox a) =>
Heap s a -> a -> ST s ()
pushST Heap {MVector s a
MVector s Int
sizeH :: forall s a. Heap s a -> MVector s Int
dataH :: forall s a. Heap s a -> MVector s a
sizeH :: MVector s Int
dataH :: MVector s a
..} a
x = do
  Int
i0 <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
sizeH Int
0
  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
dataH Int
i0 a
x
  MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
sizeH Int
0 (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
i0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  let siftUp :: Int -> f ()
siftUp Int
i = Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$ do
        let iParent :: Int
iParent = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        a
xParent <- MVector (PrimState f) a -> Int -> f a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s a
MVector (PrimState f) a
dataH Int
iParent
        Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
xParent) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$ do
          MVector (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector s a
MVector (PrimState f) a
dataH Int
iParent Int
i
          Int -> f ()
siftUp Int
iParent
  Int -> ST s ()
forall {f :: * -> *}. (PrimState f ~ s, PrimMonad f) => Int -> f ()
siftUp Int
i0

{-# INLINEABLE popST #-}
popST :: (HasCallStack, Ord a, VU.Unbox a) => Heap s a -> ST s (Maybe a)
popST :: forall a s.
(HasCallStack, Ord a, Unbox a) =>
Heap s a -> ST s (Maybe a)
popST heap :: Heap s a
heap@Heap {MVector s a
MVector s Int
sizeH :: forall s a. Heap s a -> MVector s Int
dataH :: forall s a. Heap s a -> MVector s a
sizeH :: MVector s Int
dataH :: MVector s a
..} = do
  Int
len <- Heap (PrimState (ST s)) a -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Heap (PrimState m) a -> m Int
length Heap s a
Heap (PrimState (ST s)) a
heap
  if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    then Maybe a -> ST s (Maybe a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
    else do
      let n :: Int
n = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
      MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
sizeH Int
0 Int
n
      -- copy the last element to the root
      a
root <- MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s a
MVector (PrimState (ST s)) a
dataH Int
0
      MVector (PrimState (ST s)) a -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector s a
MVector (PrimState (ST s)) a
dataH Int
0 Int
n

      -- xl <= xr <= x
      let siftDown :: Int -> f ()
siftDown Int
i = do
            let il :: Int
il = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            let ir :: Int
ir = Int
il Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
il Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$ do
              a
x <- MVector (PrimState f) a -> Int -> f a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s a
MVector (PrimState f) a
dataH Int
i
              a
xl <- MVector (PrimState f) a -> Int -> f a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s a
MVector (PrimState f) a
dataH Int
il
              if Int
ir Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
                then do
                  -- IMPORTANT: swap with the smaller child
                  a
xr <- MVector (PrimState f) a -> Int -> f a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector s a
MVector (PrimState f) a
dataH Int
ir
                  if a
xl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
xr Bool -> Bool -> Bool
&& a
xl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x
                    then do
                      MVector (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector s a
MVector (PrimState f) a
dataH Int
i Int
il
                      Int -> f ()
siftDown Int
il
                    else Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
xr a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$ do
                      MVector (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector s a
MVector (PrimState f) a
dataH Int
i Int
ir
                      Int -> f ()
siftDown Int
ir
                else Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
xl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$ do
                  MVector (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.swap MVector s a
MVector (PrimState f) a
dataH Int
i Int
il
                  Int -> f ()
siftDown Int
il

      Int -> ST s ()
forall {f :: * -> *}. (PrimState f ~ s, PrimMonad f) => Int -> f ()
siftDown Int
0
      Maybe a -> ST s (Maybe a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> ST s (Maybe a)) -> Maybe a -> ST s (Maybe a)
forall a b. (a -> b) -> a -> b
$ a -> Maybe a
forall a. a -> Maybe a
Just a
root