8000 Calling coro::sync_wait with coro::ring_buffer::consume returns defau… by jbaldwin · Pull Request #244 · jbaldwin/libcoro · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Calling coro::sync_wait with coro::ring_buffer::consume returns defau… #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions include/coro/ring_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@

namespace coro
{
namespace rb
{
enum class produce_result
{
produced,
ring_buffer_stopped
};

enum class consume_result
{
ring_buffer_stopped
};
} // namespace rb

/**
* @tparam element The type of element the ring buffer will store. Note that this type should be
* cheap to move if possible as it is moved into and out of the buffer upon produce and
Expand All @@ -20,17 +34,6 @@ template<typename element, size_t num_elements>
class ring_buffer
{
public:
enum class produce_result
{
produced,
ring_buffer_stopped
};

enum class consume_result
{
ring_buffer_stopped
};

/**
* static_assert If `num_elements` == 0.
*/
Expand Down Expand Up @@ -62,7 +65,7 @@ class ring_buffer
{
std::unique_lock lk{m_rb.m_mutex};
// Its possible a consumer on another thread consumed an item between await_ready() and await_suspend()
// so we must check to see if tehre is space again.
// so we must check to see if there is space again.
if (m_rb.try_produce_locked(lk, m_e))
{
return false;
Expand All @@ -84,9 +87,9 @@ class ring_buffer
/**
* @return produce_result
*/
auto await_resume() -> produce_result
auto await_resume() -> rb::produce_result
{
return !m_stopped ? produce_result::produced : produce_result::ring_buffer_stopped;
return !m_stopped ? rb::produce_result::produced : rb::produce_result::ring_buffer_stopped;
}

private:
Expand Down Expand Up @@ -140,11 +143,11 @@ class ring_buffer
/**
* @return The consumed element or std::nullopt if the consume has failed.
*/
auto await_resume() -> expected<element, consume_result>
auto await_resume() -> expected<element, rb::consume_result>
{
if (m_stopped)
{
return unexpected<consume_result>(consume_result::ring_buffer_stopped);
return unexpected<rb::consume_result>(rb::consume_result::ring_buffer_stopped);
}

return std::move(m_e);
Expand Down
185 changes: 154 additions & 31 deletions include/coro/sync_wait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
#include <condition_variable>
#include <exception>
#include <mutex>
#include <variant>

namespace coro
{
namespace detail
{

struct unset_return_value
{
unset_return_value() {}
unset_return_value(unset_return_value&&) = delete;
unset_return_value(const unset_return_value&) = delete;
auto operator=(unset_return_value&&) = delete;
auto operator=(const unset_return_value&) = delete;
};

class sync_wait_event
{
public:
Expand All @@ -28,7 +39,7 @@ class sync_wait_event
private:
std::mutex m_mutex;
std::condition_variable m_cv;
bool m_set{false};
std::atomic<bool> m_set{false};
};

class sync_wait_task_promise_base
Expand All @@ -53,8 +64,19 @@ class sync_wait_task_promise : public sync_wait_task_promise_base
public:
using coroutine_type = std::coroutine_handle<sync_wait_task_promise<return_type>>;

sync_wait_task_promise() noexcept = default;
~sync_wait_task_promise() = default;
static constexpr bool return_type_is_reference = std::is_reference_v<return_type>;
using stored_type = std::conditional_t<
return_type_is_reference,
std::remove_reference_t<return_type>*,
std::remove_const_t<return_type>>;
using variant_type = std::variant<unset_return_value, stored_type, std::exception_ptr>;

sync_wait_task_promise() noexcept = default;
sync_wait_task_promise(const sync_wait_task_promise&) = delete;
sync_wait_task_promise(sync_wait_task_promise&&) = delete;
auto operator=(const sync_wait_task_promise&) -> sync_wait_task_promise& = delete;
auto operator=(sync_wait_task_promise&&) -> sync_wait_task_promise& = delete;
~sync_wait_task_promise() = default;

auto start(sync_wait_event& event)
{
Expand All @@ -64,10 +86,32 @@ class sync_wait_task_promise : public sync_wait_task_promise_base

auto get_return_object() noexcept { return coroutine_type::from_promise(*this); }

auto yield_value(return_type&& value) noexcept
template<typename value_type>
requires(return_type_is_reference and std::is_constructible_v<return_type, value_type&&>) or
(not return_type_is_reference and
std::is_constructible_v<stored_type, value_type&&>) auto return_value(value_type&& value) -> void
{
m_return_value = std::addressof(value);
return final_suspend();
if constexpr (return_type_is_reference)
{
return_type ref = static_cast<value_type&&>(value);
m_storage.template emplace<stored_type>(std::addressof(ref));
}
else
{
m_storage.template emplace<stored_type>(std::forward<value_type>(value));
}
}

auto return_value(stored_type value) -> void requires(not return_type_is_reference)
{
if constexpr (std::is_move_constructible_v<stored_type>)
{
m_storage.template emplace<stored_type>(std::move(value));
}
else
{
m_storage.template emplace<stored_type>(value);
}
}

auto final_suspend() noexcept
Expand All @@ -82,19 +126,81 @@ class sync_wait_task_promise : public sync_wait_task_promise_base
return completion_notifier{};
}

auto result() -> return_type&&
auto result() & -> decltype(auto)
{
if (m_exception)
if (std::holds_alternative<stored_type>(m_storage))
{
std::rethrow_exception(m_exception);
if constexpr (return_type_is_reference)
{
return static_cast<return_type>(*std::get<stored_type>(m_storage));
}
else
{
return static_cast<const return_type&>(std::get<stored_type>(m_storage));
}
}
else if (std::holds_alternative<std::exception_ptr>(m_storage))
{
std::rethrow_exception(std::get<std::exception_ptr>(m_storage));
}
else
{
throw std::runtime_error{"The return value was never set, did you execute the coroutine?"};
}
}

auto result() const& -> decltype(auto)
{
if (std::holds_alternative<stored_type>(m_storage))
{
if constexpr (return_type_is_reference)
{
return static_cast<std::add_const_t<return_type>>(*std::get<stored_type>(m_storage));
}
else
{
return static_cast<const return_type&>(std::get<stored_type>(m_storage));
}
}
else if (std::holds_alternative<std::exception_ptr>(m_storage))
{
std::rethrow_exception(std::get<std::exception_ptr>(m_storage));
}
else
{
throw std::runtime_error{"The return value was never set, did you execute the coroutine?"};
}
}

return static_cast<return_type&&>(*m_return_value);
auto result() && -> decltype(auto)
{
if (std::holds_alternative<stored_type>(m_storage))
{
if constexpr (return_type_is_reference)
{
return static_cast<return_type>(*std::get<stored_type>(m_storage));
}
else if constexpr (std::is_assignable_v<return_type, stored_type>)
{
return static_cast<return_type&&>(std::get<stored_type>(m_storage));
}
else
{
return static_cast<const return_type&&>(std::get<stored_type>(m_storage));
}
}
else if (std::holds_alternative<std::exception_ptr>(m_storage))
{
std::rethrow_exception(std::get<std::exception_ptr>(m_storage));
}
else
{
throw std::runtime_error{"The return value was never set, did you execute the coroutine?"};
}
}
void return_void() noexcept {}

private:
std::remove_reference_t<return_type>* m_return_value;
variant_type m_storage{};
};

template<>
Expand Down Expand Up @@ -167,21 +273,9 @@ class sync_wait_task
}
}

auto start(sync_wait_event& event) noexcept { m_coroutine.promise().start(event); }

auto return_value() -> decltype(auto)
{
if constexpr (std::is_same_v<void, return_type>)
{
// Propagate exceptions.
m_coroutine.promise().result();
return;
}
else
{
return m_coroutine.promise().result();
}
}
auto promise() & -> promise_type& { return m_coroutine.promise(); }
auto promise() const& -> const promise_type& { return m_coroutine.promise(); }
auto promise() && -> promise_type&& { return std::move(m_coroutine.promise()); }

private:
coroutine_type m_coroutine;
Expand All @@ -202,21 +296,50 @@ static auto make_sync_wait_task(awaitable_type&& a) -> sync_wait_task<return_typ
}
else
{
co_yield co_await std::forward<awaitable_type>(a);
co_return co_await std::forward<awaitable_type>(a);
}
}

} // namespace detail

