{-# LANGUAGE BangPatterns
           , CPP
           , ForeignFunctionInterface
           , MagicHash
           , UnboxedTuples
           #-}
-- | An atomic integer value. All operations are thread safe.
module Data.Atomic
    (
      Atomic
    , new
    , read
    , write
    , inc
    , dec
    , add
    , subtract
    ) where

#include "MachDeps.h"
#ifndef SIZEOF_HSINT
#error "MachDeps.h didn't define SIZEOF_HSINT"
#endif

import Prelude hiding (read, subtract)

import GHC.Int

#if SIZEOF_HSINT == 8

-- 64-bit imports
import GHC.IO
import GHC.Prim

#else

-- 32-bit imports
import Data.IORef

#endif


-- 64-bit machine, Int ~ Int64, do it the fast way:
#if SIZEOF_HSINT == 8

#if MIN_VERSION_base(4,17,0)
int64ToInt :: Int64# -> Int#
int64ToInt :: Int64# -> Int#
int64ToInt = Int64# -> Int#
int64ToInt#

intToInt64 :: Int# -> Int64#
intToInt64 :: Int# -> Int64#
intToInt64 = Int# -> Int64#
intToInt64#
#else
int64ToInt :: Int# -> Int#
int64ToInt i = i

intToInt64 :: Int# -> Int#
intToInt64 i = i
#endif

-- | A mutable, atomic integer.
data Atomic = C (MutableByteArray# RealWorld)

-- | Create a new, zero initialized, atomic.
new :: Int64 -> IO Atomic
new :: Int64 -> IO Atomic
new (I64# Int64#
n64) = (State# RealWorld -> (# State# RealWorld, Atomic #)) -> IO Atomic
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Atomic #)) -> IO Atomic)
-> (State# RealWorld -> (# State# RealWorld, Atomic #))
-> IO Atomic
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# SIZEOF_HSINT# s of { (# s1, mba #) ->
    case atomicWriteIntArray# mba 0# (int64ToInt n64) s1 of { s2 ->
    (# s2, C mba #) }}

read :: Atomic -> IO Int64
read :: Atomic -> IO Int64
read (C MutableByteArray# RealWorld
mba) = (State# RealWorld -> (# State# RealWorld, Int64 #)) -> IO Int64
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Int64 #)) -> IO Int64)
-> (State# RealWorld -> (# State# RealWorld, Int64 #)) -> IO Int64
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case MutableByteArray# RealWorld
-> Int# -> State# RealWorld -> (# State# RealWorld, Int# #)
forall d.
MutableByteArray# d -> Int# -> State# d -> (# State# d, Int# #)
atomicReadIntArray# MutableByteArray# RealWorld
mba Int#
0# State# RealWorld
s of { (# State# RealWorld
s1, Int#
n #) ->
    (# State# RealWorld
s1, Int64# -> Int64
I64# (Int# -> Int64#
intToInt64 Int#
n) #)}

-- | Set the atomic to the given value.
write :: Atomic -> Int64 -> IO ()
write :: Atomic -> Int64 -> IO ()
write (C MutableByteArray# RealWorld
mba) (I64# Int64#
n64) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> State# RealWorld
forall d.
MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
atomicWriteIntArray# MutableByteArray# RealWorld
mba Int#
0# (Int64# -> Int#
int64ToInt Int64#
n64) State# RealWorld
s of { State# RealWorld
s1 ->
    (# State# RealWorld
s1, () #) }

-- | Increase the atomic by the given amount.
add :: Atomic -> Int64 -> IO ()
add :: Atomic -> Int64 -> IO ()
add (C MutableByteArray# RealWorld
mba) (I64# Int64#
n64) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> (# State# RealWorld, Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchAddIntArray# MutableByteArray# RealWorld
mba Int#
0# (Int64# -> Int#
int64ToInt Int64#
n64) State# RealWorld
s of { (# State# RealWorld
s1, Int#
_ #) ->
    (# State# RealWorld
s1, () #) }

-- | Decrease the atomic by the given amount.
subtract :: Atomic -> Int64 -> IO ()
subtract :: Atomic -> Int64 -> IO ()
subtract (C MutableByteArray# RealWorld
mba) (I64# Int64#
n64) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> (# State# RealWorld, Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchSubIntArray# MutableByteArray# RealWorld
mba Int#
0# (Int64# -> Int#
int64ToInt Int64#
n64) State# RealWorld
s of { (# State# RealWorld
s1, Int#
_ #) ->
    (# State# RealWorld
s1, () #) }

#else

-- 32-bit machine, Int ~ Int32, fall back to IORef. This could be replaced with
-- faster implementations for specific 32-bit machines in the future, but the
-- idea is to preserve 64-bit width for counters.

newtype Atomic = C (IORef Int64)

-- | Create a new, zero initialized, atomic.
new :: Int64 -> IO Atomic
new = fmap C . newIORef

read :: Atomic -> IO Int64
read (C ior) = readIORef ior

-- | Set the atomic to the given value.
write :: Atomic -> Int64 -> IO ()
write (C ior) !i = atomicWriteIORef ior i

-- | Increase the atomic by the given amount.
add :: Atomic -> Int64 -> IO ()
add (C ior) !i = atomicModifyIORef' ior (\(!n) -> (n+i, ()))

-- | Decrease the atomic by the given amount.
subtract :: Atomic -> Int64 -> IO ()
subtract (C ior) !i = atomicModifyIORef' ior (\(!n) -> (n-i, ()))

#endif

-- | Increase the atomic by one.
inc :: Atomic -> IO ()
inc :: Atomic -> IO ()
inc Atomic
atomic = Atomic -> Int64 -> IO ()
add Atomic
atomic Int64
1

-- | Decrease the atomic by one.
dec :: Atomic -> IO ()
dec :: Atomic -> IO ()
dec Atomic
atomic = Atomic -> Int64 -> IO ()
subtract Atomic
atomic Int64
1