{-
  Copyright (c) Meta Platforms, Inc. and affiliates.
  All rights reserved.

  This source code is licensed under the BSD-style license found in the
  LICENSE file in the root directory of this source tree.
-}

-- | Processing streams with a fixed number of worker threads
module Control.Concurrent.Stream
  ( stream
  , streamBound
  , streamWithInput
  , streamWithOutput
  , streamWithInputOutput
  , mapConcurrentlyBounded
  , forConcurrentlyBounded
  ) where

import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Data.Maybe
import Data.IORef

data ShouldBindThreads = BoundThreads | UnboundThreads

-- | Maps a fixed number of workers concurrently over a stream of values
-- produced by a producer function. The producer is passed a function to
-- call for each work item. If a worker throws a synchronous exception, it
-- will be propagated to the caller.
stream
  :: Int -- ^ Maximum Concurrency
  -> ((a -> IO ()) -> IO ()) -- ^ Producer
  -> (a -> IO ()) -- ^ Worker
  -> IO ()
stream :: forall a. Int -> ((a -> IO ()) -> IO ()) -> (a -> IO ()) -> IO ()
stream Int
maxConcurrency (a -> IO ()) -> IO ()
producer a -> IO ()
worker =
  ((a -> IO ()) -> IO ()) -> [()] -> (() -> a -> IO ()) -> IO ()