template<concepts::awaitable awaitable_type>
template<
concepts::awaitable awaitable_type,
typename return_type = typename concepts::awaitable_traits<awaitable_type>::awaiter_return_type>
auto sync_wait(awaitable_type&& a) -> decltype(auto)
{
detail::sync_wait_event e{};
auto task = detail::make_sync_wait_task(std::forward<awaitable_type>(a));
task.start(e);
task.promise().start(e);
e.wait();

return task.return_value();
if constexpr (std::is_void_v<return_type>)
{
task.promise().result();
return;
}
else if constexpr (std::is_reference_v<return_type>)
{
return task.promise().result();
}
else if constexpr (std::is_move_assignable_v<return_type>)
{
// issue-242
// For non-trivial types (or possibly types that don't fit in a register)
// the compiler will end up calling the ~return_type() when the promise
// is destructed at the end of sync_wait(). This causes the return_type
// object to also be destructed causingn the final return/move from
// sync_wait() to be a 'use after free' bug. To work around this the result
// must be moved off the promise object before the promise is destructed.
// Other solutions could be heap allocating the return_type but that has
// other downsides, for now it is determined that a double move is an
// acceptable solution to work around this bug.
auto result = std::move(task).promise().result();
return result;
}
else
{
return task.promise().result();
}
}

} // namespace coro
13 changes: 4 additions & 9 deletions src/sync_wait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,19 @@ sync_wait_event::sync_wait_event(bool initially_set) : m_set(initially_set)

auto sync_wait_event::set() noexcept -> void
{
{
std::lock_guard<std::mutex> g{m_mutex};
m_set = true;
}

m_set.exchange(true, std::memory_order::release);
m_cv.notify_all();
}

auto sync_wait_event::reset() noexcept -> void
{
std::lock_guard<std::mutex> g{m_mutex};
m_set = false;
m_set.exchange(false, std::memory_order::release);
}

auto sync_wait_event::wait() noexcept -> void
{
std::unique_lock<std::mutex> lk{m_mutex};
m_cv.wait(lk, [this] { return m_set; });
m_cv.wait(lk, [this] { return m_set.load(std::memory_order::acquire); });
}

} // namespace coro::detail
} // namespace coro::detail
Loading
0