{-# LANGUAGE BangPatterns, ForeignFunctionInterface #-}
-- | An atomic integer value. All operations are thread safe.
module Data.Atomic
    (
      Atomic
    , new
    , read
    , write
    , inc
    , dec
    , add
    , subtract
    ) where

import Data.Int (Int64)
import Foreign.ForeignPtr (ForeignPtr, mallocForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (poke)
import Prelude hiding (read, subtract)

-- | A mutable, atomic integer.
newtype Atomic = C (ForeignPtr Int64)

-- | Create a new, zero initialized, atomic.
new :: Int64 -> IO Atomic
new :: Int64 -> IO Atomic
new Int64
n = do
    ForeignPtr Int64
fp <- IO (ForeignPtr Int64)
forall a. Storable a => IO (ForeignPtr a)
mallocForeignPtr
    ForeignPtr Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Int64
fp ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Ptr Int64
p -> Ptr Int64 -> Int64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Int64
p Int64
n
    Atomic -> IO Atomic
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Atomic -> IO Atomic) -> Atomic -> IO Atomic
forall a b. (a -> b) -> a -> b
$ ForeignPtr Int64 -> Atomic
C ForeignPtr Int64
fp

read :: Atomic -> IO Int64
read :: Atomic -> IO Int64
read (C ForeignPtr Int64
fp) = ForeignPtr Int64 -> (Ptr Int64 -> IO Int64) -> IO Int64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Int64
fp Ptr Int64 -> IO Int64
cRead

foreign import ccall unsafe "hs_atomic_read" cRead :: Ptr Int64 -> IO Int64

-- | Set the atomic to the given value.
write :: Atomic -> Int64 -> IO ()
write :: Atomic -> Int64 -> IO ()
write (C ForeignPtr Int64
fp) Int64
n = ForeignPtr Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Int64
fp ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Ptr Int64
p -> Ptr Int64 -> Int64 -> IO ()
cWrite Ptr Int64
p Int64
n

foreign import ccall unsafe "hs_atomic_write" cWrite
    :: Ptr Int64 -> Int64 -> IO ()

-- | Increase the atomic by one.
inc :: Atomic -> IO ()
inc :: Atomic -> IO ()
inc Atomic
atomic = Atomic -> Int64 -> IO ()
add Atomic
atomic Int64
1

-- | Decrease the atomic by one.
dec :: Atomic -> IO ()
dec :: Atomic -> IO ()
dec Atomic
atomic = Atomic -> Int64 -> IO ()
subtract Atomic
atomic Int64
1

-- | Increase the atomic by the given amount.
add :: Atomic -> Int64 -> IO ()
add :: Atomic -> Int64 -> IO ()
add (C ForeignPtr Int64
fp) Int64
n = ForeignPtr Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Int64
fp ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Ptr Int64
p -> Ptr Int64 -> Int64 -> IO ()
cAdd Ptr Int64
p Int64
n

-- | Decrease the atomic by the given amount.
subtract :: Atomic -> Int64 -> IO ()
subtract :: Atomic -> Int64 -> IO ()
subtract (C ForeignPtr Int64
fp) Int64
n = ForeignPtr Int64 -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Int64
fp ((Ptr Int64 -> IO ()) -> IO ()) -> (Ptr Int64 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Ptr Int64
p -> Ptr Int64 -> Int64 -> IO ()
cSubtract Ptr Int64
p Int64
n

-- | Increase the atomic by the given amount.
foreign import ccall unsafe "hs_atomic_add" cAdd :: Ptr Int64 -> Int64 -> IO ()

-- | Increase the atomic by the given amount.
foreign import ccall unsafe "hs_atomic_subtract" cSubtract
    :: Ptr Int64 -> Int64 -> IO ()