8000 Added compile-time guards to all local and remote container lambdas by bwpriest · Pull Request #138 · LLNL/ygm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Added compile-time guards to all local and remote container lambdas #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/container/disjoint_set_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ int main(int argc, char** argv) {
connected_components.all_compress();

world.cout0("Person : Representative");
connected_components.for_all([&world](const auto& person_rep_pair) {
std::cout << person_rep_pair.first << " : " << person_rep_pair.second
<< std::endl;
connected_components.for_all([&world](const auto& person, const auto& rep) {
std::cout << person << " : " << rep << std::endl;
});
}
17 changes: 15 additions & 2 deletions include/ygm/container/detail/array_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ template <typename Value, typename Index>
class array_impl {
public:
using self_type = array_impl<Value, Index>;
using ptr_type = typename ygm::ygm_ptr<self_type>;
using value_type = Value;
using index_type = Index;

Expand Down Expand Up @@ -111,8 +112,20 @@ class array_impl {
ASSERT_RELEASE(l_index < parray->m_local_vec.size());
value_type &l_value = parray->m_local_vec[l_index];
Visitor *vis = nullptr;
ygm::meta::apply_optional(*vis, std::make_tuple(parray),
std::forward_as_tuple(i, l_value, args...));
if constexpr (std::is_invocable<decltype(visitor), const index_type &,
value_type &, VisitorArgs &...>() ||
std::is_invocable<decltype(visitor), ptr_type,
const index_type &, value_type &,
VisitorArgs &...>()) {
ygm::meta::apply_optional(*vis, std::make_tuple(parray),
std::forward_as_tuple(i, l_value, args...));
} else {
static_assert(
ygm::detail::always_false<>,
"remote array lambda signature must be invocable with (const "
"&index_type, value_type&, ...) or (ptr_type, const "
"&index_type, value_type&, ...) signatures");
}
};

m_comm.async(dest, visit_wrapper, pthis, index,
Expand Down
36 changes: 28 additions & 8 deletions include/ygm/container/detail/disjoint_set_impl.hpp
8000
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ygm/comm.hpp>
#include <ygm/container/detail/hash_partitioner.hpp>
#include <ygm/detail/ygm_ptr.hpp>
#include <ygm/detail/ygm_traits.hpp>

namespace ygm::container::detail {
template <typename Item, typename Partitioner>
Expand Down Expand Up @@ -96,12 +97,12 @@ class disjoint_set_impl {

template <typename Function, typename... FunctionArgs>
void async_union_and_execute(const value_type &a, const value_type &b,
Function fn, const FunctionArgs &... args) {
Function fn, const FunctionArgs &...args) {
// Walking up parent trees can be expressed as a recursive operation
struct simul_parent_walk_functor {
void operator()(self_ygm_ptr_type pdset, const value_type &my_item,
const value_type &other_item, const value_type &orig_a,
const value_type &orig_b, const FunctionArgs &... args) {
const value_type &orig_b, const FunctionArgs &...args) {
const auto my_parent = pdset->local_get_parent(my_item);

// Found root
Expand All @@ -110,9 +111,18 @@ class disjoint_set_impl {

// Perform user function after merge
Function *f = nullptr;
ygm::meta::apply_optional(
*f, std::make_tuple(pdset),
std::forward_as_tuple(orig_a, orig_b, args...));
if constexpr (std::is_invocable<decltype(fn), const value_type &,
const value_type &,
FunctionArgs &...>()) {
ygm::meta::apply_optional(
*f, std::make_tuple(pdset),
std::forward_as_tuple(orig_a, orig_b, args...));
} else {
static_assert(
ygm::detail::always_false<>,
"remote disjoint_set lambda signature must be invocable "
"with (const value_type &, const value_type &) signature");
}

return;
}
Expand Down Expand Up @@ -284,8 +294,18 @@ class disjoint_set_impl {
void for_all(Function fn) {
all_compress();

std::for_each(m_local_item_parent_map.begin(),
m_local_item_parent_map.end(), fn);
if constexpr (std::is_invocable<decltype(fn), const value_type &,
const value_type &>()) {
for (const std::pair<value_type, value_type> &item_rep :
m_local_item_parent_map) {
const auto [item, rep] = item_rep;
fn(item, rep);
}
} else {
static_assert(ygm::detail::always_false<>,
"local disjoint_set lambda signature must be invocable "
"with (const value_type &, const value_type &) signature");
}
}

std::map<value_type, value_type> all_find(
Expand Down Expand Up @@ -317,7 +337,7 @@ class disjoint_set_impl {
pdset->comm().async(
source_rank,
[](ygm_ptr<return_type> p_to_return,
const value_type & source_item,
const value_type &source_item,
const value_type &rep) { (*p_to_return)[source_item] = rep; },
p_to_return, source_item, parent);
} else {
Expand Down
10 changes: 8 additions & 2 deletions include/ygm/container/detail/set_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ygm/comm.hpp>
#include <ygm/container/detail/hash_partitioner.hpp>
#include <ygm/detail/ygm_ptr.hpp>
#include <ygm/detail/ygm_traits.hpp>

namespace ygm::container::detail {
template <typename Key, typename Partitioner = detail::hash_partitioner<Key>,
Expand Down Expand Up @@ -110,10 +111,15 @@ class set_impl {

ygm::comm &comm() { return m_comm; }

// protected:
template <typename Function>
void local_for_all(Function fn) {
std::for_each(m_local_set.begin(), m_local_set.end(), fn);
if constexpr (std::is_invocable<decltype(fn), const key_type &>()) {
std::for_each(m_local_set.begin(), m_local_set.end(), fn);
} else {
static_assert(ygm::detail::always_false<>,
"local set lambda signature must be invocable with (const "
"key_type &) signature");
}
}

int owner(const key_type &key) const {
Expand Down
42 changes: 42 additions & 0 deletions test/test_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,48 @@ int main(int argc, char **argv) {
});
}

// Test async_visit
{
int size = 64;

ygm::container::array<int> arr(world, size);

if (world.rank0()) {
for (int i = 0; i < size; ++i) {
arr.async_set(i, i);
}
}

world.barrier();

for (int i = 0; i < size; ++i) {
arr.async_visit(i, [](const auto index, const auto value) {
ASSERT_RELEASE(value == index);
});
}
}

// Test async_visit (ptr)
{
int size = 64;

ygm::container::array<int> arr(world, size);

if (world.rank0()) {
for (int i = 0; i < size; ++i) {
arr.async_set(i, i);
}
}

world.barrier();

for (int i = 0; i < size; ++i) {
arr.async_visit(i, [](auto ptr, const auto index, const auto value) {
ASSERT_RELEASE(value == index);
});
}
}

// Test value-only for_all
{
int size = 64;
Expand Down
8 changes: 4 additions & 4 deletions test/test_disjoint_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ int main(int argc, char** argv) {
dset.async_union(i, i);
}

dset.for_all([&counter](const auto& item_rep_pair) {
ASSERT_RELEASE(item_rep_pair.first == item_rep_pair.second);
dset.for_all([&counter](const auto& item, const auto& rep) {
ASSERT_RELEASE(item == rep);
++counter;
});

Expand All @@ -182,8 +182,8 @@ int main(int argc, char** argv) {
[](const int u, const int v) { counter++; });
dset.async_union_and_execute(1, 2,
[](const int u, const int v) { counter++; });
dset.async_union_and_execute(3, 4,
[](const int u, const int v) { counter++; });
dset.async_union_and_execute(
3, 4, [](const int u, const int v, const auto thing) { counter++; }, 0);

world.barrier();

Expand Down
19 changes: 18 additions & 1 deletion test/test_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <ygm/comm.hpp>
#include <ygm/container/set.hpp>

int main(int argc, char** argv) {
int main(int argc, char **argv) {
ygm::comm world(&argc, &argv);

//
Expand Down Expand Up @@ -67,5 +67,22 @@ int main(int argc, char** argv) {
ASSERT_RELEASE(sset.count("car") == 1);
}

//
// Test for_all
{
ygm::container::set<std::string> sset1(world);
ygm::container::set<std::string> sset2(world);

sset1.async_insert("dog");
sset1.async_insert("apple");
sset1.async_insert("red");

sset1.for_all([&sset2](const auto &key) { sset2.async_insert(key); });

ASSERT_RELEASE(sset2.count("dog") == 1);
ASSERT_RELEASE(sset2.count("apple") == 1);
ASSERT_RELEASE(sset2.count("red") == 1);
}

return 0;
}
0