Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to merge samplers outputs #252

Merged
merged 8 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.0] - 2023-MM-DD
### Added
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#253](https://github.com/pyg-team/pyg-lib/pull/253))
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#252](https://github.com/pyg-team/pyg-lib/pull/252), [#253](https://github.com/pyg-team/pyg-lib/pull/253))
- Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251))
- Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243))
- Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229))
Expand Down
170 changes: 170 additions & 0 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/library.h>

#include "parallel_hashmap/phmap.h"

#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {

namespace {

template <bool disjoint>
std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_outputs(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch) {
at::Tensor out_node_id;
at::Tensor out_edge_id;
c10::optional<at::Tensor> out_batch = c10::nullopt;

auto offset = num_neighbors;

if (num_neighbors < 0) {
// find maximum population
std::vector<std::vector<int64_t>> population(num_partitions);
std::vector<int64_t> max_populations(num_partitions);

at::parallel_for(0, num_partitions, 1, [&](size_t _s, size_t _e) {
for (auto p_id = _s; p_id < _e; p_id++) {
auto cummsum1 =
std::vector<int64_t>(cumsum_neighbors_per_node[p_id].begin() + 1,
cumsum_neighbors_per_node[p_id].end());
auto cummsum2 =
std::vector<int64_t>(cumsum_neighbors_per_node[p_id].begin(),
cumsum_neighbors_per_node[p_id].end() - 1);
std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(),
std::back_inserter(population[p_id]),
[](int64_t a, int64_t b) { return std::abs(a - b); });
auto max =
*max_element(population[p_id].begin(), population[p_id].end());
max_populations[p_id] = max;
}
});
offset = *max_element(max_populations.begin(), max_populations.end());
}

const auto p_size = partition_ids.size();
std::vector<int64_t> sampled_neighbors_per_node(p_size);

const auto scalar_type = node_ids[0].scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] {
std::vector<scalar_t> sampled_node_ids(p_size * offset, -1);
std::vector<scalar_t> sampled_edge_ids(p_size * offset, -1);
std::vector<std::vector<scalar_t>> sampled_node_ids_vec(p_size);
std::vector<std::vector<scalar_t>> sampled_edge_ids_vec(p_size);

std::vector<scalar_t> sampled_batch;
if constexpr (disjoint) {
sampled_batch = std::vector<scalar_t>(p_size * offset, -1);
}
const auto batch_data =
disjoint ? batch.value().data_ptr<scalar_t>() : nullptr;

for (auto p_id = 0; p_id < num_partitions; p_id++) {
sampled_node_ids_vec[p_id] =
pyg::utils::to_vector<scalar_t>(node_ids[p_id]);
sampled_edge_ids_vec[p_id] =
pyg::utils::to_vector<scalar_t>(edge_ids[p_id]);
}
at::parallel_for(0, p_size, 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
auto p_id = partition_ids[j];
auto p_order = partition_orders[j];

// When it comes to node and batch, we omit seed nodes.
// In the case of edges, we take into account all sampled edge ids.
auto begin_node = cumsum_neighbors_per_node[p_id][p_order];
auto begin_edge = begin_node - cumsum_neighbors_per_node[p_id][0];

auto end_node = cumsum_neighbors_per_node[p_id][p_order + 1];
auto end_edge = end_node - cumsum_neighbors_per_node[p_id][0];

std::copy(sampled_node_ids_vec[p_id].begin() + begin_node,
sampled_node_ids_vec[p_id].begin() + end_node,
sampled_node_ids.begin() + j * offset);
std::copy(sampled_edge_ids_vec[p_id].begin() + begin_edge,
sampled_edge_ids_vec[p_id].begin() + end_edge,
sampled_edge_ids.begin() + j * offset);

if constexpr (disjoint) {
std::fill(sampled_batch.begin() + j * offset,
sampled_batch.begin() + j * offset + end_node - begin_node,
batch_data[j]);
}

sampled_neighbors_per_node[j] = end_node - begin_node;
}
});

// Remove auxilary -1 numbers:
auto neg =
std::remove(sampled_node_ids.begin(), sampled_node_ids.end(), -1);
sampled_node_ids.erase(neg, sampled_node_ids.end());
out_node_id = pyg::utils::from_vector(sampled_node_ids);

neg = std::remove(sampled_edge_ids.begin(), sampled_edge_ids.end(), -1);
sampled_edge_ids.erase(neg, sampled_edge_ids.end());
out_edge_id = pyg::utils::from_vector(sampled_edge_ids);

if constexpr (disjoint) {
neg = std::remove(sampled_batch.begin(), sampled_batch.end(), -1);
sampled_batch.erase(neg, sampled_batch.end());
out_batch = pyg::utils::from_vector(sampled_batch);
}
});

return std::make_tuple(out_node_id, out_edge_id, out_batch,
sampled_neighbors_per_node);
}

