module Freckle.App.Async
  ( async
  , foldConcurrently
  , immortalCreate
  , immortalCreateLogged
  , ThreadContext (..)
  , getThreadContext
  , withThreadContext
  , forConcurrently
  , forConcurrently_
  , mapConcurrently
  , mapConcurrently_
  ) where

import Freckle.App.Prelude

import Blammo.Logging (Message (..), MonadLogger, logError, (.=))
import Blammo.Logging.ThreadContext (MonadMask)
import Blammo.Logging.ThreadContext qualified as Blammo
import Control.Immortal qualified as Immortal
import Control.Monad (forever)
import Data.Aeson (Value)
import Data.Aeson.Compat (KeyMap)
import Data.Aeson.Compat qualified as KeyMap
import OpenTelemetry.Context qualified as OpenTelemetry
import OpenTelemetry.Context.ThreadLocal qualified as OpenTelemetry
import UnliftIO.Async (Async, conc, runConc)
import UnliftIO.Async qualified as UnliftIO
import UnliftIO.Concurrent (threadDelay)

-- | 'UnliftIO.Async.async' but passing the thread context along
async :: (MonadMask m, MonadUnliftIO m) => m a -> m (Async a)
async :: forall (m :: * -> *) a.
(MonadMask m, MonadUnliftIO m) =>
m a -> m (Async a)
async m a
f = do
  ThreadContext
context <- m ThreadContext
forall (m :: * -> *). MonadIO m => m ThreadContext
getThreadContext
  m a -> m (Async a)
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
UnliftIO.async (m a -> m (Async a)) -> m a -> m (Async a)
forall a b. (a -> b) -> a -> b
$ ThreadContext -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext
context m a
f

-- | Run a list of actions concurrently
--
-- The forked threads will have the current thread context copied to them.
foldConcurrently
  :: (MonadUnliftIO m, MonadMask m, Monoid a, Foldable t) => t (m a) -> m a
foldConcurrently :: forall (m :: * -> *) a (t :: * -> *).
(MonadUnliftIO m, MonadMask m, Monoid a, Foldable t) =>
t (m a) -> m a
foldConcurrently t (m a)
xs = do
  ThreadContext
context <- m ThreadContext
forall (m :: * -> *). MonadIO m => m ThreadContext
getThreadContext
  Conc m a -> m a
forall (m :: * -> *) a. MonadUnliftIO m => Conc m a -> m a
runConc (Conc m a -> m a) -> Conc m a -> m a
forall a b. (a -> b) -> a -> b
$ (m a -> Conc m a) -> t (m a) -> Conc m a
forall m a. Monoid m => (a -> m) -> t a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (m a -> Conc m a
forall (m :: * -> *) a. m a -> Conc m a
conc (m a -> Conc m a) -> (m a -> m a) -> m a -> Conc m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadContext -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext
context) t (m a)
xs

-- | 'UnliftIO.Async.forConcurrently' but passing the thread context along
forConcurrently
  :: (MonadUnliftIO m, MonadMask m, Traversable t) => t a -> (a -> m b) -> m (t b)
forConcurrently :: forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, MonadMask m, Traversable t) =>
t a -> (a -> m b) -> m (t b)
forConcurrently = ((a -> m b) -> t a -> m (t b)) -> t a -> (a -> m b) -> m (t b)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> m b) -> t a -> m (t b)
forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, MonadMask m, Traversable t) =>
(a -> m b) -> t a -> m (t b)
mapConcurrently

-- | 'UnliftIO.Async.mapConcurrently' but passing the thread context along
mapConcurrently
  :: (MonadUnliftIO m, MonadMask m, Traversable t) => (a -> m b) -> t a -> m (t b)
mapConcurrently :: forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, MonadMask m, Traversable t) =>
(a -> m b) -> t a -> m (t b)
mapConcurrently a -> m b
f t a
xs = do
  ThreadContext
context <- m ThreadContext
forall (m :: * -> *). MonadIO m => m ThreadContext
getThreadContext
  (a -> m b) -> t a -> m (t b)
forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, Traversable t) =>
(a -> m b) -> t a -> m (t b)
UnliftIO.mapConcurrently (ThreadContext -> m b -> m b
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext
context (m b -> m b) -> (a -> m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f) t a
xs

-- | 'UnliftIO.Async.forConcurrently_' but passing the thread context along
forConcurrently_
  :: (MonadUnliftIO m, MonadMask m, Traversable t) => t a -> (a -> m b) -> m ()
forConcurrently_ :: forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, MonadMask m, Traversable t) =>
t a -> (a -> m b) -> m ()
forConcurrently_ = ((a -> m b) -> t a -> m ()) -> t a -> (a -> m b) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> m b) -> t a -> m ()
forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, MonadMask m, Traversable t) =>
(a -> m b) -> t a -> m ()
mapConcurrently_

-- | 'UnliftIO.Async.mapConcurrently_' but passing the thread context along
mapConcurrently_
  :: (MonadUnliftIO m, MonadMask m, Traversable t) => (a -> m b) -> t a -> m ()
mapConcurrently_ :: forall (m :: * -> *) (t :: * -> *) a b.
(MonadUnliftIO m, MonadMask m, Traversable t) =>
(a -> m b) -> t a -> m ()
mapConcurrently_ a -> m b
f t a
xs = do
  ThreadContext
context <- m ThreadContext
forall (m :: * -> *). MonadIO m => m ThreadContext
getThreadContext
  (a -> m b) -> t a -> m ()
