-- |
-- Description : Fast, mutable sparse sets.
-- Copyright   : (c) Jeremy Nuttall, 2025
-- License     : BSD-3-Clause
-- Maintainer  : jeremy@jeremy-nuttall.com
-- Stability   : experimental
-- Portability : GHC
--
-- __This implementation is NOT thread-safe.__ Thread safety must be maintained by a whole-set
-- locking mechanism.
module Data.SparseSet.Storable.Mutable (
  MutableSparseSet,
  IOMutableSparseSet,
  STMutableSparseSet,

  -- * Creation
  withCapacity,
  new,

  -- * Read
  length,
  contains,
  members,
  lookup,

  -- * Update
  insert,
  delete,
  clear,
  compact,

  -- * Iteration
  foldM,
  ifoldM,
  mapM_,
  imapM_,
  ifoldIntersectionM,
)
where

import Control.Monad.Primitive
import Data.Typeable (Typeable)
import Data.Vector.Storable qualified as VS
import GHC.Generics (Generic)
import Prelude hiding (length, lookup, mapM_)

import Data.SparseSet.Generic.Mutable qualified as G

newtype MutableSparseSet s a = MSS (G.MutableSparseSet VS.MVector s a)
  deriving stock ((forall x. MutableSparseSet s a -> Rep (MutableSparseSet s a) x)
-> (forall x. Rep (MutableSparseSet s a) x -> MutableSparseSet s a)
-> Generic (MutableSparseSet s a)
forall x. Rep (MutableSparseSet s a) x -> MutableSparseSet s a
forall x. MutableSparseSet s a -> Rep (MutableSparseSet s a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall s a x. Rep (MutableSparseSet s a) x -> MutableSparseSet s a
forall s a x. MutableSparseSet s a -> Rep (MutableSparseSet s a) x
$cfrom :: forall s a x. MutableSparseSet s a -> Rep (MutableSparseSet s a) x
from :: forall x. MutableSparseSet s a -> Rep (MutableSparseSet s a) x
$cto :: forall s a x. Rep (MutableSparseSet s a) x -> MutableSparseSet s a
to :: forall x. Rep (MutableSparseSet s a) x -> MutableSparseSet s a
Generic, Typeable)

type IOMutableSparseSet = MutableSparseSet RealWorld
type STMutableSparseSet s = MutableSparseSet s

-- | Create a sparse set with a given dense and sparse capacity
--
-- It's a good idea to use this function if you have an estimate of your data requirements,
-- as it can prevent costly re-allocations as the set grows.
--
-- @since 0.1.0.0
withCapacity
  :: (PrimMonad m, VS.Storable a)
  => Int
  -- ^ Capacity for the dense set
  -> Int
  -- ^ Capacity for the sparse set
  -> m (MutableSparseSet (PrimState m) a)
withCapacity :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> Int -> m (MutableSparseSet (PrimState m) a)
withCapacity Int
dc Int
sc = MutableSparseSet MVector (PrimState m) a
-> MutableSparseSet (PrimState m) a
forall s a. MutableSparseSet MVector s a -> MutableSparseSet s a
MSS (MutableSparseSet MVector (PrimState m) a
 -> MutableSparseSet (PrimState m) a)
-> m (MutableSparseSet MVector (PrimState m) a)
-> m (MutableSparseSet (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MutableSparseSet MVector (PrimState m) a)
forall a (v :: * -> * -> *) (m :: * -> *).
(PrimMonad m, MVector v a) =>
Int -> Int -> m (MutableSparseSet v (PrimState m) a)
G.withCapacity Int
dc Int
sc

-- | Create an empty sparse set with default capacities
--
-- @since 0.1.0.0
new :: forall a m. (PrimMonad m, VS.Storable a) => m (MutableSparseSet (PrimState m) a)
new :: forall a (m :: * -> *).
(PrimMonad m, Storable a) =>
m (MutableSparseSet (PrimState m) a)
new = MutableSparseSet MVector (PrimState m) a
-> MutableSparseSet (PrimState m) a
forall s a. MutableSparseSet MVector s a -> MutableSparseSet s a
MSS (MutableSparseSet MVector (PrimState m) a
 -> MutableSparseSet (PrimState m) a)
-> m (MutableSparseSet MVector (PrimState m) a)
-> m (MutableSparseSet (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (MutableSparseSet MVector (PrimState m) a)
forall a (v :: * -> * -> *) (m :: * -> *).
(PrimMonad m, MVector v a) =>
m (MutableSparseSet v (PrimState m) a)
G.new
{-# INLINE new #-}

-- | O(1) Number of elements in the set (dense)
--
-- @since 0.1.0.0
length :: forall a m. (PrimMonad m) => MutableSparseSet (PrimState m) a -> m Int
length :: forall a (m :: * -> *).
PrimMonad m =>
MutableSparseSet (PrimState m) a -> m Int
length (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> m Int
forall a (v :: * -> * -> *) (m :: * -> *).
PrimMonad m =>
MutableSparseSet v (PrimState m) a -> m Int
G.length MutableSparseSet MVector (PrimState m) a
g
{-# INLINE length #-}

-- | O(1) Check whether an element is in the set
--
-- @since 0.1.0.0
contains :: (PrimMonad m) => MutableSparseSet (PrimState m) a -> Int -> m Bool
contains :: forall (m :: * -> *) a.
PrimMonad m =>
MutableSparseSet (PrimState m) a -> Int -> m Bool
contains (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> Int -> m Bool
forall a (v :: * -> * -> *) (m :: * -> *).
PrimMonad m =>
MutableSparseSet v (PrimState m) a -> Int -> m Bool
G.contains MutableSparseSet MVector (PrimState m) a
g
{-# INLINE contains #-}

-- | O(n) The members of the set in an unspecified order.
--
-- @since 0.1.0.0
members :: (PrimMonad m) => MutableSparseSet (PrimState m) a -> m (VS.Vector Int)
members :: forall (m :: * -> *) a.
PrimMonad m =>
MutableSparseSet (PrimState m) a -> m (Vector Int)
members (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> m (Vector Int)
forall (w :: * -> * -> *) a (v :: * -> *) (m :: * -> *).
(Vector v Int, PrimMonad m) =>
MutableSparseSet w (PrimState m) a -> m (v Int)
G.members MutableSparseSet MVector (PrimState m) a
g
{-# INLINE members #-}

-- | O(1) Look up an element in the set
--
-- @since 0.1.0.0
lookup :: (PrimMonad m, VS.Storable a) => MutableSparseSet (PrimState m) a -> Int -> m (Maybe a)
lookup :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MutableSparseSet (PrimState m) a -> Int -> m (Maybe a)
lookup (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> Int -> m (Maybe a)
forall a (v :: * -> * -> *) (m :: * -> *).
(PrimMonad m, MVector v a) =>
MutableSparseSet v (PrimState m) a -> Int -> m (Maybe a)
G.lookup MutableSparseSet MVector (PrimState m) a
g
{-# INLINE lookup #-}

-- | O(1) amortized. Insert a value for a given key.
--
-- If the key is already in the set, its value is overwritten.
--
-- __INVARIANT__: Keys cannot be negative. An unchecked exception is
-- thrown if a negative key is added to the set.
--
-- @since 0.1.0.0
insert :: (PrimMonad m, VS.Storable a) => MutableSparseSet (PrimState m) a -> Int -> a -> m ()
insert :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MutableSparseSet (PrimState m) a -> Int -> a -> m ()
insert (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> Int -> a -> m ()
forall a (v :: * -> * -> *) (m :: * -> *).
(HasCallStack, PrimMonad m, MVector v a) =>
MutableSparseSet v (PrimState m) a -> Int -> a -> m ()
G.insert MutableSparseSet MVector (PrimState m) a
g
{-# INLINE insert #-}

-- | O(1) Delete an element from the set
--
-- @since 0.1.0.0
delete :: (PrimMonad m, VS.Storable a) => MutableSparseSet (PrimState m) a -> Int -> m (Maybe a)
delete :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MutableSparseSet (PrimState m) a -> Int -> m (Maybe a)
delete (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> Int -> m (Maybe a)
forall a (v :: * -> * -> *) (m :: * -> *).
(PrimMonad m, MVector v a) =>
MutableSparseSet v (PrimState m) a -> Int -> m (Maybe a)
G.delete MutableSparseSet MVector (PrimState m) a
g
{-# INLINE delete #-}

-- | O(1) Clear all elements from the set
--
-- @since 0.1.0.0
clear :: (PrimMonad m) => MutableSparseSet (PrimState m) a -> m ()
clear :: forall (m :: * -> *) a.
PrimMonad m =>
MutableSparseSet (PrimState m) a -> m ()
clear (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> m ()
forall a (v :: * -> * -> *) (m :: * -> *).
PrimMonad m =>
MutableSparseSet v (PrimState m) a -> m ()
G.clear MutableSparseSet MVector (PrimState m) a
g
{-# INLINE clear #-}

-- | O(n) Shrink the capacity of the set to fit exactly the current number of elements.
--
-- @since 0.1.0.0
compact :: (PrimMonad m, VS.Storable a) => MutableSparseSet (PrimState m) a -> m ()
compact :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MutableSparseSet (PrimState m) a -> m ()
compact (MSS MutableSparseSet MVector (PrimState m) a
g) = MutableSparseSet MVector (PrimState m) a -> m ()
forall a (v :: * -> * -> *) (m :: * -> *).
(PrimMonad m, MVector v a) =>
MutableSparseSet v (PrimState m) a -> m ()
G.compact MutableSparseSet MVector (PrimState m) a
g
{-# INLINE compact #-}

-- | O(n) Fold over the values of the set.
--
-- @since 0.1.0.0
foldM
  :: (PrimMonad m, VS.Storable a)
  => (b -> a -> m b)
  -> b
  -> MutableSparseSet (PrimState m) a
  -> m b
foldM :: forall (m :: * -> *) a b.
(PrimMonad m, Storable a) =>
(b -> a -> m b) -> b -> MutableSparseSet (PrimState m) a -> m b
foldM b -> a -> m b
f b
initAcc (MSS MutableSparseSet MVector (PrimState m) a
g) = (b -> a -> m b)
-> b -> MutableSparseSet MVector (PrimState m) a -> m b
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
(b -> a -> m b) -> b -> MutableSparseSet v (PrimState m) a -> m b
G.foldM b -> a -> m b
f b
initAcc MutableSparseSet MVector (PrimState m) a
g
{-# INLINE foldM #-}

-- | O(n) Fold over the keys and values of the set.
--
-- @since 0.1.0.0
ifoldM
  :: (PrimMonad m, VS.Storable a)
  => (b -> (Int, a) -> m b)
  -> b
  -> MutableSparseSet (PrimState m) a
  -> m b
ifoldM :: forall (m :: * -> *) a b.
(PrimMonad m, Storable a) =>
(b -> (Int, a) -> m b)
-> b -> MutableSparseSet (PrimState m) a -> m b
ifoldM b -> (Int, a) -> m b
f b
initAcc (MSS MutableSparseSet MVector (PrimState m) a
g) = (b -> (Int, a) -> m b)
-> b -> MutableSparseSet MVector (PrimState m) a -> m b
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
(b -> (Int, a) -> m b)
-> b -> MutableSparseSet v (PrimState m) a -> m b
G.ifoldM b -> (Int, a) -> m b
f b
initAcc MutableSparseSet MVector (PrimState m) a
g
{-# INLINE ifoldM #-}

-- | O(n) Iterate over the values of the set.
--
-- @since 0.1.0.0
mapM_
  :: (PrimMonad m, VS.Storable a)
  => (a -> m ()) -- Action to perform
  -> MutableSparseSet (PrimState m) a
  -> m ()
mapM_ :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
(a -> m ()) -> MutableSparseSet (PrimState m) a -> m ()
mapM_ a -> m ()
f (MSS MutableSparseSet MVector (PrimState m) a
g) = (a -> m ()) -> MutableSparseSet MVector (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
(a -> m ()) -> MutableSparseSet v (PrimState m) a -> m ()
G.mapM_ a -> m ()
f MutableSparseSet MVector (PrimState m) a
g
{-# INLINE mapM_ #-}

-- | O(n) Iterate over the keys and values of the set.
--
-- @since 0.1.0.0
imapM_
  :: (PrimMonad m, VS.Storable a)
  => ((Int, a) -> m ()) -- Action to perform
  -> MutableSparseSet (PrimState m) a
  -> m ()
imapM_ :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
((Int, a) -> m ()) -> MutableSparseSet (PrimState m) a -> m ()
imapM_ (Int, a) -> m ()
f (MSS MutableSparseSet MVector (PrimState m) a
g) = ((Int, a) -> m ())
-> MutableSparseSet MVector (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
((Int, a) -> m ()) -> MutableSparseSet v (PrimState m) a -> m ()
G.imapM_ (Int, a) -> m ()
f MutableSparseSet MVector (PrimState m) a
g
{-# INLINE imapM_ #-}

-- | O(min(n, m)) Iterate over the intersection of two sets with an accumulator.
--
-- The order of the arguments does not matter - the smaller of the two sets is
-- selected as the iteratee.
--
-- @since 0.1.0.0
ifoldIntersectionM
  :: (PrimMonad m, VS.Storable a, VS.Storable b)
  => (c -> Int -> a -> b -> m c)
  -> c
  -> MutableSparseSet (PrimState m) a
  -> MutableSparseSet (PrimState m) b
  -> m c
ifoldIntersectionM :: forall (m :: * -> *) a b c.
(PrimMonad m, Storable a, Storable b) =>
(c -> Int -> a -> b -> m c)
-> c
-> MutableSparseSet (PrimState m) a
-> MutableSparseSet (PrimState m) b
-> m c
ifoldIntersectionM c -> Int -> a -> b -> m c
acc c
c (MSS MutableSparseSet MVector (PrimState m) a
a) (MSS MutableSparseSet MVector (PrimState m) b
b) = (c -> Int -> a -> b -> m c)
-> c
-> MutableSparseSet MVector (PrimState m) a
-> MutableSparseSet MVector (PrimState m) b
-> m c
forall (m :: * -> *) (v :: * -> * -> *) a b c.
(PrimMonad m, MVector v a, MVector v b) =>
(c -> Int -> a -> b -> m c)
-> c
-> MutableSparseSet v (PrimState m) a
-> MutableSparseSet v (PrimState m) b
-> m c
G.ifoldIntersectionM c -> Int -> a -> b -> m c
acc c
c MutableSparseSet MVector (PrimState m) a
a MutableSparseSet MVector (PrimState m) b
b
{-# INLINE ifoldIntersectionM #-}