Skip to content

Commit

Permalink
assign dst nodes to a subgraph during disjoint sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Sep 8, 2023
1 parent b6b55e8 commit 7bb03aa
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 18 deletions.
17 changes: 7 additions & 10 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ merge_outputs(
const int64_t partitions_num,
const int64_t one_hop_num,
const c10::optional<std::vector<at::Tensor>>& edge_ids,
const c10::optional<std::vector<at::Tensor>>& batch) {
const c10::optional<at::Tensor>& batch) {
at::Tensor out_node;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
c10::optional<at::Tensor> out_batch = c10::nullopt;
Expand Down Expand Up @@ -65,26 +65,23 @@ merge_outputs(
std::vector<scalar_t> sampled_batch;
std::vector<std::vector<scalar_t>> sampled_nodes_vec(p_size);
std::vector<std::vector<scalar_t>> edge_ids_vec;
std::vector<std::vector<scalar_t>> batch_vec(p_size);

if constexpr (with_edge) {
sampled_edge_ids = std::vector<scalar_t>(p_size * offset, -1);
edge_ids_vec = std::vector<std::vector<scalar_t>>(p_size);
}
if constexpr (disjoint) {
sampled_batch = std::vector<scalar_t>(p_size * offset, -1);
batch_vec = std::vector<std::vector<scalar_t>>(p_size);
}
const auto batch_data =
disjoint ? batch.value().data_ptr<scalar_t>() : nullptr;

for (auto p_id = 0; p_id < partitions_num; p_id++) {
sampled_nodes_vec[p_id] = pyg::utils::to_vector<scalar_t>(nodes[p_id]);

if constexpr (with_edge)
edge_ids_vec[p_id] =
pyg::utils::to_vector<scalar_t>(edge_ids.value()[p_id]);

if constexpr (disjoint)
batch_vec[p_id] = pyg::utils::to_vector<scalar_t>(batch.value()[p_id]);
}
at::parallel_for(0, p_size, 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
Expand All @@ -107,9 +104,9 @@ merge_outputs(
edge_ids_vec[p_id].begin() + end_edge,
sampled_edge_ids.begin() + j * offset);
if constexpr (disjoint)
std::copy(batch_vec[p_id].begin() + begin,
batch_vec[p_id].begin() + end,
sampled_batch.begin() + j * offset);
std::fill(sampled_batch.begin() + j * offset,
sampled_batch.begin() + j * offset + end - begin,
batch_data[j]);

sampled_nbrs_per_node[j] = end - begin;
}
Expand Down Expand Up @@ -164,7 +161,7 @@ merge_sampler_outputs_kernel(
const int64_t partitions_num,
const int64_t one_hop_num,
const c10::optional<std::vector<at::Tensor>>& edge_ids,
const c10::optional<std::vector<at::Tensor>>& batch,
const c10::optional<at::Tensor>& batch,
bool disjoint,
bool with_edge) {
DISPATCH_MERGE_OUTPUTS(disjoint, with_edge, nodes, cumm_sampled_nbrs_per_node,
Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ merge_sampler_outputs_kernel(
const int64_t partitions_num,
const int64_t one_hop_num,
const c10::optional<std::vector<at::Tensor>>& edge_ids,
const c10::optional<std::vector<at::Tensor>>& batch,
const c10::optional<at::Tensor>& batch,
bool disjoint,
bool with_edge);

Expand Down
9 changes: 6 additions & 3 deletions pyg_lib/csrc/sampler/dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ merge_sampler_outputs(
const int64_t partitions_num,
const int64_t one_hop_num,
const c10::optional<std::vector<at::Tensor>>& edge_ids,
const c10::optional<std::vector<at::Tensor>>& batch,
const c10::optional<at::Tensor>& batch,
bool disjoint,
bool with_edge) {
std::vector<at::TensorArg> nodes_args;
Expand All @@ -32,9 +32,12 @@ merge_sampler_outputs(
TORCH_CHECK(partition_ids.size() == partition_orders.size(),
"Each id must be assigned a sampling order'");

if (disjoint)
if (disjoint) {
TORCH_CHECK(batch.has_value(),
"I case of disjoint sampling batch needs to be specified");
TORCH_CHECK(batch.value().numel() == partition_ids.size(),
"Each src node must belong to a subgraph'");
}

if (with_edge)
TORCH_CHECK(edge_ids.has_value(), "No edge ids specified");
Expand All @@ -52,7 +55,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"pyg::merge_sampler_outputs(Tensor[] nodes, "
"int[][] cumm_sampled_nbrs_per_node, int[] partition_ids, int[] "
"partition_orders, int partitions_num, int one_hop_num, Tensor[]? "
"edge_ids, Tensor[]? batch, bool disjoint, bool with_edge) -> (Tensor, "
"edge_ids, Tensor? batch, bool disjoint, bool with_edge) -> (Tensor, "
"Tensor?, Tensor?, int[])"));
}

Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/csrc/sampler/dist_merge_outputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ merge_sampler_outputs(
const int64_t partitions_num,
const int64_t one_hop_num,
const c10::optional<std::vector<at::Tensor>>& edge_ids = c10::nullopt,
const c10::optional<std::vector<at::Tensor>>& batch = c10::nullopt,
const c10::optional<at::Tensor>& batch = c10::nullopt,
bool disjoint = false,
bool with_edge = true);

Expand Down
4 changes: 1 addition & 3 deletions test/csrc/sampler/test_dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
const std::vector<at::Tensor> nodes = {at::tensor({2, 7, 8}, options),
at::tensor({0, 1, 4, 5, 6}, options),
at::tensor({3, 9, 10}, options)};
const std::vector<at::Tensor> batch = {at::tensor({2, 2, 2}, options),
at::tensor({0, 1, 0, 0, 1}, options),
at::tensor({3, 3, 3}, options)};
const auto batch = at::tensor({0, 1, 2, 3}, options);

const std::vector<std::vector<int64_t>> cumm_sampled_nbrs_per_node = {
{1, 3}, {2, 4, 5}, {1, 3}};
Expand Down

0 comments on commit 7bb03aa

Please sign in to comment.