8000 Find extra global load by PointKernel · Pull Request #688 · NVIDIA/cuCollections · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Find extra global load #688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,63 +46,7 @@ endfunction(ConfigureBench)
### benchmark sources #############################################################################
###################################################################################################

###################################################################################################
# - static_set benchmarks -------------------------------------------------------------------------
ConfigureBench(STATIC_SET_BENCH
static_set/contains_bench.cu
static_set/find_bench.cu
static_set/insert_bench.cu
static_set/retrieve_bench.cu
static_set/retrieve_all_bench.cu
static_set/size_bench.cu
static_set/rehash_bench.cu)

###################################################################################################
# - static_map benchmarks -------------------------------------------------------------------------
ConfigureBench(STATIC_MAP_BENCH
static_map/insert_bench.cu
static_map/find_bench.cu
static_map/contains_bench.cu
static_map/erase_bench.cu
static_map/insert_or_apply_bench.cu)

###################################################################################################
# - static_multiset benchmarks --------------------------------------------------------------------
ConfigureBench(STATIC_MULTISET_BENCH
static_multiset/contains_bench.cu
static_multiset/retrieve_bench.cu
static_multiset/count_bench.cu
static_multiset/find_bench.cu
static_multiset/insert_bench.cu)

###################################################################################################
# - static_multimap benchmarks --------------------------------------------------------------------
ConfigureBench(STATIC_MULTIMAP_BENCH
static_multimap/insert_bench.cu
static_multimap/retrieve_bench.cu
static_multimap/query_bench.cu
static_multimap/count_bench.cu)

###################################################################################################
# - dynamic_map benchmarks ------------------------------------------------------------------------
ConfigureBench(DYNAMIC_MAP_BENCH
dynamic_map/insert_bench.cu
dynamic_map/find_bench.cu
dynamic_map/contains_bench.cu
dynamic_map/erase_bench.cu)

###################################################################################################
# - hash function benchmarks ----------------------------------------------------------------------
ConfigureBench(HASH_FUNCTION_BENCH
hash_function/hash_function_bench.cu)

###################################################################################################
# - hyperloglog benchmarks -----------------------------------------------------------
ConfigureBench(HYPERLOGLOG_BENCH
hyperloglog/hyperloglog_bench.cu)

###################################################################################################
# - bloom_filter benchmarks -----------------------------------------------------------------------
ConfigureBench(BLOOM_FILTER_BENCH
bloom_filter/add_bench.cu
bloom_filter/contains_bench.cu)
2 changes: 1 addition & 1 deletion include/cuco/detail/extent/extent.inl
8000
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ template <int32_t CGSize, int32_t BucketSize, typename SizeType, std::size_t N>
return bucket_extent<SizeType>{static_cast<SizeType>(
*cuco::detail::lower_bound(
cuco::detail::primes.begin(), cuco::detail::primes.end(), static_cast<uint64_t>(size)) *
CGSize)};
CGSize * BucketSize)};
}
if constexpr (N != dynamic_extent) {
return bucket_extent<SizeType,
Expand Down
113 changes: 42 additions & 71 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -430,71 +430,34 @@ class open_addressing_ref_impl {
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value) noexcept
{
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto* data = reinterpret_cast<char*>(storage_ref_.data());

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
value_type bucket_slots[2];
auto const tmp =
*reinterpret_cast<uint4 const*>(data + *probing_iter * sizeof(value_type) * 2);
memcpy(&bucket_slots[0], &tmp, 2 * sizeof(value_type));

auto const [state, intra_bucket_index] = [&]() {
for (auto i = 0; i < bucket_size; ++i) {
switch (
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(bucket_slots[i]))) {
case detail::equal_result::AVAILABLE:
return bucket_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EQUAL: {
if constexpr (allows_duplicates) {
continue;
} else {
return bucket_probing_results{detail::equal_result::EQUAL, i};
}
}
default: continue;
}
}
// returns dummy index `-1` for UNEQUAL
return bucket_probing_results{detail::equal_result::UNEQUAL, -1};
}();
auto const first_slot_is_empty =
detail::bitwise_compare(bucket_slots[0].first, this->empty_key_sentinel());
auto const second_slot_is_empty =
detail::bitwise_compare(bucket_slots[1].first, this->empty_key_sentinel());

if constexpr (not allows_duplicates) {
// If the key is already in the container, return false
if (group.any(state == detail::equal_result::EQUAL)) { return false; }
}
auto const bucket_contains_empty = group.ballot(first_slot_is_empty or second_slot_is_empty);

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
if (bucket_contains_empty) {
auto const src_lane = __ffs(bucket_contains_empty) - 1;
auto status = insert_result::CONTINUE;
if (group.thread_rank() == src_lane) {
if constexpr (SupportsErase) {
status =
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
bucket_slots[intra_bucket_index],
val);
} else {
status =
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
this->empty_slot_sentinel(),
val);
}
status = attempt_insert(bucket_slots, this->empty_slot_sentinel(), val);
}