forall a b.
((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
streamWithInput (a -> IO ()) -> IO ()
producer (Int -> () -> [()]
forall a. Int -> a -> [a]
replicate Int
maxConcurrency ()) ((() -> a -> IO ()) -> IO ()) -> (() -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (a -> IO ()) -> () -> a -> IO ()
forall a b. a -> b -> a
const a -> IO ()
worker

-- | Like stream, but uses bound threads for the workers.  See
-- 'Control.Concurrent.forkOS' for details on bound threads.
streamBound
  :: Int -- ^ Maximum Concurrency
  -> ((a -> IO ()) -> IO ()) -- ^ Producer
  -> (a -> IO ()) -- ^ Worker
  -> IO ()
streamBound :: forall a. Int -> ((a -> IO ()) -> IO ()) -> (a -> IO ()) -> IO ()
streamBound Int
maxConcurrency (a -> IO ()) -> IO ()
producer a -> IO ()
worker =
  ShouldBindThreads
-> ((a -> IO ()) -> IO ()) -> [()] -> (() -> a -> IO ()) -> IO ()
forall a b.
ShouldBindThreads
-> ((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
stream_ ShouldBindThreads
BoundThreads (a -> IO ()) -> IO ()
producer (Int -> () -> [()]
forall a. Int -> a -> [a]
replicate Int
maxConcurrency ()) ((() -> a -> IO ()) -> IO ()) -> (() -> a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (a -> IO ()) -> () -> a -> IO ()
forall a b. a -> b -> a
const a -> IO ()
worker

-- | Like stream, but each worker is passed an element of an input list.
streamWithInput
  :: ((a -> IO ()) -> IO ()) -- ^ Producer
  -> [b] -- ^ Worker state
  -> (b -> a -> IO ()) -- ^ Worker
  -> IO ()
streamWithInput :: forall a b.
((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
streamWithInput = ShouldBindThreads
-> ((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
forall a b.
ShouldBindThreads
-> ((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
stream_ ShouldBindThreads
UnboundThreads

-- | Like 'stream', but collects the results of each worker
streamWithOutput
  :: Int 
  -> ((a -> IO ()) -> IO ()) -- ^ Producer
  -> (a -> IO c) -- ^ Worker
  -> IO [c]
streamWithOutput :: forall a c. Int -> ((a -> IO ()) -> IO ()) -> (a -> IO c) -> IO [c]
streamWithOutput Int
maxConcurrency (a -> IO ()) -> IO ()
producer a -> IO c
worker =
  ((a -> IO ()) -> IO ()) -> [()] -> (() -> a -> IO c) -> IO [c]
forall a b c.
((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO c) -> IO [c]
streamWithInputOutput (a -> IO ()) -> IO ()
producer (Int -> () -> [()]
forall a. Int -> a -> [a]
replicate Int
maxConcurrency ()) ((() -> a -> IO c) -> IO [c]) -> (() -> a -> IO c) -> IO [c]
forall a b. (a -> b) -> a -> b
$ 
    (a -> IO c) -> () -> a -> IO c
forall a b. a -> b -> a
const a -> IO c
worker

-- | Like 'streamWithInput', but collects the results of each worker
streamWithInputOutput
  :: ((a -> IO ()) -> IO ()) -- ^ Producer
  -> [b] -- ^ Worker input
  -> (b -> a -> IO c) -- ^ Worker
  -> IO [c]
streamWithInputOutput :: forall a b c.
((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO c) -> IO [c]
streamWithInputOutput (a -> IO ()) -> IO ()
producer [b]
workerInput b -> a -> IO c
worker = do
  IORef [IORef (Maybe c)]
results <- [IORef (Maybe c)] -> IO (IORef [IORef (Maybe c)])
forall a. a -> IO (IORef a)
newIORef []
  let prod :: ((a, IORef (Maybe c)) -> IO ()) -> IO ()
prod (a, IORef (Maybe c)) -> IO ()
write = (a -> IO ()) -> IO ()
producer ((a -> IO ()) -> IO ()) -> (a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \a
a -> do
        IORef (Maybe c)
res <- Maybe c -> IO (IORef (Maybe c))
forall a. a -> IO (IORef a)
newIORef Maybe c
forall a. Maybe a
Nothing
        IORef [IORef (Maybe c)]
-> ([IORef (Maybe c)] -> [IORef (Maybe c)]) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef [IORef (Maybe c)]
results (IORef (Maybe c)
res IORef (Maybe c) -> [IORef (Maybe c)] -> [IORef (Maybe c)]
forall a. a -> [a] -> [a]
:)
        (a, IORef (Maybe c)) -> IO ()
write (a
a, IORef (Maybe c)
res)
  ShouldBindThreads
-> (((a, IORef (Maybe c)) -> IO ()) -> IO ())
-> [b]
-> (b -> (a, IORef (Maybe c)) -> IO ())
-> IO ()
forall a b.
ShouldBindThreads
-> ((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
stream_ ShouldBindThreads
UnboundThreads ((a, IORef (Maybe c)) -> IO ()) -> IO ()
prod [b]
workerInput ((b -> (a, IORef (Maybe c)) -> IO ()) -> IO ())
-> (b -> (a, IORef (Maybe c)) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \b
s (a
a,IORef (Maybe c)
ref) -> do
    b -> a -> IO c
worker b
s a
a IO c -> (c -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (Maybe c) -> Maybe c -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe c)
ref (Maybe c -> IO ()) -> (c -> Maybe c) -> c -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. c -> Maybe c
forall a. a -> Maybe a
Just
  IORef [IORef (Maybe c)] -> IO [IORef (Maybe c)]
forall a. IORef a -> IO a
readIORef IORef [IORef (Maybe c)]
results IO [IORef (Maybe c)]
-> ([IORef (Maybe c)] -> IO [Maybe c]) -> IO [Maybe c]
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (IORef (Maybe c) -> IO (Maybe c))
-> [IORef (Maybe c)] -> IO [Maybe c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IORef (Maybe c) -> IO (Maybe c)
forall a. IORef a -> IO a
readIORef IO [Maybe c] -> ([Maybe c] -> IO [c]) -> IO [c]
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [c] -> IO [c]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([c] -> IO [c]) -> ([Maybe c] -> [c]) -> [Maybe c] -> IO [c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe c] -> [c]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe c] -> [c]) -> ([Maybe c] -> [Maybe c]) -> [Maybe c] -> [c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe c] -> [Maybe c]
forall a. [a] -> [a]
reverse
    
stream_
  :: ShouldBindThreads -- use bound threads?
  -> ((a -> IO ()) -> IO ()) -- ^ Producer
  -> [b] -- Worker input
  -> (b -> a -> IO ()) -- ^ Worker
  -> IO ()
stream_ :: forall a b.
ShouldBindThreads
-> ((a -> IO ()) -> IO ()) -> [b] -> (b -> a -> IO ()) -> IO ()
stream_ ShouldBindThreads
useBoundThreads (a -> IO ()) -> IO ()
producer [b]
workerInput b -> a -> IO ()
worker = do
  let maxConcurrency :: Int
maxConcurrency = [b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [b]
workerInput
  TBQueue (Maybe a)
q <- STM (TBQueue (Maybe a)) -> IO (TBQueue (Maybe a))
forall a. STM a -> IO a
atomically (STM (TBQueue (Maybe a)) -> IO (TBQueue (Maybe a)))
-> STM (TBQueue (Maybe a)) -> IO (TBQueue (Maybe a))
forall a b. (a -> b) -> a -> b
$ Natural -> STM (TBQueue (Maybe a))
forall a. Natural -> STM (TBQueue a)
newTBQueue (Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxConcurrency)
  let write :: a -> IO ()
write a
x = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TBQueue (Maybe a) -> Maybe a -> STM ()
forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue (Maybe a)
q (a -> Maybe a
forall a. a -> Maybe a
Just a
x)
  ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
    IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
concurrently_ ((IO () -> IO ()) -> TBQueue (Maybe a) -> IO ()
forall {a}. (IO () -> IO a) -> TBQueue (Maybe a) -> IO ()
runWorkers IO () -> IO ()
forall a. IO a -> IO a
unmask TBQueue (Maybe a)
q) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall a. IO a -> IO a
unmask (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      -- run the producer
      (a -> IO ()) -> IO ()
producer a -> IO ()
write
      -- write end-markers for all workers
      Int -> IO () -> IO ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
maxConcurrency (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TBQueue (Maybe a) -> Maybe a -> STM ()
forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue (Maybe a)
q Maybe a
forall a. Maybe a
Nothing
  where
    runWorkers :: (IO () -> IO a) -> TBQueue (Maybe a) -> IO ()
runWorkers IO () -> IO a
unmask TBQueue (Maybe a)
q = case ShouldBindThreads
useBoundThreads of
      ShouldBindThreads
BoundThreads ->
        (IO () -> IO () -> IO ()) -> [IO ()] -> IO ()
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
concurrentlyBound ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$
          (b -> IO ()) -> [b] -> [IO ()]
forall a b. (a -> b) -> [a] -> [b]
map ((IO () -> IO a) -> TBQueue (Maybe a) -> b -> IO ()
forall {a}. (IO () -> IO a) -> TBQueue (Maybe a) -> b -> IO ()
runWorker IO () -> IO a
unmask TBQueue (Maybe a)
q) [b]
workerInput
      ShouldBindThreads
UnboundThreads ->
        (b -> IO ()) -> [b] -> IO ()
forall (f :: * -> *) a b. Foldable f => (a -> IO b) -> f a -> IO ()
mapConcurrently_ ((IO () -> IO a) -> TBQueue (Maybe a) -> b -> IO ()
forall {a}. (IO () -> IO a) -> TBQueue (Maybe a) -> b -> IO ()
runWorker IO () -> IO a
unmask TBQueue (Maybe a)
q) [b]
workerInput

    concurrentlyBound :: IO a -> IO b -> IO ()
concurrentlyBound IO a
l IO b
r =
      IO a -> (Async a -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsyncBound IO a
l ((Async a -> IO ()) -> IO ()) -> (Async a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Async a
a ->
      IO b -> (Async b -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsyncBound IO b
r ((Async b -> IO ()) -> IO ()) -> (Async b -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Async b
b ->
      IO (a, b) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (a, b) -> IO ()) -> IO (a, b) -> IO ()
forall a b. (a -> b) -> a -> b
$ Async a -> Async b -> IO (a, b)
forall a b. Async a -> Async b -> IO (a, b)
waitBoth Async a
a Async b
b

    runWorker :: (IO () -> IO a) -> TBQueue (Maybe a) -> b -> IO ()
runWorker IO () -> IO a
unmask TBQueue (Maybe a)
q b
s = do
      Maybe a
v <- STM (Maybe a) -> IO (Maybe a)
forall a. STM a -> IO a
atomically (STM (Maybe a) -> IO (Maybe a)) -> STM (Maybe a) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ TBQueue (Maybe a) -> STM (Maybe a)
forall a. TBQueue a -> STM a
readTBQueue TBQueue (Maybe a)
q
      case Maybe a
v of
        Maybe a
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just a
t -> do
          IO () -> IO a
unmask (b -> a -> IO ()
worker b
s a
t)
          (IO () -> IO a) -> TBQueue (Maybe a) -> b -> IO ()
runWorker IO () -> IO a
unmask TBQueue (Maybe a)
q b
s

-- | Concurrent map over a list of values, using a bounded number of threads.
mapConcurrentlyBounded
  :: Int -- ^ Maximum concurrency
  -> (a -> IO b) -- ^ Function to map over the input values
  -> [a] -- ^ List of input values
  -> IO [b] -- ^ List of output values
mapConcurrentlyBounded :: forall a b. Int -> (a -> IO b) -> [a] -> IO [b]
mapConcurrentlyBounded Int
maxConcurrency a -> IO b
f [a]
input =
  Int -> ((a -> IO ()) -> IO ()) -> (a -> IO b) -> IO [b]
forall a c. Int -> ((a -> IO ()) -> IO ()) -> (a -> IO c) -> IO [c]
streamWithOutput Int
maxConcurrency ([a] -> (a -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [a]
input) a -> IO b
f
  
-- | 'mapConcurrentlyBounded' but with its arguments reversed
forConcurrentlyBounded
  :: Int -- ^ Maximum concurrency
  -> [a] -- ^ List of input values
  -> (a -> IO b) -- ^ Function to map over the input values
  -> IO [b] -- ^ List of output values
forConcurrentlyBounded :: forall a b. Int -> [a] -> (a -> IO b) -> IO [b]
forConcurrentlyBounded = ((a -> IO b) -> [a] -> IO [b]) -> [a] -> (a -> IO b) -> IO [b]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (((a -> IO b) -> [a] -> IO [b]) -> [a] -> (a -> IO b) -> IO [b])
-> (Int -> (a -> IO b) -> [a] -> IO [b])
-> Int
-> [a]
-> (a -> IO b)
-> IO [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (a -> IO b) -> [a] -> IO [b]
forall a b. Int -> (a -> IO b) -> [a] -> IO [b]
mapConcurrentlyBounded