forall (m :: * -> *) (f :: * -> *) a b.
(MonadUnliftIO m, Foldable f) =>
(a -> m b) -> f a -> m ()
UnliftIO.mapConcurrently_ (ThreadContext -> m b -> m b
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext
context (m b -> m b) -> (a -> m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f) t a
xs

-- | Wrapper around creating "Control.Immortal" processes
--
-- Features:
--
-- - Ensures the thread context is correctly passed to both your spawned action
--   and your error handler
-- - Blocks forever after spawning your thread.
immortalCreate
  :: (MonadMask m, MonadUnliftIO m)
  => (Either SomeException () -> m ())
  -- ^ How to handle unexpected finish
  -> m ()
  -- ^ The action to run persistently
  -> m a
immortalCreate :: forall (m :: * -> *) a.
(MonadMask m, MonadUnliftIO m) =>
(Either SomeException () -> m ()) -> m () -> m a
immortalCreate Either SomeException () -> m ()
onUnexpected m ()
act = do
  ThreadContext
context <- m ThreadContext
forall (m :: * -> *). MonadIO m => m ThreadContext
getThreadContext

  let
    act' :: m ()
act' = ThreadContext -> m () -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext
context m ()
act
    onUnexpected' :: Either SomeException () -> m ()
onUnexpected' = ThreadContext -> m () -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext
context (m () -> m ())
-> (Either SomeException () -> m ())
-> Either SomeException ()
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either SomeException () -> m ()
onUnexpected

  m Thread -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Thread -> m ()) -> m Thread -> m ()
forall a b. (a -> b) -> a -> b
$ (Thread -> m ()) -> m Thread
forall (m :: * -> *).
MonadUnliftIO m =>
(Thread -> m ()) -> m Thread
Immortal.create ((Thread -> m ()) -> m Thread) -> (Thread -> m ()) -> m Thread
forall a b. (a -> b) -> a -> b
$ \Thread
thread -> do
    Thread -> (Either SomeException () -> m ()) -> m () -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
Thread -> (Either SomeException () -> m ()) -> m () -> m ()
Immortal.onUnexpectedFinish Thread
thread Either SomeException () -> m ()
onUnexpected' m ()
act'

  m () -> m a
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m a) -> m () -> m a
forall a b. (a -> b) -> a -> b
$ Int -> m ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay Int
forall a. Bounded a => a
maxBound

-- | 'immortalCreate' with logging of unexpected finishes
immortalCreateLogged
  :: (MonadMask m, MonadUnliftIO m, MonadLogger m) => m () -> m a
immortalCreateLogged :: forall (m :: * -> *) a.
(MonadMask m, MonadUnliftIO m, MonadLogger m) =>
m () -> m a
immortalCreateLogged = (Either SomeException () -> m ()) -> m () -> m a
forall (m :: * -> *) a.
(MonadMask m, MonadUnliftIO m) =>
(Either SomeException () -> m ()) -> m () -> m a
immortalCreate ((Either SomeException () -> m ()) -> m () -> m a)
-> (Either SomeException () -> m ()) -> m () -> m a
forall a b. (a -> b) -> a -> b
$ (SomeException -> m ())
-> (() -> m ()) -> Either SomeException () -> m ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> m ()
forall {m :: * -> *} {e}. (MonadLogger m, Exception e) => e -> m ()
logEx () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
 where
  logEx :: e -> m ()
logEx e
ex = Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logError (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Unexpected Finish" Text -> [SeriesElem] -> Message
:# [Key
"exception" Key -> String -> SeriesElem
forall v. ToJSON v => Key -> v -> SeriesElem
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= e -> String
forall e. Exception e => e -> String
displayException e
ex]

data ThreadContext = ThreadContext
  { ThreadContext -> KeyMap Value
blammoContext :: KeyMap Value
  , ThreadContext -> Maybe Context
openTelemetryContext :: Maybe OpenTelemetry.Context
  }

getThreadContext :: MonadIO m => m ThreadContext
getThreadContext :: forall (m :: * -> *). MonadIO m => m ThreadContext
getThreadContext =
  KeyMap Value -> Maybe Context -> ThreadContext
ThreadContext
    (KeyMap Value -> Maybe Context -> ThreadContext)
-> m (KeyMap Value) -> m (Maybe Context -> ThreadContext)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (KeyMap Value) -> m (KeyMap Value)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (KeyMap Value)
forall (m :: * -> *). (MonadIO m, MonadThrow m) => m (KeyMap Value)
Blammo.myThreadContext
    m (Maybe Context -> ThreadContext)
-> m (Maybe Context) -> m ThreadContext
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m (Maybe Context)
forall (m :: * -> *). MonadIO m => m (Maybe Context)
OpenTelemetry.lookupContext

withThreadContext :: (MonadIO m, MonadMask m) => ThreadContext -> m a -> m a
withThreadContext :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
ThreadContext -> m a -> m a
withThreadContext ThreadContext {KeyMap Value
blammoContext :: ThreadContext -> KeyMap Value
blammoContext :: KeyMap Value
blammoContext, Maybe Context
openTelemetryContext :: ThreadContext -> Maybe Context
openTelemetryContext :: Maybe Context
openTelemetryContext} m a
continue =
  [Pair] -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Pair] -> m a -> m a
Blammo.withThreadContext (KeyMap Value -> [Pair]
forall v. KeyMap v -> [(Key, v)]
KeyMap.toList KeyMap Value
blammoContext) (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ @Maybe Context -> m (Maybe Context)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe Context)
OpenTelemetry.attachContext Maybe Context
openTelemetryContext
    m a
continue