8000 Accept only unique_ptr for distributed vector constructors by MarcelKoch · Pull Request #1284 · ginkgo-project/ginkgo · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Accept only unique_ptr for distributed vector constructors #1284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions benchmark/solver/distributed/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ struct Generator : public DistributedDefaultSystemGenerator<SolverGenerator> {
{
return Vec::create(
exec, comm, gko::dim<2>{system_matrix->get_size()[0], FLAGS_nrhs},
gko::as<typename LocalGenerator::Vec>(
local_generator.generate_rhs(
exec, gko::as<Mtx>(system_matrix)->get_local_matrix().get(),
config))
.get());
local_generator.generate_rhs(
exec, gko::as<Mtx>(system_matrix)->get_local_matrix().get(),
config));
}

std::unique_ptr<Vec> generate_initial_guess(
Expand All @@ -69,11 +67,9 @@ struct Generator : public DistributedDefaultSystemGenerator<SolverGenerator> {
{
return Vec::create(
exec, comm, gko::dim<2>{rhs->get_size()[0], FLAGS_nrhs},
gko::as<typename LocalGenerator::Vec>(
local_generator.generate_initial_guess(
exec, gko::as<Mtx>(system_matrix)->get_local_matrix().get(),
rhs->get_local_vector()))
.get());
local_generator.generate_initial_guess(
exec, gko::as<Mtx>(system_matrix)->get_local_matrix().get(),
rhs->get_local_vector()));
}
};

Expand Down
5 changes: 3 additions & 2 deletions benchmark/utils/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ struct DistributedDefaultSystemGenerator {
auto global_rows = local->get_size()[0];
comm.all_reduce(gko::ReferenceExecutor::create(), &global_rows, 1,
MPI_SUM);
return Vec::create(
exec, comm, gko::dim<2>{global_rows, local->get_size()[1]}, local);
return Vec::create(exec, comm,
gko::dim<2>{global_rows, local->get_size()[1]},
std::move(local));
}

gko::experimental::mpi::communicator comm;
Expand Down
37 changes: 31 additions & 6 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm, dim<2> global_size,
ptr_param<local_vector_type> local_vector)
std::unique_ptr<local_vector_type> local_vector)
: EnableDistributedLinOp<Vector<ValueType>>{exec, global_size},
DistributedBase{comm},
local_{exec}
Expand All @@ -114,7 +114,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm,
ptr_param<local_vector_type> local_vector)
std::unique_ptr<local_vector_type> local_vector)
: EnableDistributedLinOp<Vector<ValueType>>{exec, {}},
DistributedBase{comm},
local_{exec}
Expand All @@ -124,6 +124,32 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
}


template <typename ValueType>
std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
dim<2> global_size, std::unique_ptr<const local_vector_type> local_vector)
{
auto non_const_local_vector =
const_cast<local_vector_type*>(local_vector.release());

return std::unique_ptr<const Vector<ValueType>>(new Vector<ValueType>(
std::move(exec), std::move(comm), global_size,
std::unique_ptr<local_vector_type>{non_const_local_vector}));
}


template <typename ValueType>
std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
std::unique_ptr<const local_vector_type> local_vector)
{
auto global_size =
compute_global_size(exec, comm, local_vector->get_size());
return Vector<ValueType>::create_const(
std::move(exec), std::move(comm), global_size, std::move(local_vector));
}


template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(
ptr_param<const Vector> other)
Expand Down Expand Up @@ -585,10 +611,9 @@ Vector<ValueType>::create_real_view() const
const auto num_cols =
is_complex<ValueType>() ? 2 * this->get_size()[1] : this->get_size()[1];

return real_type::create(this->get_executor(), this->get_communicator(),
dim<2>{num_global_rows, num_cols},
const_cast<typename real_type::local_vector_type*>(
local_.create_real_view().get()));
return real_type::create_const(
this->get_executor(), this->get_communicator(),
dim<2>{num_global_rows, num_cols}, local_.create_real_view());
}


Expand Down
47 changes: 40 additions & 7 deletions include/ginkgo/core/distributed/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,35 @@ class Vector

size_type get_stride() const noexcept { return local_.get_stride(); }

/**
* Creates a constant (immutable) distributed Vector from a constant local
* vector.
*
* @param exec Executor associated with this vector
* @param comm Communicator associated with this vector
* @param global_size The global size of the vector
* @param local_vector The underlying local vector, of which a view is
* created
*/
static std::unique_ptr<const Vector> create_const(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
dim<2> global_size,
std::unique_ptr<const local_vector_type> local_vector);

/**
* Creates a constant (immutable) distributed Vector from a constant local
* vector. The global size will be deduced from the local sizes, which will
* incur a collective communication.
*
* @param exec Executor associated with this vector
* @param comm Communicator associated with this vector
* @param local_vector The underlying local vector, of which a view is
* created
*/
static std::unique_ptr<const Vector> create_const(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
std::unique_ptr<const local_vector_type> local_vector);
Comment on lines +486 to +488
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are missing for this and the previous overload of create_const.


