Skip to content

Commit

Permalink
review updates:
Browse files Browse the repository at this point in the history
- documentation
- renaming

Co-authored-by: Pratik Nayak <pratik.nayak@kit.edu>
  • Loading branch information
MarcelKoch and pratikvn committed Apr 4, 2024
1 parent 81b3431 commit 1f133bf
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 45 deletions.
4 changes: 2 additions & 2 deletions core/distributed/index_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ namespace distributed {

template <typename LocalIndexType, typename GlobalIndexType>
array<LocalIndexType> index_map<LocalIndexType, GlobalIndexType>::get_local(
const array<GlobalIndexType>& global_ids, index_space is) const
const array<GlobalIndexType>& global_ids, index_space index_space_v) const
{
array<LocalIndexType> local_ids(exec_);

exec_->run(index_map_kernels::make_get_local(
partition_.get(), remote_target_ids_, remote_global_idxs_, rank_,
global_ids, is, local_ids));
global_ids, index_space_v, local_ids));

return local_ids;
}
Expand Down
12 changes: 7 additions & 5 deletions core/distributed/index_map_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef INDEX_MAP_KERNELS_HPP
#define INDEX_MAP_KERNELS_HPP
#ifndef GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_
#define GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_


#include <ginkgo/core/distributed/index_map.hpp>


#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/collection.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/partition.hpp>


Expand All @@ -34,7 +36,7 @@ namespace kernels {
std::shared_ptr<const DefaultExecutor> exec, \
const experimental::distributed::Partition<_ltype, _gtype>* partition, \
const array<experimental::distributed::comm_index_type>& \
remote_targed_ids, \
remote_target_ids, \
const collection::array<_gtype>& remote_global_idxs, \
experimental::distributed::comm_index_type rank, \
const array<_gtype>& global_ids, \
Expand All @@ -59,4 +61,4 @@ GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(index_map,
} // namespace kernels
} // namespace gko

#endif // INDEX_MAP_KERNELS_HPP
#endif // GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_
9 changes: 5 additions & 4 deletions include/ginkgo/core/distributed/index_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ struct index_map {
/**
* \brief Maps global indices to local indices
*
* \param global_ids the global indices to map
* \param is the index space in which the returned local indices are defined
* \param global_ids the global indices to map
* \param index_space_v the index space in which the returned local indices
* are defined
*
* \return the mapped local indices. Any global index that is not in the
* specified index space is mapped to invalid_index.
* \return the mapped local indices. Any global index that is not in the
* specified index space is mapped to invalid_index.
*/
[[nodiscard]] array<LocalIndexType> get_local(
const array<GlobalIndexType>& global_ids,
Expand Down
35 changes: 6 additions & 29 deletions reference/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,12 @@ void build_mapping(
collection::array<GlobalIndexType>& remote_global_idxs)
{
using experimental::distributed::comm_index_type;
using partition_type =
experimental::distributed::Partition<LocalIndexType, GlobalIndexType>;
auto part_ids = part->get_part_ids();

std::vector<GlobalIndexType> unique_indices(recv_connections.get_size());
std::copy_n(recv_connections.get_const_data(), recv_connections.get_size(),
unique_indices.begin());

auto find_range = [](GlobalIndexType idx, const partition_type* partition,
size_type hint) {
auto range_bounds = partition->get_range_bounds();
auto num_ranges = partition->get_num_ranges();
if (range_bounds[hint] <= idx && idx < range_bounds[hint + 1]) {
return hint;
} else {
auto it = std::upper_bound(range_bounds + 1,
range_bounds + num_ranges + 1, idx);
return static_cast<size_type>(std::distance(range_bounds + 1, it));
}
};

auto map_to_local = [](GlobalIndexType idx, const partition_type* partition,
size_type range_id) {
auto range_bounds = partition->get_range_bounds();
auto range_starting_indices = partition->get_range_starting_indices();
return static_cast<LocalIndexType>(idx - range_bounds[range_id]) +
range_starting_indices[range_id];
};

auto find_part = [&](GlobalIndexType idx) {
auto range_id = find_range(idx, part, 0);
return part_ids[range_id];
Expand Down Expand Up @@ -139,7 +116,7 @@ void get_local(
std::shared_ptr<const DefaultExecutor> exec,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
partition,
const array<experimental::distributed::comm_index_type>& remote_targed_ids,
const array<experimental::distributed::comm_index_type>& remote_target_ids,
const collection::array<GlobalIndexType>& remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<GlobalIndexType>& global_ids,
Expand Down Expand Up @@ -171,13 +148,13 @@ void get_local(
// the global indexing. So find the part-id that corresponds
// to the global index first
auto set_id = std::distance(
remote_targed_ids.get_const_data(),
std::lower_bound(remote_targed_ids.get_const_data(),
remote_targed_ids.get_const_data() +
remote_targed_ids.get_size(),
remote_target_ids.get_const_data(),
std::lower_bound(remote_target_ids.get_const_data(),
remote_target_ids.get_const_data() +
remote_target_ids.get_size(),
part_id));

if (set_id == remote_targed_ids.get_size()) {
if (set_id == remote_target_ids.get_size()) {
return invalid_index<LocalIndexType>();
}

Expand Down
5 changes: 0 additions & 5 deletions reference/test/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
#include "core/distributed/index_map_kernels.hpp"
#include "core/test/utils.hpp"

namespace {


using comm_index_type = gko::experimental::distributed::comm_index_type;

Expand Down Expand Up @@ -164,6 +162,3 @@ TEST_F(IndexMap, CanGetLocalWithCombinedISWithInvalid)
gko::array<local_index_type> expected(ref, {2, 3, 0, 1, 2, 4, -1, 1});
GKO_ASSERT_ARRAY_EQ(local_ids, expected);
}


} // namespace

0 comments on commit 1f133bf

Please sign in to comment.