switch (group.shfl(status, src_lane)) {
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;
}
}
default: continue;
}
if (group.any(status == insert_result::SUCCESS)) { return true; }
} else {
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}
}
Expand Down Expand Up @@ -990,27 +953,35 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ size_type count(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
size_type count = 0;
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
size_type count = 0;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
auto* data = reinterpret_cast<char*>(storage_ref_.data());

auto const state = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto& slot : bucket_slots) {
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot));
if (res == detail::equal_result::EMPTY) { return res; }
count += static_cast<size_type>(res);
}
return res;
}();
if constexpr (has_payload) {
while (true) {
value_type bucket_slots[2];
auto const tmp =
*reinterpret_cast<uint4 const*>(data + *probing_iter * sizeof(value_type) * 2);
memcpy(&bucket_slots[0], &tmp, 2 * sizeof(value_type));

if (group.any(state == detail::equal_result::EMPTY)) { return count; }
++probing_iter;
if (*probing_iter == init_idx) { return count; }
auto const first_slot_is_empty =
detail::bitwise_compare(bucket_slots[0].first, this->empty_key_sentinel());
auto const second_slot_is_empty =
detail::bitwise_compare(bucket_slots[1].first, this->empty_key_sentinel());
auto const first_equals =
(not first_slot_is_empty and predicate_.equal_(key, bucket_slots[0].first));
auto const second_equals =
(not second_slot_is_empty and predicate_.equal_(key, bucket_slots[1].first));

count += (first_equals + second_equals);

if (group.any(first_slot_is_empty or second_slot_is_empty)) { return count; }

++probing_iter;
}
}
return count;
}

/**
Expand Down
13 changes: 13 additions & 0 deletions include/cuco/detail/storage/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,18 @@ CUCO_KERNEL void initialize(BucketT* buckets,
}
}

template <typename BucketT>
CUCO_KERNEL void initialize(BucketT* buckets, cuco::detail::index_type n, BucketT value)
{
auto const loop_stride = cuco::detail::grid_stride();
auto idx = cuco::detail::global_thread_id();

while (idx < n) {
auto& slot = *(buckets + idx);
slot = value;
idx += loop_stride;
}
}

} // namespace detail
} // namespace cuco
34 changes: 34 additions & 0 deletions include/cuco/detail/storage/storage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuco/bucket_storage.cuh>
#include <cuco/flat_storage.cuh>

namespace cuco {
namespace detail {
Expand Down Expand Up @@ -60,5 +61,38 @@ class storage : StorageImpl::template impl<T, Extent, Allocator> {
}
};

template <class StorageImpl, class T, class Extent, class Allocator>
class slot_storage : StorageImpl::template impl<T, Extent, Allocator> {
public:
/// Storage implementation type
using impl_type = typename StorageImpl::template impl<T, Extent, Allocator>;
using ref_type = typename impl_type::ref_type; ///< Storage ref type
using value_type = typename impl_type::value_type; ///< Storage value type
using allocator_type = typename impl_type::allocator_type; ///< Storage value type

/// Number of elements per bucket
static constexpr int bucket_size = impl_type::bucket_size;

using impl_type::allocator;
using impl_type::bucket_extent;
using impl_type::capacity;
using impl_type::data;
using impl_type::initialize;
using impl_type::initialize_async;
using impl_type::num_buckets;
using impl_type::ref;

/**
* @brief Constructs storage.
*
* @param size Number of slots to (de)allocate
* @param allocator Allocator used for (de)allocating device storage
*/
explicit constexpr slot_storage(Extent size, Allocator const& allocator)
: impl_type{size, allocator}
{
}
};

} // namespace detail
} // namespace cuco
Loading
Loading
0