protected:
/**
* Creates an empty distributed vector with a specified size
Expand Down Expand Up @@ -488,8 +517,10 @@ class Vector
* Creates a distributed vector from local vectors with a specified size.
*
* @note The data form the local_vector will be moved into the new
* distributed vector. This means, access to local_vector
* will be invalid after 6D4E this call.
* distributed vector. You could either move in a std::unique_ptr
* directly, copy a local vector with gko::clone, or create a
* unique non-owining view of a given local vector with
* gko::make_dense_view.
*
* @param exec Executor associated with this vector
* @param comm Communicator associated with this vector
Expand All @@ -498,24 +529,26 @@ class Vector
* into this
*/
Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
dim<2> global_size, ptr_param<local_vector_type> local_vector);
dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);

/**
* Creates a distributed vector from local vectors. The global size will
* be deduced from the local sizes, which will incur a collective
* communication.
*
* @note The data form the local_vector will be moved into the new
* distributed vector. This means, access to local_vector
* will be invalid after this call.
* distributed vector. You could either move in a std::unique_ptr
* directly, copy a local vector with gko::clone, or create a
* unique non-owining view of a given local vector with
* gko::make_dense_view.
*
* @param exec Executor associated with this vector
* @param comm Communicator associated with this vector
* @param local_vector The underlying local vector, the data will be moved
* into this
* into this.
*/
Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
ptr_param<local_vector_type> local_vector);
std::unique_ptr<local_vector_type> local_vector);

void resize(dim<2> global_size, dim<2> local_size);

Expand Down
44 changes: 38 additions & 6 deletions test/mpi/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,12 @@ TYPED_TEST(VectorCreation, CanCreateFromLocalVectorAndSize)
using dense_type = typename TestFixture::dense_type;
auto local_vec = dense_type::create(this->exec);
local_vec->read(this->md_localized[this->comm.rank()]);
auto clone_local_vec = gko::clone(local_vec);

auto vec = dist_vec_type::create(this->exec, this->comm, gko::dim<2>{6, 2},
local_vec);
gko::clone(local_vec));

GKO_ASSERT_EQUAL_DIMENSIONS(vec, gko::dim<2>(6, 2));
GKO_ASSERT_MTX_NEAR(vec->get_local_vector(), clone_local_vec, 0);
GKO_ASSERT_MTX_NEAR(vec->get_local_vector(), local_vec, 0);
}


Expand All @@ -329,15 +328,48 @@ TYPED_TEST(VectorCreation, CanCreateFromLocalVectorWithoutSize)
using dense_type = typename TestFixture::dense_type;
auto local_vec = dense_type::create(this->exec);
local_vec->read(this->md_localized[this->comm.rank()]);
auto clone_local_vec = gko::clone(local_vec);

auto vec = dist_vec_type::create(this->exec, this->comm, local_vec);
auto vec =
dist_vec_type::create(this->exec, this->comm, gko::clone(local_vec));

GKO_ASSERT_EQUAL_DIMENSIONS(vec, gko::dim<2>(6, 2));
GKO_ASSERT_MTX_NEAR(vec->get_local_vector(), clone_local_vec, 0);
GKO_ASSERT_MTX_NEAR(vec->get_local_vector(), local_vec, 0);
}


TYPED_TEST(VectorCreation, CanConstCreateFromLocalVectorAndSize)
{
using dist_vec_type = typename TestFixture::dist_vec_type;
using dense_type = typename TestFixture::dense_type;
auto local_vec = dense_type::create(this->exec);
local_vec->read(this->md_localized[this->comm.rank()]);

auto vec = dist_vec_type::create_const(
this->exec, this->comm, gko::dim<2>{6, 2}, gko::clone(local_vec));

ASSERT_TRUE(std::is_const<
typename std::remove_reference<decltype(*vec)>::type>::value);
GKO_ASSERT_EQUAL_DIMENSIONS(vec, gko::dim<2>(6, 2));
GKO_ASSERT_MTX_NEAR(vec->get_local_vector(), local_vec, 0);
}


TYPED_TEST(VectorCreation, CanConstCreateFromLocalVectorWithoutSize)
{
using dist_vec_type = typename TestFixture::dist_vec_type;
using dense_type = typename TestFixture::dense_type;
auto local_vec = dense_type::create(this->exec);
local_vec->read(this->md_localized[this->comm.rank()]);

auto vec = dist_vec_type::create_const(this->exec, this->comm,
gko::clone(local_vec));

ASSERT_TRUE(std::is_const<
typename std::remove_reference<decltype(*vec)>::type>::value);
GKO_ASSERT_EQUAL_DIMENSIONS(vec, gko::dim<2>(6, 2));
GKO_ASSERT_MTX_NEAR(vec->get_local_vector(), local_vec, 0);
}

template <typename ValueType>
class VectorCreationHelpers : public CommonMpiTestFixture {
public:
Expand Down
0