From 89b4dda53d91587278c7a59f4b9b400edab330b8 Mon Sep 17 00:00:00 2001 From: jbaldwin Date: Fri, 2 Feb 2024 09:57:54 -0700 Subject: [PATCH] Calling coro::sync_wait with coro::ring_buffer::consume returns default constructed objects for complex return values Running the test that replicates the bug via valgrind or asan it shows that the compiler is calling sync_wait()'s promise's destructor before moving the complex object return value. This causes the object to be destructed and then the final move out is on deleted (use after free) memory causing the object to be in a bad empty state. To resolve this for now a double move has been introduced to move the complex object off the promise object and onto the sync_wait() function callstack. This seems to keep the object alive and not be destructed when sync_wait() finally returns. Other solutions could be to heap allocate the promise() object or even the return_type but this is probably more expensive than a double move, but will vary on use case. Closes #242 --- include/coro/ring_buffer.hpp | 35 ++++--- include/coro/sync_wait.hpp | 185 +++++++++++++++++++++++++++------ src/sync_wait.cpp | 13 +-- test/test_ring_buffer.cpp | 191 +++++++++++++++++++++++++++++++++++ 4 files changed, 368 insertions(+), 56 deletions(-) diff --git a/include/coro/ring_buffer.hpp b/include/coro/ring_buffer.hpp index ae359808..3e5149dc 100644 --- a/include/coro/ring_buffer.hpp +++ b/include/coro/ring_buffer.hpp @@ -10,6 +10,20 @@ namespace coro { +namespace rb +{ +enum class produce_result +{ + produced, + ring_buffer_stopped +}; + +enum class consume_result +{ + ring_buffer_stopped +}; +} // namespace rb + /** * @tparam element The type of element the ring buffer will store. Note that this type should be * cheap to move if possible as it is moved into and out of the buffer upon produce and @@ -20,17 +34,6 @@ template class ring_buffer { public: - enum class produce_result - { - produced, - ring_buffer_stopped - }; - - enum class consume_result - { - ring_buffer_stopped - }; - /** * static_assert If `num_elements` == 0. */ @@ -62,7 +65,7 @@ class ring_buffer { std::unique_lock lk{m_rb.m_mutex}; // Its possible a consumer on another thread consumed an item between await_ready() and await_suspend() - // so we must check to see if tehre is space again. + // so we must check to see if there is space again. if (m_rb.try_produce_locked(lk, m_e)) { return false; @@ -84,9 +87,9 @@ class ring_buffer /** * @return produce_result */ - auto await_resume() -> produce_result + auto await_resume() -> rb::produce_result { - return !m_stopped ? produce_result::produced : produce_result::ring_buffer_stopped; + return !m_stopped ? rb::produce_result::produced : rb::produce_result::ring_buffer_stopped; } private: @@ -140,11 +143,11 @@ class ring_buffer /** * @return The consumed element or std::nullopt if the consume has failed. */ - auto await_resume() -> expected + auto await_resume() -> expected { if (m_stopped) { - return unexpected(consume_result::ring_buffer_stopped); + return unexpected(rb::consume_result::ring_buffer_stopped); } return std::move(m_e); diff --git a/include/coro/sync_wait.hpp b/include/coro/sync_wait.hpp index c07b2425..333d187b 100644 --- a/include/coro/sync_wait.hpp +++ b/include/coro/sync_wait.hpp @@ -6,11 +6,22 @@ #include #include #include +#include namespace coro { namespace detail { + +struct unset_return_value +{ + unset_return_value() {} + unset_return_value(unset_return_value&&) = delete; + unset_return_value(const unset_return_value&) = delete; + auto operator=(unset_return_value&&) = delete; + auto operator=(const unset_return_value&) = delete; +}; + class sync_wait_event { public: @@ -28,7 +39,7 @@ class sync_wait_event private: std::mutex m_mutex; std::condition_variable m_cv; - bool m_set{false}; + std::atomic m_set{false}; }; class sync_wait_task_promise_base @@ -53,8 +64,19 @@ class sync_wait_task_promise : public sync_wait_task_promise_base public: using coroutine_type = std::coroutine_handle>; - sync_wait_task_promise() noexcept = default; - ~sync_wait_task_promise() = default; + static constexpr bool return_type_is_reference = std::is_reference_v; + using stored_type = std::conditional_t< + return_type_is_reference, + std::remove_reference_t*, + std::remove_const_t>; + using variant_type = std::variant; + + sync_wait_task_promise() noexcept = default; + sync_wait_task_promise(const sync_wait_task_promise&) = delete; + sync_wait_task_promise(sync_wait_task_promise&&) = delete; + auto operator=(const sync_wait_task_promise&) -> sync_wait_task_promise& = delete; + auto operator=(sync_wait_task_promise&&) -> sync_wait_task_promise& = delete; + ~sync_wait_task_promise() = default; auto start(sync_wait_event& event) { @@ -64,10 +86,32 @@ class sync_wait_task_promise : public sync_wait_task_promise_base auto get_return_object() noexcept { return coroutine_type::from_promise(*this); } - auto yield_value(return_type&& value) noexcept + template + requires(return_type_is_reference and std::is_constructible_v) or + (not return_type_is_reference and + std::is_constructible_v) auto return_value(value_type&& value) -> void { - m_return_value = std::addressof(value); - return final_suspend(); + if constexpr (return_type_is_reference) + { + return_type ref = static_cast(value); + m_storage.template emplace(std::addressof(ref)); + } + else + { + m_storage.template emplace(std::forward(value)); + } + } + + auto return_value(stored_type value) -> void requires(not return_type_is_reference) + { + if constexpr (std::is_move_constructible_v) + { + m_storage.template emplace(std::move(value)); + } + else + { + m_storage.template emplace(value); + } } auto final_suspend() noexcept @@ -82,19 +126,81 @@ class sync_wait_task_promise : public sync_wait_task_promise_base return completion_notifier{}; } - auto result() -> return_type&& + auto result() & -> decltype(auto) { - if (m_exception) + if (std::holds_alternative(m_storage)) { - std::rethrow_exception(m_exception); + if constexpr (return_type_is_reference) + { + return static_cast(*std::get(m_storage)); + } + else + { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) + { + std::rethrow_exception(std::get(m_storage)); + } + else + { + throw std::runtime_error{"The return value was never set, did you execute the coroutine?"}; + } + } + + auto result() const& -> decltype(auto) + { + if (std::holds_alternative(m_storage)) + { + if constexpr (return_type_is_reference) + { + return static_cast>(*std::get(m_storage)); + } + else + { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) + { + std::rethrow_exception(std::get(m_storage)); } + else + { + throw std::runtime_error{"The return value was never set, did you execute the coroutine?"}; + } + } - return static_cast(*m_return_value); + auto result() && -> decltype(auto) + { + if (std::holds_alternative(m_storage)) + { + if constexpr (return_type_is_reference) + { + return static_cast(*std::get(m_storage)); + } + else if constexpr (std::is_assignable_v) + { + return static_cast(std::get(m_storage)); + } + else + { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) + { + std::rethrow_exception(std::get(m_storage)); + } + else + { + throw std::runtime_error{"The return value was never set, did you execute the coroutine?"}; + } } - void return_void() noexcept {} private: - std::remove_reference_t* m_return_value; + variant_type m_storage{}; }; template<> @@ -167,21 +273,9 @@ class sync_wait_task } } - auto start(sync_wait_event& event) noexcept { m_coroutine.promise().start(event); } - - auto return_value() -> decltype(auto) - { - if constexpr (std::is_same_v) - { - // Propagate exceptions. - m_coroutine.promise().result(); - return; - } - else - { - return m_coroutine.promise().result(); - } - } + auto promise() & -> promise_type& { return m_coroutine.promise(); } + auto promise() const& -> const promise_type& { return m_coroutine.promise(); } + auto promise() && -> promise_type&& { return std::move(m_coroutine.promise()); } private: coroutine_type m_coroutine; @@ -202,21 +296,50 @@ static auto make_sync_wait_task(awaitable_type&& a) -> sync_wait_task(a); + co_return co_await std::forward(a); } } } // namespace detail -template +template< + concepts::awaitable awaitable_type, + typename return_type = typename concepts::awaitable_traits::awaiter_return_type> auto sync_wait(awaitable_type&& a) -> decltype(auto) { detail::sync_wait_event e{}; auto task = detail::make_sync_wait_task(std::forward(a)); - task.start(e); + task.promise().start(e); e.wait(); - return task.return_value(); + if constexpr (std::is_void_v) + { + task.promise().result(); + return; + } + else if constexpr (std::is_reference_v) + { + return task.promise().result(); + } + else if constexpr (std::is_move_assignable_v) + { + // issue-242 + // For non-trivial types (or possibly types that don't fit in a register) + // the compiler will end up calling the ~return_type() when the promise + // is destructed at the end of sync_wait(). This causes the return_type + // object to also be destructed causingn the final return/move from + // sync_wait() to be a 'use after free' bug. To work around this the result + // must be moved off the promise object before the promise is destructed. + // Other solutions could be heap allocating the return_type but that has + // other downsides, for now it is determined that a double move is an + // acceptable solution to work around this bug. + auto result = std::move(task).promise().result(); + return result; + } + else + { + return task.promise().result(); + } } } // namespace coro diff --git a/src/sync_wait.cpp b/src/sync_wait.cpp index 61b287f5..e610ad25 100644 --- a/src/sync_wait.cpp +++ b/src/sync_wait.cpp @@ -8,24 +8,19 @@ sync_wait_event::sync_wait_event(bool initially_set) : m_set(initially_set) auto sync_wait_event::set() noexcept -> void { - { - std::lock_guard g{m_mutex}; - m_set = true; - } - + m_set.exchange(true, std::memory_order::release); m_cv.notify_all(); } auto sync_wait_event::reset() noexcept -> void { - std::lock_guard g{m_mutex}; - m_set = false; + m_set.exchange(false, std::memory_order::release); } auto sync_wait_event::wait() noexcept -> void { std::unique_lock lk{m_mutex}; - m_cv.wait(lk, [this] { return m_set; }); + m_cv.wait(lk, [this] { return m_set.load(std::memory_order::acquire); }); } -} // namespace coro::detail \ No newline at end of file +} // namespace coro::detail diff --git a/test/test_ring_buffer.cpp b/test/test_ring_buffer.cpp index 2cda5d22..39fe6ebb 100644 --- a/test/test_ring_buffer.cpp +++ b/test/test_ring_buffer.cpp @@ -173,3 +173,194 @@ TEST_CASE("ring_buffer producer consumer separate threads", "[ring_buffer]") REQUIRE(rb.empty()); } + +TEST_CASE("ring_buffer issue-242 default constructed complex objects on consume", "[ring_buffer]") +{ + struct message + { + message(uint32_t i, std::string t) : id(i), text(std::move(t)) {} + message(const message&) = delete; + message(message&& other) : id(other.id), text(std::move(other.text)) {} + auto operator=(const message&) -> message& = delete; + auto operator=(message&& other) -> message& + { + if (std::addressof(other) != this) + { + this->id = std::exchange(other.id, 0); + this->text = std::move(other.text); + } + + return *this; + } + + ~message() { id = 0; } + + uint32_t id; + std::string text; + }; + + struct example + { + example() { std::cerr << "I'm being created\n"; } + example(const example&) = delete; + example(example&& other) : msg(std::move(other.msg)) + { + std::cerr << "i'm being moved constructed with msg = "; + if (msg.has_value()) + { + std::cerr << "id = " << msg.value().id << ", msg = " << msg.value().text << "\n"; + } + else + { + std::cerr << "nullopt\n"; + } + } + + ~example() + { + std::cerr << "I'm being deleted with msg = "; + if (msg.has_value()) + { + std::cerr << "id = " << msg.value().id << ", msg = " << msg.value().text << "\n"; + } + else + { + std::cerr << "nullopt\n"; + } + } + + auto operator=(const example&) -> example& = delete; + auto operator=(example&& other) -> example& + { + if (std::addressof(other) != this) + { + this->msg = std::move(other.msg); + + std::cerr << "i'm being moved assigned with msg = "; + if (msg.has_value()) + { + std::cerr << msg.value().id << ", " << msg.value().text << "\n"; + } + else + { + std::cerr << "nullopt\n"; + } + } + + return *this; + } + + std::optional msg{std::nullopt}; + }; + + coro::ring_buffer buffer; + + const auto produce = [&buffer]() -> coro::task + { + std::cerr << "enter produce coroutine\n"; + example data{}; + data.msg = {message{1, "Hello World!"}}; + std::cerr << "ID: " << data.msg.value().id << "\n"; + std::cerr << "Text: " << data.msg.value().text << "\n"; + std::cerr << "buffer.produce(move(data)) start\n"; + auto result = co_await buffer.produce(std::move(data)); + std::cerr << "buffer.produce(move(data)) done\n"; + REQUIRE(result == coro::rb::produce_result::produced); + std::cerr << "exit produce coroutine\n"; + co_return; + }; + + coro::sync_wait(produce()); + std::cerr << "enter sync_wait\n"; + auto result = coro::sync_wait(buffer.consume()); + std::cerr << "exit sync_wait\n"; + REQUIRE(result); + + auto& data = result.value(); + REQUIRE(data.msg.has_value()); + REQUIRE(data.msg.value().id == 1); + REQUIRE(data.msg.value().text == "Hello World!"); + std::cerr << "Outside the coroutine\n"; + std::cerr << "ID: " << data.msg.value().id << "\n"; + std::cerr << "Text: " << data.msg.value().text << "\n"; +} + +TEST_CASE("ring_buffer issue-242 default constructed complex objects on consume in coroutines", "[ring_buffer]") +{ + struct message + { + uint32_t id; + std::string text; + }; + + struct example + { + example() {} + example(const example&) = delete; + example(example&& other) : msg(std::move(other.msg)) {} + + auto operator=(const example&) -> example& = delete; + auto operator=(example&& other) -> example& + { + if (std::addressof(other) != this) + { + this->msg = std::move(other.msg); + } + + return *this; + } + + std::optional msg{std::nullopt}; + }; + + coro::ring_buffer buffer; + + const auto produce = [&buffer]() -> coro::task + { + example data{}; + data.msg = {message{.id = 1, .text = "Hello World!"}}; + std::cout << "Inside the coroutine\n"; + std::cout << "ID: " << data.msg.value().id << "\n"; + std::cout << "Text: " << data.msg.value().text << "\n"; + auto result = co_await buffer.produce(std::move(data)); + REQUIRE(result == coro::rb::produce_result::produced); + co_return; + }; + + const auto consume = [&buffer]() -> coro::task + { + auto result = co_await buffer.consume(); + REQUIRE(result.has_value()); + REQUIRE(result.value().msg.has_value()); + auto data = std::move(*result); + co_return std::move(data); + }; + + coro::sync_wait(produce()); + auto data = coro::sync_wait(consume()); + + REQUIRE(data.msg.has_value()); + REQUIRE(data.msg.value().id == 1); + REQUIRE(data.msg.value().text == "Hello World!"); + std::cout << "Outside the coroutine\n"; + std::cout << "ID: " << data.msg.value().id << "\n"; + std::cout << "Text: " << data.msg.value().text << "\n"; +} + +TEST_CASE("ring_buffer issue-242 basic type", "[ring_buffer]") +{ + coro::ring_buffer buffer; + + const auto foo = [&buffer]() -> coro::task + { + co_await buffer.produce(1); + co_return; + }; + + coro::sync_wait(foo()); + auto result = coro::sync_wait(buffer.consume()); + REQUIRE(result); + + auto data = std::move(*result); + REQUIRE(data == 1); +}