Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Sep 5, 2022
1 parent 8725707 commit f2bb685
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ sample(const at::Tensor& rowptr,

auto mapper = Mapper<node_t, scalar_t>(/*num_nodes=*/rowptr.size(0) - 1);
std::vector<node_t> sampled_nodes;
auto sampler =
NeighborSampler<node_t, scalar_t, replace, directed, return_edge_id>(
rowptr.data_ptr<scalar_t>(), col.data_ptr<scalar_t>());

const auto seed_data = seed.data_ptr<scalar_t>();
for (size_t i = 0; i < seed.numel(); i++) {
Expand All @@ -167,6 +164,10 @@ sample(const at::Tensor& rowptr,
}
}

auto sampler =
NeighborSampler<node_t, scalar_t, replace, directed, return_edge_id>(
rowptr.data_ptr<scalar_t>(), col.data_ptr<scalar_t>());

size_t begin = 0, end = seed.size(0);
for (size_t ell = 0; ell < num_neighbors.size(); ++ell) {
const auto count = num_neighbors[ell];
Expand All @@ -191,6 +192,7 @@ sample(const at::Tensor& rowptr,
if (directed) {
std::tie(out_row, out_col, out_edge_id) = sampler.get_sampled_edges();
} else {
TORCH_CHECK(!disjoint, "Disjoint subgraphs are not yet supported");
std::tie(out_row, out_col, out_edge_id) =
pyg::sampler::subgraph(rowptr, col, out_node_id, return_edge_id);
}
Expand All @@ -207,9 +209,6 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
bool directed,
bool disjoint,
bool return_edge_id) {
if (disjoint && !directed)
AT_ERROR("Disjoint subgraphs are currently not supported");

if (disjoint)
return sample<false, true, true, true>(rowptr, col, seed, num_neighbors);
else
Expand Down

0 comments on commit f2bb685

Please sign in to comment.