diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp index 1c3884289..e198de8f3 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp @@ -26,7 +26,7 @@ merge_outputs( const int64_t partitions_num, const int64_t one_hop_num, const c10::optional>& edge_ids, - const c10::optional>& batch) { + const c10::optional& batch) { at::Tensor out_node; c10::optional out_edge_id = c10::nullopt; c10::optional out_batch = c10::nullopt; @@ -65,7 +65,6 @@ merge_outputs( std::vector sampled_batch; std::vector> sampled_nodes_vec(p_size); std::vector> edge_ids_vec; - std::vector> batch_vec(p_size); if constexpr (with_edge) { sampled_edge_ids = std::vector(p_size * offset, -1); @@ -73,8 +72,9 @@ merge_outputs( } if constexpr (disjoint) { sampled_batch = std::vector(p_size * offset, -1); - batch_vec = std::vector>(p_size); } + const auto batch_data = + disjoint ? batch.value().data_ptr() : nullptr; for (auto p_id = 0; p_id < partitions_num; p_id++) { sampled_nodes_vec[p_id] = pyg::utils::to_vector(nodes[p_id]); @@ -82,9 +82,6 @@ merge_outputs( if constexpr (with_edge) edge_ids_vec[p_id] = pyg::utils::to_vector(edge_ids.value()[p_id]); - - if constexpr (disjoint) - batch_vec[p_id] = pyg::utils::to_vector(batch.value()[p_id]); } at::parallel_for(0, p_size, 1, [&](size_t _s, size_t _e) { for (auto j = _s; j < _e; j++) { @@ -107,9 +104,9 @@ merge_outputs( edge_ids_vec[p_id].begin() + end_edge, sampled_edge_ids.begin() + j * offset); if constexpr (disjoint) - std::copy(batch_vec[p_id].begin() + begin, - batch_vec[p_id].begin() + end, - sampled_batch.begin() + j * offset); + std::fill(sampled_batch.begin() + j * offset, + sampled_batch.begin() + j * offset + end - begin, + batch_data[j]); sampled_nbrs_per_node[j] = end - begin; } @@ -164,7 +161,7 @@ merge_sampler_outputs_kernel( const int64_t partitions_num, const int64_t one_hop_num, const c10::optional>& edge_ids, - const c10::optional>& batch, + const c10::optional& batch, bool disjoint, bool with_edge) { DISPATCH_MERGE_OUTPUTS(disjoint, with_edge, nodes, cumm_sampled_nbrs_per_node, diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h index 878160e59..1936e23d5 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h @@ -17,7 +17,7 @@ merge_sampler_outputs_kernel( const int64_t partitions_num, const int64_t one_hop_num, const c10::optional>& edge_ids, - const c10::optional>& batch, + const c10::optional& batch, bool disjoint, bool with_edge); diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.cpp b/pyg_lib/csrc/sampler/dist_merge_outputs.cpp index ccd5fd90f..984fdf617 100644 --- a/pyg_lib/csrc/sampler/dist_merge_outputs.cpp +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.cpp @@ -20,7 +20,7 @@ merge_sampler_outputs( const int64_t partitions_num, const int64_t one_hop_num, const c10::optional>& edge_ids, - const c10::optional>& batch, + const c10::optional& batch, bool disjoint, bool with_edge) { std::vector nodes_args; @@ -32,9 +32,12 @@ merge_sampler_outputs( TORCH_CHECK(partition_ids.size() == partition_orders.size(), "Each id must be assigned a sampling order'"); - if (disjoint) + if (disjoint) { TORCH_CHECK(batch.has_value(), "I case of disjoint sampling batch needs to be specified"); + TORCH_CHECK(batch.value().numel() == partition_ids.size(), + "Each src node must belong to a subgraph'"); + } if (with_edge) TORCH_CHECK(edge_ids.has_value(), "No edge ids specified"); @@ -52,7 +55,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "pyg::merge_sampler_outputs(Tensor[] nodes, " "int[][] cumm_sampled_nbrs_per_node, int[] partition_ids, int[] " "partition_orders, int partitions_num, int one_hop_num, Tensor[]? " - "edge_ids, Tensor[]? batch, bool disjoint, bool with_edge) -> (Tensor, " + "edge_ids, Tensor? batch, bool disjoint, bool with_edge) -> (Tensor, " "Tensor?, Tensor?, int[])")); } diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.h b/pyg_lib/csrc/sampler/dist_merge_outputs.h index 244bc4e5a..5770025a1 100644 --- a/pyg_lib/csrc/sampler/dist_merge_outputs.h +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.h @@ -24,7 +24,7 @@ merge_sampler_outputs( const int64_t partitions_num, const int64_t one_hop_num, const c10::optional>& edge_ids = c10::nullopt, - const c10::optional>& batch = c10::nullopt, + const c10::optional& batch = c10::nullopt, bool disjoint = false, bool with_edge = true); diff --git a/test/csrc/sampler/test_dist_merge_outputs.cpp b/test/csrc/sampler/test_dist_merge_outputs.cpp index 8a40d7f84..a0d21ec0e 100644 --- a/test/csrc/sampler/test_dist_merge_outputs.cpp +++ b/test/csrc/sampler/test_dist_merge_outputs.cpp @@ -88,9 +88,7 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) { const std::vector nodes = {at::tensor({2, 7, 8}, options), at::tensor({0, 1, 4, 5, 6}, options), at::tensor({3, 9, 10}, options)}; - const std::vector batch = {at::tensor({2, 2, 2}, options), - at::tensor({0, 1, 0, 0, 1}, options), - at::tensor({3, 3, 3}, options)}; + const auto batch = at::tensor({0, 1, 2, 3}, options); const std::vector> cumm_sampled_nbrs_per_node = { {1, 3}, {2, 4, 5}, {1, 3}};