{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | This module corresponds to "Control.Concurrent.STM.TVar" in the @stm@ package.
--
-- This module can be used as a drop-in replacement for
-- "Control.Concurrent.Class.MonadSTM.Strict.TVar", but not the other way
-- around.
module Control.Concurrent.Class.MonadSTM.Strict.TVar.Checked
  ( -- * StrictTVar
    LazyTVar
  , StrictTVar
  , castStrictTVar
  , fromLazyTVar
  , modifyTVar
  , newTVar
  , newTVarIO
  , newTVarWithInvariant
  , newTVarWithInvariantIO
  , readTVar
  , readTVarIO
  , stateTVar
  , swapTVar
  , toLazyTVar
  , unsafeToUncheckedStrictTVar
  , writeTVar

    -- * MonadLabelSTM
  , labelTVar
  , labelTVarIO

    -- * MonadTraceSTM
  , traceTVar
  , traceTVarIO

    -- * Invariant
  , checkInvariant
  ) where

import Control.Concurrent.Class.MonadSTM
  ( InspectMonadSTM
  , MonadLabelledSTM
  , MonadSTM
  , MonadTraceSTM
  , STM
  , TraceValue
  , atomically
  )
import Control.Concurrent.Class.MonadSTM.Strict.TVar qualified as Strict
import GHC.Stack (HasCallStack)

{-------------------------------------------------------------------------------
  StrictTVar
-------------------------------------------------------------------------------}

type LazyTVar m = Strict.LazyTVar m

#if CHECK_TVAR_INVARIANTS
data StrictTVar m a = StrictTVar {
    -- | Invariant checked whenever updating the 'StrictTVar'.
    forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
invariant :: !(a -> Maybe String)
  , forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar      :: !(Strict.StrictTVar m a)
  }
#else
newtype StrictTVar m a = StrictTVar {
    tvar :: Strict.StrictTVar m a
  }
#endif

castStrictTVar ::
  LazyTVar m ~ LazyTVar n =>
  StrictTVar m a -> StrictTVar n a
castStrictTVar :: forall (m :: * -> *) (n :: * -> *) a.
(LazyTVar m ~ LazyTVar n) =>
StrictTVar m a -> StrictTVar n a
castStrictTVar StrictTVar m a
v = (a -> Maybe String) -> StrictTVar n a -> StrictTVar n a
forall a (m :: * -> *).
(a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
mkStrictTVar (StrictTVar m a -> a -> Maybe String
forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
getInvariant StrictTVar m a
v) (StrictTVar m a -> StrictTVar n a
forall (m :: * -> *) (n :: * -> *) a.
(LazyTVar m ~ LazyTVar n) =>
StrictTVar m a -> StrictTVar n a
Strict.castStrictTVar (StrictTVar m a -> StrictTVar n a)
-> StrictTVar m a -> StrictTVar n a
forall a b. (a -> b) -> a -> b
$ StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar StrictTVar m a
v)

-- | Get the underlying @TVar@
--
-- Since we obviously cannot guarantee that updates to this 'LazyTVar' will be
-- strict, this should be used with caution.
--
-- Similarly, we can not guarantee that updates to this 'LazyTVar' do not break
-- the original invariant that the 'StrictTVar' held.
toLazyTVar :: StrictTVar m a -> LazyTVar m a
toLazyTVar :: forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
toLazyTVar = StrictTVar m a -> TVar m a
forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
Strict.toLazyTVar (StrictTVar m a -> TVar m a)
-> (StrictTVar m a -> StrictTVar m a) -> StrictTVar m a -> TVar m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar

-- | Create a 'StrictMVar' from a 'LazyMVar'
--
-- It is not guaranteed that the 'LazyTVar' contains a value that is in WHNF, so
-- there is no guarantee that the resulting 'StrictTVar' contains a value that
-- is in WHNF. This should be used with caution.
--
-- The resulting 'StrictTVar' has a trivial invariant.
fromLazyTVar :: LazyTVar m a -> StrictTVar m a
fromLazyTVar :: forall (m :: * -> *) a. LazyTVar m a -> StrictTVar m a
fromLazyTVar = (a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
forall a (m :: * -> *).
(a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
mkStrictTVar (Maybe String -> a -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing) (StrictTVar m a -> StrictTVar m a)
-> (TVar m a -> StrictTVar m a) -> TVar m a -> StrictTVar m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar m a -> StrictTVar m a
forall (m :: * -> *) a. LazyTVar m a -> StrictTVar m a
Strict.fromLazyTVar

-- | Create an unchecked reference to the given checked 'StrictTVar'.
--
-- Note that the invariant is only guaranteed when modifying the checked TVar.
-- Any modification to the unchecked reference might break the invariants.
unsafeToUncheckedStrictTVar :: StrictTVar m a -> Strict.StrictTVar m a
unsafeToUncheckedStrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
unsafeToUncheckedStrictTVar = StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar

newTVar :: MonadSTM m => a -> STM m (StrictTVar m a)
newTVar :: forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
newTVar a
a = (a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
forall a (m :: * -> *).
(a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
mkStrictTVar (Maybe String -> a -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing) (StrictTVar m a -> StrictTVar m a)
-> STM m (StrictTVar m a) -> STM m (StrictTVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> STM m (StrictTVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
Strict.newTVar a
a

newTVarIO :: MonadSTM m => a -> m (StrictTVar m a)
newTVarIO :: forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO = (a -> Maybe String) -> a -> m (StrictTVar m a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> m (StrictTVar m a)
newTVarWithInvariantIO (Maybe String -> a -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing)

newTVarWithInvariant ::
  (MonadSTM m, HasCallStack) =>
  (a -> Maybe String) ->
  a ->
  STM m (StrictTVar m a)
newTVarWithInvariant :: forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> STM m (StrictTVar m a)
newTVarWithInvariant a -> Maybe String
inv !a
a =
  Maybe String -> STM m (StrictTVar m a) -> STM m (StrictTVar m a)
forall a. HasCallStack => Maybe String -> a -> a
checkInvariant (a -> Maybe String
inv a
a) (STM m (StrictTVar m a) -> STM m (StrictTVar m a))
-> STM m (StrictTVar m a) -> STM m (StrictTVar m a)
forall a b. (a -> b) -> a -> b
$
    (a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
forall a (m :: * -> *).
(a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
mkStrictTVar a -> Maybe String
inv (StrictTVar m a -> StrictTVar m a)
-> STM m (StrictTVar m a) -> STM m (StrictTVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> STM m (StrictTVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
Strict.newTVar a
a

newTVarWithInvariantIO ::
  (MonadSTM m, HasCallStack) =>
  (a -> Maybe String) ->
  a ->
  m (StrictTVar m a)
newTVarWithInvariantIO :: forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> m (StrictTVar m a)
newTVarWithInvariantIO a -> Maybe String
inv !a
a =
  Maybe String -> m (StrictTVar m a) -> m (StrictTVar m a)
forall a. HasCallStack => Maybe String -> a -> a
checkInvariant (a -> Maybe String
inv a
a) (m (StrictTVar m a) -> m (StrictTVar m a))
-> m (StrictTVar m a) -> m (StrictTVar m a)
forall a b. (a -> b) -> a -> b
$
    (a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
forall a (m :: * -> *).
(a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
mkStrictTVar a -> Maybe String
inv (StrictTVar m a -> StrictTVar m a)
-> m (StrictTVar m a) -> m (StrictTVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> m (StrictTVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
Strict.newTVarIO a
a

readTVar :: MonadSTM m => StrictTVar m a -> STM m a
readTVar :: forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar = StrictTVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
Strict.readTVar (StrictTVar m a -> STM m a)
-> (StrictTVar m a -> StrictTVar m a) -> StrictTVar m a -> STM m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar

readTVarIO :: MonadSTM m => StrictTVar m a -> m a
readTVarIO :: forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO = StrictTVar m a -> m a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
Strict.readTVarIO (StrictTVar m a -> m a)
-> (StrictTVar m a -> StrictTVar m a) -> StrictTVar m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar

writeTVar :: (MonadSTM m, HasCallStack) => StrictTVar m a -> a -> STM m ()
writeTVar :: forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m a
v !a
a =
  Maybe String -> STM m () -> STM m ()
forall a. HasCallStack => Maybe String -> a -> a
checkInvariant (StrictTVar m a -> a -> Maybe String
forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
getInvariant StrictTVar m a
v a
a) (STM m () -> STM m ()) -> STM m () -> STM m ()
forall a b. (a -> b) -> a -> b
$
    StrictTVar m a -> a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
Strict.writeTVar (StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar StrictTVar m a
v) a
a

modifyTVar :: MonadSTM m => StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m a
v a -> a
f = StrictTVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m a
v STM m a -> (a -> STM m ()) -> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StrictTVar m a -> a -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m a
v (a -> STM m ()) -> (a -> a) -> a -> STM m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f

stateTVar :: MonadSTM m => StrictTVar m s -> (s -> (a, s)) -> STM m a
stateTVar :: forall (m :: * -> *) s a.
MonadSTM m =>
StrictTVar m s -> (s -> (a, s)) -> STM m a
stateTVar StrictTVar m s
v s -> (a, s)
f = do
  a <- StrictTVar m s -> STM m s
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m s
v
  let (b, a') = f a
  writeTVar v a'
  return b

swapTVar :: MonadSTM m => StrictTVar m a -> a -> STM m a
swapTVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m a
swapTVar StrictTVar m a
v a
a' = do
  a <- StrictTVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m a
v
  writeTVar v a'
  return a

--
-- Dealing with invariants
--

-- | Check invariant (if enabled) before continuing
--
-- @checkInvariant mErr x@ is equal to @x@ if @mErr == Nothing@, and throws
-- an error @err@ if @mErr == Just err@.
--
-- This is exported so that other code that wants to conditionally check
-- invariants can reuse the same logic, rather than having to introduce new
-- per-package flags.
checkInvariant :: HasCallStack => Maybe String -> a -> a
getInvariant :: StrictTVar m a -> a -> Maybe String
mkStrictTVar :: (a -> Maybe String) -> Strict.StrictTVar m a -> StrictTVar m a

#if CHECK_TVAR_INVARIANTS
checkInvariant :: forall a. HasCallStack => Maybe String -> a -> a
checkInvariant Maybe String
Nothing    a
k = a
k
checkInvariant (Just String
err) a
_ = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"StrictTVar invariant violation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
err
getInvariant :: forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
getInvariant StrictTVar {a -> Maybe String
invariant :: forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
invariant :: a -> Maybe String
invariant} = a -> Maybe String
invariant
mkStrictTVar :: forall a (m :: * -> *).
(a -> Maybe String) -> StrictTVar m a -> StrictTVar m a
mkStrictTVar a -> Maybe String
invariant  StrictTVar m a
tvar        = StrictTVar {a -> Maybe String
invariant :: a -> Maybe String
invariant :: a -> Maybe String
invariant, StrictTVar m a
tvar :: StrictTVar m a
tvar :: StrictTVar m a
tvar}
#else
checkInvariant _err       k  = k
getInvariant _               = const Nothing
mkStrictTVar _invariant tvar = StrictTVar {tvar}
#endif

{-------------------------------------------------------------------------------
  MonadLabelledSTM
-------------------------------------------------------------------------------}

labelTVar :: MonadLabelledSTM m => StrictTVar m a -> String -> STM m ()
labelTVar :: forall (m :: * -> *) a.
MonadLabelledSTM m =>
StrictTVar m a -> String -> STM m ()
labelTVar = StrictTVar m a -> String -> STM m ()
forall (m :: * -> *) a.
MonadLabelledSTM m =>
StrictTVar m a -> String -> STM m ()
Strict.labelTVar (StrictTVar m a -> String -> STM m ())
-> (StrictTVar m a -> StrictTVar m a)
-> StrictTVar m a
-> String
-> STM m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar

labelTVarIO :: MonadLabelledSTM m => StrictTVar m a -> String -> m ()
labelTVarIO :: forall (m :: * -> *) a.
MonadLabelledSTM m =>
StrictTVar m a -> String -> m ()
labelTVarIO StrictTVar m a
v = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> (String -> STM m ()) -> String -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> String -> STM m ()
forall (m :: * -> *) a.
MonadLabelledSTM m =>
StrictTVar m a -> String -> STM m ()
labelTVar StrictTVar m a
v

{-------------------------------------------------------------------------------
  MonadTraceSTM
-------------------------------------------------------------------------------}

traceTVar ::
  MonadTraceSTM m =>
  proxy m ->
  StrictTVar m a ->
  (Maybe a -> a -> InspectMonadSTM m TraceValue) ->
  STM m ()
traceTVar :: forall (m :: * -> *) (proxy :: (* -> *) -> *) a.
MonadTraceSTM m =>
proxy m
-> StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue)
-> STM m ()
traceTVar proxy m
p = proxy m
-> StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue)
-> STM m ()
forall (m :: * -> *) (proxy :: (* -> *) -> *) a.
MonadTraceSTM m =>
proxy m
-> StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue)
-> STM m ()
Strict.traceTVar proxy m
p (StrictTVar m a
 -> (Maybe a -> a -> InspectMonadSTM m TraceValue) -> STM m ())
-> (StrictTVar m a -> StrictTVar m a)
-> StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue)
-> STM m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar

traceTVarIO ::
  MonadTraceSTM m =>
  StrictTVar m a ->
  (Maybe a -> a -> InspectMonadSTM m TraceValue) ->
  m ()
traceTVarIO :: forall (m :: * -> *) a.
MonadTraceSTM m =>
StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue) -> m ()
traceTVarIO = StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue) -> m ()
forall (m :: * -> *) a.
MonadTraceSTM m =>
StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue) -> m ()
Strict.traceTVarIO (StrictTVar m a
 -> (Maybe a -> a -> InspectMonadSTM m TraceValue) -> m ())
-> (StrictTVar m a -> StrictTVar m a)
-> StrictTVar m a
-> (Maybe a -> a -> InspectMonadSTM m TraceValue)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m a -> StrictTVar m a
forall (m :: * -> *) a. StrictTVar m a -> StrictTVar m a
tvar