/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Docs: https://fburl.com/fbcref_asyncgenerator // #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if FOLLY_HAS_COROUTINES namespace folly { namespace coro { namespace detail { template class AsyncGeneratorPromise; } // namespace detail /** * The AsyncGenerator class represents a sequence of asynchronously produced * values where the values are produced by a coroutine. * * Values are produced by using the 'co_yield' keyword and the coroutine can * also consume other asynchronous operations using the 'co_await' keyword. * The end of the sequence is indicated by executing 'co_return;' either * explicitly or by letting execution run off the end of the coroutine. * * Reference Type * -------------- * The first template parameter controls the 'reference' type. * i.e. the type returned when you dereference the iterator using operator*(). * This type is typically specified as an actual reference type. * eg. 'const T&' (non-mutable), 'T&' (mutable) or 'T&&' (movable) depending * what access you want your consumers to have to the yielded values. * * It's also possible to specify the 'Reference' template parameter as a value * type. In this case the generator takes a copy of the yielded value (either * copied or move-constructed) and you get a copy of this value every time * you dereference the iterator with '*iter'. * This can be expensive for types that are expensive to copy, but can provide * a small performance win for types that are cheap to copy (like built-in * integer types). * * Value Type * ---------- * The second template parameter is optional, but if specified can be used as * the value-type that should be used to take a copy of the value returned by * the Reference type. * By default this type is the same as 'Reference' type stripped of qualifiers * and references. However, in some cases it can be a different type. * For example, if the 'Reference' type was a non-reference proxy type. * * Example: * * AsyncGenerator, std::tuple> getItems() { * auto firstMap = co_await getFirstMap(); * for (auto&& [k, v] : firstMap) { * co_yield {k, v}; * } * auto secondMap = co_await getSecondMap(); * for (auto&& [k, v] : secondMap) { * co_yield {k, v}; * } * } * * This is mostly useful for generic algorithms that need to take copies of * elements of the sequence. * * Executor Affinity * ----------------- * An AsyncGenerator coroutine has similar executor-affinity to that of the * folly::coro::Task coroutine type. Every time a consumer requests a new value * from the generator using 'co_await ++it' the generator inherits the caller's * current executor. The coroutine will ensure that it always resumes on the * associated executor when resuming from `co_await' expression until it hits * the next 'co_yield' or 'co_return' statement. * Note that the executor can potentially change at a 'co_yield' statement if * the next element of the sequence is requested from a consumer coroutine that * is associated with a different executor. * * Example: Writing an async generator. * * folly::coro::AsyncGenerator getRecordsAsync() { * auto resultSet = executeQuery(someQuery); * for (;;) { * auto resultSetPage = co_await resultSet.nextPage(); * if (resultSetPage.empty()) break; * for (auto& row : resultSetPage) { * co_yield Record{row.get("name"), row.get("email")}; * } * } * } * * Example: Consuming items from an async generator * * folly::coro::Task consumer() { * auto records = getRecordsAsync(); * while (auto item = co_await records.next()) { * auto&& record = *item; * process(record); * } * } * * Async Cleanup * ------------- * When the template parameter RequiresCleanup is true, the owner of an * AsyncGenerator is responsible for awaiting cleanup() before the generator * object's destructor is called. That allows to use folly::coro::co_scope_exit * awaitables inside AsyncGenerator, which are asynchronously executed when * cleanup() is awaited. Note that the AsyncGenerator coroutine frame is * destroyed before co_scope_exit awaitables are executed. * * There is an alias CleanableAsyncGenerator for AsyncGenerator with * RequiresCleanup set to true. * * Drain safety * ------------ * One significant difference between AsyncGenerator and folly::coro::Task is * that AsyncGenerator may be destroyed between next() calls - i.e. destroyed * without being fully drained. * * For example: * * AsyncGenerator gen() { * SCOPE_EXIT { * LOG(INFO) << "Step 4"; * }; * LOG(INFO) << "Step 1"; * co_yield 41; * SCOPE_EXIT { * LOG(INFO) << "Step 3"; * }; * LOG(INFO) << "Step 2"; * co_yield 42; * SCOPE_EXIT { * LOG(INFO) << "Never reached"; * }; * LOG(INFO) << "Never reached"; * co_yield 43; * } * * { * AsyncGenerator g = gen(); * while (auto next = co_await g.next()) { * LOG(INFO) << *next; * if (*next == 42) { * break; * // ^^^ this may trigger generator destruction before it is drained. * } * } * } * * This means that when writing an AsyncGenerator, you should always document * whether such AsyncGenerator requires draining before destruction (drain * unsafe). When possible you should always aim to make AsyncGenerator not * require draining before destruction (drain safe). * * If an AsyncGenerator is drain unsafe, always mention this in the * documentation and ideally include some assertions that help detect cases * where such AsyncGenerator is destroyed without being fully drained. * * Example: * * AsyncGenerator gen() { * auto drainGuard = makeGuard([] { LOG(FATAL) << "I shall be drained!"; }); * co_yield 41; * co_yield 42; * co_yield 43; * drainGuard.dismiss(); * } */ template < typename Reference, typename Value = remove_cvref_t, bool RequiresCleanup = false> class FOLLY_NODISCARD AsyncGenerator { static_assert( std::is_constructible::value, "AsyncGenerator 'value_type' must be constructible from a 'reference'."); public: using promise_type = detail::AsyncGeneratorPromise; private: using handle_t = coroutine_handle; public: using value_type = Value; using reference = Reference; using pointer = std::add_pointer_t; public: AsyncGenerator() noexcept : coro_() {} AsyncGenerator(AsyncGenerator&& other) noexcept : coro_(std::exchange(other.coro_, {})) {} ~AsyncGenerator() { if (coro_) { if constexpr (RequiresCleanup) { LOG(FATAL) << "cleanup() hasn't been called!"; } coro_.destroy(); } } class CleanupSemiAwaitable; class FOLLY_NODISCARD CleanupAwaitable { public: bool await_ready() noexcept { return !scopeExit_; } template FOLLY_NOINLINE auto await_suspend( coroutine_handle continuation) noexcept { asyncFrame_.setReturnAddress(); scopeExit_.promise().setContext( continuation, &asyncFrame_, executor_.get_alias()); if constexpr (detail::promiseHasAsyncFrame_v) { folly::pushAsyncStackFrameCallerCallee( continuation.promise().getAsyncFrame(), asyncFrame_); return scopeExit_; } else { folly::resumeCoroutineWithNewAsyncStackRoot(scopeExit_); } } void await_resume() noexcept {} private: friend CleanupSemiAwaitable; CleanupAwaitable( coroutine_handle scopeExit, folly::Executor::KeepAlive<> executor) noexcept : scopeExit_{scopeExit}, executor_{std::move(executor)} {} friend CleanupAwaitable tag_invoke( cpo_t, CleanupAwaitable awaitable) noexcept { return std::move(awaitable); } coroutine_handle scopeExit_; folly::AsyncStackFrame asyncFrame_; folly::Executor::KeepAlive<> executor_; }; class FOLLY_NODISCARD CleanupSemiAwaitable { public: CleanupAwaitable viaIfAsync(Executor::KeepAlive<> executor) noexcept { return CleanupAwaitable{scopeExit_, std::move(executor)}; } private: friend AsyncGenerator; explicit CleanupSemiAwaitable( coroutine_handle scopeExit) noexcept : scopeExit_{scopeExit} {} coroutine_handle scopeExit_; }; CleanupSemiAwaitable cleanup() && { static_assert(RequiresCleanup); if (coro_) { SCOPE_EXIT { std::exchange(coro_, {}).destroy(); }; return CleanupSemiAwaitable{coro_.promise().scopeExit_}; } else { return CleanupSemiAwaitable{{}}; } } AsyncGenerator& operator=(AsyncGenerator&& other) noexcept { auto oldCoro = std::exchange(coro_, std::exchange(other.coro_, {})); if (oldCoro) { CHECK(!RequiresCleanup) << "cleanup() hasn't been called!"; oldCoro.destroy(); } return *this; } void swap(AsyncGenerator& other) noexcept { std::swap(coro_, other.coro_); } class NextAwaitable; class NextSemiAwaitable; class NextResult { public: NextResult() noexcept : hasValue_(false) {} NextResult(NextResult&& other) noexcept : hasValue_(other.hasValue_) { if (hasValue_) { value_.construct(std::move(other.value_).get()); } } ~NextResult() { if (hasValue_) { value_.destruct(); } } NextResult& operator=(NextResult&& other) { if (&other != this) { if (has_value()) { hasValue_ = false; value_.destruct(); } if (other.has_value()) { value_.construct(std::move(other.value_).get()); hasValue_ = true; } } return *this; } bool has_value() const noexcept { return hasValue_; } explicit operator bool() const noexcept { return has_value(); } decltype(auto) value() & { DCHECK(has_value()); return value_.get(); } decltype(auto) value() && { DCHECK(has_value()); return std::move(value_).get(); } decltype(auto) value() const& { DCHECK(has_value()); return value_.get(); } decltype(auto) value() const&& { DCHECK(has_value()); return std::move(value_).get(); } decltype(auto) operator*() & { return value(); } decltype(auto) operator*() && { return std::move(*this).value(); } decltype(auto) operator*() const& { return value(); } decltype(auto) operator*() const&& { return std::move(*this).value(); } decltype(auto) operator->() { DCHECK(has_value()); auto&& x = value_.get(); return std::addressof(x); } decltype(auto) operator->() const { DCHECK(has_value()); auto&& x = value_.get(); return std::addressof(x); } private: friend NextAwaitable; explicit NextResult(handle_t coro) noexcept : hasValue_(true) { value_.construct(coro.promise().getRvalue()); } detail::ManualLifetime value_; bool hasValue_ = false; }; class NextAwaitable { public: bool await_ready() noexcept { return !coro_; } template FOLLY_NOINLINE auto await_suspend( coroutine_handle continuation) noexcept { auto& promise = coro_.promise(); promise.setContinuation(continuation); promise.clearValue(); auto& asyncFrame = promise.getAsyncFrame(); asyncFrame.setReturnAddress(); if constexpr (detail::promiseHasAsyncFrame_v) { folly::pushAsyncStackFrameCallerCallee( continuation.promise().getAsyncFrame(), asyncFrame); return coro_; } else { folly::resumeCoroutineWithNewAsyncStackRoot(coro_); } } NextResult await_resume() { if (!coro_) { return NextResult{}; } else if (!coro_.promise().hasValue()) { coro_.promise().throwIfException(); return NextResult{}; } else { return NextResult{coro_}; } } folly::Try await_resume_try() noexcept { if (coro_) { if (coro_.promise().hasValue()) { return folly::Try(NextResult{coro_}); } else if (coro_.promise().hasException()) { return folly::Try( std::move(coro_.promise().getException())); } } return folly::Try(NextResult{}); } private: friend NextSemiAwaitable; explicit NextAwaitable(handle_t coro) noexcept : coro_(coro) {} friend NextAwaitable tag_invoke( cpo_t, NextAwaitable awaitable) noexcept { return NextAwaitable{awaitable.coro_}; } handle_t coro_; }; class NextSemiAwaitable { public: NextAwaitable viaIfAsync(Executor::KeepAlive<> executor) noexcept { if (coro_) { coro_.promise().setExecutor(std::move(executor)); } return NextAwaitable{coro_}; } friend NextSemiAwaitable co_withCancellation( CancellationToken cancelToken, NextSemiAwaitable&& awaitable) { if (awaitable.coro_) { awaitable.coro_.promise().setCancellationToken(std::move(cancelToken)); } return NextSemiAwaitable{std::exchange(awaitable.coro_, {})}; } private: friend AsyncGenerator; explicit NextSemiAwaitable(handle_t coro) noexcept : coro_(coro) {} handle_t coro_; }; NextSemiAwaitable next() noexcept { DCHECK(!coro_ || !coro_.done()); return NextSemiAwaitable{coro_}; } template friend AsyncGenerator tag_invoke( tag_t, tag_t, F_ f, A_... a) { if constexpr (RequiresCleanup) { auto&& [fScoped, r] = co_await co_scope_exit( [](auto&&, auto&& gen) { return std::move(gen).cleanup(); }, static_cast(f), AsyncGenerator{}); r = invoke(static_cast(fScoped), static_cast(a)...); while (true) { co_yield co_result(co_await co_awaitTry(r.next())); } } else { auto r = invoke(static_cast(f), static_cast(a)...); while (true) { co_yield co_result(co_await co_awaitTry(r.next())); } } } private: friend promise_type; explicit AsyncGenerator(coroutine_handle coro) noexcept : coro_(coro) {} coroutine_handle coro_; }; template > using CleanableAsyncGenerator = AsyncGenerator; namespace detail { template struct BaseAsyncGeneratorPromise {}; template <> struct BaseAsyncGeneratorPromise { coroutine_handle scopeExit_; }; template class AsyncGeneratorPromise final : public ExtendedCoroutinePromise, BaseAsyncGeneratorPromise { class YieldAwaiter { public: bool await_ready() noexcept { return false; } coroutine_handle<> await_suspend( coroutine_handle h) noexcept { AsyncGeneratorPromise& promise = h.promise(); // Pop AsyncStackFrame first as clearContext() clears the frame state. folly::popAsyncStackFrameCallee(promise.getAsyncFrame()); promise.clearContext(); if (promise.hasException()) { auto [handle, frame] = promise.continuation_.getErrorHandle(promise.getException()); return handle.getHandle(); } return promise.continuation_.getHandle(); } void await_resume() noexcept {} }; public: template AsyncGeneratorPromise(Args&... args) { if constexpr (RequiresCleanup) { scheduleAutoCleanupIfNeeded( coroutine_handle::from_promise(*this), args...); } } ~AsyncGeneratorPromise() { switch (state_) { case State::VALUE: folly::coro::detail::deactivate(value_); break; case State::EXCEPTION_WRAPPER: folly::coro::detail::deactivate(exceptionWrapper_); break; case State::DONE: case State::INVALID: break; } } static void* operator new(std::size_t size) { return ::folly_coro_async_malloc(size); } static void operator delete(void* ptr, std::size_t size) { ::folly_coro_async_free(ptr, size); } AsyncGenerator get_return_object() noexcept { return AsyncGenerator{ coroutine_handle::from_promise(*this)}; } suspend_always initial_suspend() noexcept { return {}; } YieldAwaiter final_suspend() noexcept { DCHECK(!hasValue()); return {}; } YieldAwaiter yield_value(Reference&& value) noexcept( std::is_nothrow_move_constructible::value) { DCHECK(state_ == State::INVALID); folly::coro::detail::activate(value_, static_cast(value)); state_ = State::VALUE; return YieldAwaiter{}; } /// In the case where 'Reference' is not actually a reference-type we /// allow implicit conversion from the co_yield argument to Reference. /// However, we don't want to allow this for cases where 'Reference' _is_ /// a reference because this could result in the reference binding to a /// temporary that results from an implicit conversion. template < typename U, std::enable_if_t< !std::is_reference_v && std::is_convertible_v, int> = 0> YieldAwaiter yield_value(U&& value) noexcept( std::is_nothrow_constructible_v) { DCHECK(state_ == State::INVALID); folly::coro::detail::activate(value_, static_cast(value)); state_ = State::VALUE; return {}; } YieldAwaiter yield_value(co_error&& error) noexcept { DCHECK(state_ == State::INVALID); folly::coro::detail::activate( exceptionWrapper_, std::move(error.exception())); state_ = State::EXCEPTION_WRAPPER; return {}; } YieldAwaiter yield_value(co_result&& res) noexcept { if (res.result().hasValue()) { return yield_value(std::move(res.result().value())); } else if (res.result().hasException()) { return yield_value(co_error(res.result().exception())); } else { return_void(); return {}; } } YieldAwaiter yield_value( co_result:: NextResult>&& res) noexcept { DCHECK( res.result().hasValue() || (res.result().hasException() && res.result().exception())); if (res.result().hasException()) { return yield_value(co_error(res.result().exception())); } else if (res.result().hasValue()) { if (res.result()->has_value()) { return yield_value(std::move(res.result()->value())); } else { return_void(); return {}; } } return yield_value(co_error(UsingUninitializedTry{})); } variant_awaitable> await_transform( co_safe_point_t) noexcept { if (cancelToken_.isCancellationRequested()) { return yield_value(co_cancelled); } return ready_awaitable<>{}; } void unhandled_exception() noexcept { DCHECK(state_ == State::INVALID); folly::coro::detail::activate(exceptionWrapper_, current_exception()); state_ = State::EXCEPTION_WRAPPER; } void return_void() noexcept { DCHECK(state_ == State::INVALID); state_ = State::DONE; } // FIXME: Much of this class is currently copy-pasted from `TaskPromiseBase`, // Refactor this to use that, so as to avoid `co_await` behavior divergence. template < typename Awaitable, std::enable_if_t, int> = 0> auto await_transform(Awaitable&& awaitable) { bypassExceptionThrowing_ = bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED ? BypassExceptionThrowing::ACTIVE : BypassExceptionThrowing::INACTIVE; return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync( executor_.get_alias(), folly::coro::co_withCancellation( cancelToken_, static_cast(awaitable)))); } template < typename Awaitable, std::enable_if_t, int> = 0> auto await_transform(Awaitable awaitable) { bypassExceptionThrowing_ = bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED ? BypassExceptionThrowing::ACTIVE : BypassExceptionThrowing::INACTIVE; return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync( executor_.get_alias(), folly::coro::co_withCancellation( cancelToken_, mustAwaitImmediatelyUnsafeMover(std::move(awaitable))()))); } template auto await_transform(NothrowAwaitable awaitable) { bypassExceptionThrowing_ = BypassExceptionThrowing::REQUESTED; return await_transform( mustAwaitImmediatelyUnsafeMover(awaitable.unwrap())()); } auto await_transform(folly::coro::co_current_executor_t) noexcept { return ready_awaitable{executor_.get()}; } auto await_transform(folly::coro::co_current_cancellation_token_t) noexcept { return ready_awaitable{cancelToken_}; } void setCancellationToken(folly::CancellationToken cancelToken) noexcept { // Only keep the first cancellation token. // ie. the inner-most cancellation scope of the consumer's calling // context. if (!hasCancelTokenOverride_) { cancelToken_ = std::move(cancelToken); hasCancelTokenOverride_ = true; } } void setExecutor(folly::Executor::KeepAlive<> executor) noexcept { DCHECK(executor); executor_ = std::move(executor); } void setContinuation(ExtendedCoroutineHandle continuation) noexcept { continuation_ = continuation; } bool hasException() const noexcept { return state_ == State::EXCEPTION_WRAPPER; } folly::exception_wrapper& getException() noexcept { DCHECK(hasException()); return exceptionWrapper_.get(); } void throwIfException() { if (state_ == State::EXCEPTION_WRAPPER) { exceptionWrapper_.get().throw_exception(); } } decltype(auto) getRvalue() noexcept { DCHECK(hasValue()); return std::move(value_).get(); } void clearValue() noexcept { if (hasValue()) { state_ = State::INVALID; folly::coro::detail::deactivate(value_); } else { CHECK(state_ != State::DONE) << "Using generator after receiving completion."; CHECK(state_ != State::EXCEPTION_WRAPPER) << "Using generator after receiving exception."; } } bool hasValue() const noexcept { return state_ == State::VALUE; } folly::AsyncStackFrame& getAsyncFrame() noexcept { return asyncFrame_; } std::pair getErrorHandle( exception_wrapper& ex) final { if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) { auto yieldAwaiter = yield_value(co_error(std::move(ex))); DCHECK(!yieldAwaiter.await_ready()); return { yieldAwaiter.await_suspend( coroutine_handle::from_promise(*this)), // yieldAwaiter.await_suspend pops a frame getAsyncFrame().getParentFrame()}; } return { coroutine_handle::from_promise(*this), nullptr}; } private: friend AsyncGenerator; void clearContext() noexcept { executor_ = {}; cancelToken_ = {}; hasCancelTokenOverride_ = false; asyncFrame_ = {}; } friend coroutine_handle tag_invoke( cpo_t, AsyncGeneratorPromise& p, coroutine_handle scopeExit) noexcept { static_assert( RequiresCleanup, "Only CleanableAsyncGenerator (AsyncGenerator with RequiresCleanup" " template parameter set to true) supports attaching co_scope_exit"); return std::exchange(p.scopeExit_, scopeExit); } enum class State : std::uint8_t { INVALID, VALUE, EXCEPTION_WRAPPER, DONE, }; ExtendedCoroutineHandle continuation_; folly::AsyncStackFrame asyncFrame_; folly::Executor::KeepAlive<> executor_; folly::CancellationToken cancelToken_; union { ManualLifetime exceptionWrapper_; ManualLifetime value_; }; State state_ = State::INVALID; bool hasCancelTokenOverride_ = false; enum class BypassExceptionThrowing : uint8_t { INACTIVE, ACTIVE, REQUESTED, } bypassExceptionThrowing_{BypassExceptionThrowing::INACTIVE}; }; } // namespace detail template auto tag_invoke( cpo_t, CleanableAsyncGenerator&& gen) { return std::move(gen).cleanup(); } } // namespace coro // Standard `AsyncGenerator` coros can easily capture references & other unsafe // aliasing. // // Future: Implement a `coro/safe` generator wrapper, like `async_closure_gen`. template struct safe_alias_of<::folly::coro::AsyncGenerator> : safe_alias_constant {}; } // namespace folly #endif