#define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \
if (disjoint) \
return merge_outputs<true>(__VA_ARGS__); \
if (!disjoint) \
return merge_outputs<false>(__VA_ARGS__);

} // namespace

std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs_kernel(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch,
bool disjoint) {
DISPATCH_MERGE_OUTPUTS(
disjoint, node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids,
partition_orders, num_partitions, num_neighbors, batch);
}

Check warning on line 159 in pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp#L159

Added line #L159 was not covered by tests

// We use `BackendSelect` as a fallback to the dispatcher logic as automatic
// dispatching of std::vector<at::Tensor> is not yet supported by PyTorch.
// See: pytorch/aten/src/ATen/templates/RegisterBackendSelect.cpp.
TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::merge_sampler_outputs"),
TORCH_FN(merge_sampler_outputs_kernel));
}

} // namespace sampler
} // namespace pyg
59 changes: 59 additions & 0 deletions pyg_lib/csrc/sampler/dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "dist_merge_outputs.h"

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>

#include "pyg_lib/csrc/utils/check.h"

namespace pyg {
namespace sampler {

std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch,
bool disjoint) {
std::vector<at::TensorArg> node_ids_args;
std::vector<at::TensorArg> edge_ids_args;
pyg::utils::fill_tensor_args(node_ids_args, node_ids, "node_ids", 0);
pyg::utils::fill_tensor_args(edge_ids_args, edge_ids, "edge_ids", 0);

at::CheckedFrom c{"merge_sampler_outputs"};
at::checkAllDefined(c, {node_ids_args});
at::checkAllDefined(c, {edge_ids_args});

TORCH_CHECK(partition_ids.size() == partition_orders.size(),
"Every partition ID must be assigned a sampling order");

if (disjoint) {
TORCH_CHECK(batch.has_value(),
"Disjoint sampling requires 'batch' to be specified");
}

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::merge_sampler_outputs", "")
.typed<decltype(merge_sampler_outputs)>();
return op.call(node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids,
partition_orders, num_partitions, num_neighbors, batch,
disjoint);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::merge_sampler_outputs(Tensor[] node_ids, Tensor[] edge_ids, "
"int[][] cumsum_neighbors_per_node, int[] partition_ids, int[] "
"partition_orders, int num_partitions, int num_neighbors, Tensor? "
"batch, bool disjoint) -> (Tensor, Tensor, Tensor?, int[])"));
}

} // namespace sampler
} // namespace pyg
34 changes: 34 additions & 0 deletions pyg_lib/csrc/sampler/dist_merge_outputs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <ATen/ATen.h>
#include "pyg_lib/csrc/macros.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {

// For distributed training purposes. Merges sampler outputs from different
// partitions, so that they are sorted according to the sampling order.
// Removes seed nodes from sampled nodes and calculates how many neighbors
// were sampled by each source node based on the cummulative sum of sampled
// neighbors for each input node.
// Returns the unified node, edge and batch indices as well as the merged
// cummulative sum of sampled neighbors.
PYG_API
std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch = c10::nullopt,
bool disjoint = false);

} // namespace sampler
} // namespace pyg
5 changes: 5 additions & 0 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ hetero_neighbor_sample(
std::string strategy = "uniform",
bool return_edge_id = true);

// For distributed sampling purposes. Leverages the `neighbor_sample` function
// internally. Samples one-hop neighborhoods with duplicates from all node
// indices in `seed` in the graph given by `(rowptr, col)`.
// Returns the original node and edge indices for all sampled nodes and edges.
// Lastly, returns the cummulative sum of sampled neighbors for each input node.
PYG_API
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
Expand Down
45 changes: 0 additions & 45 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,50 +165,6 @@ def hetero_neighbor_sample(
num_nodes_per_hop_dict, num_edges_per_hop_dict)


def dist_neighbor_sample(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
num_neighbors: int,
time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
) -> Tuple[Tensor, Tensor, List[int]]:
r"""For distributed sampling purpose. Leverages the
:meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from
all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`.

Args:
num_neighbors (int): Maximum number of neighbors to sample in the
current layer.
kwargs: Arguments of :meth:`neighbor_sample`.

Returns:
(torch.Tensor, torch.Tensor, List[int]): Returns the original node and
edge indices for all sampled nodes and edges. Lastly, returns the
cummulative sum of the amount of sampled neighbors for each input node.
"""
return torch.ops.pyg.dist_neighbor_sample(
rowptr,
col,
seed,
num_neighbors,
time,
seed_time,
edge_weight,
csc,
replace,
directed,
disjoint,
temporal_strategy,
)


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand Down Expand Up @@ -262,7 +218,6 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,
__all__ = [
'neighbor_sample',
'hetero_neighbor_sample',
'dist_neighbor_sample',
'subgraph',
'random_walk',
]
Loading
Loading