Skip to content

Commit

Permalink
fix bug of random sample k
Browse files Browse the repository at this point in the history
  • Loading branch information
Liwb5 committed Jul 12, 2021
1 parent fc74f83 commit d743606
Showing 1 changed file with 20 additions and 34 deletions.
54 changes: 20 additions & 34 deletions paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,47 +24,33 @@ void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }

std::vector<int> RandomSampler::sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) {
int n = edges->size();
if (k > n) {
if (k >= n) {
k = n;
}
std::vector<int> sample_result;
for(int i = 0;i < k;i ++ ) {
std::vector<int> sample_result;
for (int i = 0; i < k; i++) {
sample_result.push_back(i);
}
return sample_result;
}
if (k == n) {
return sample_result;
}

std::uniform_int_distribution<int> distrib(0, n - 1);
std::vector<int> sample_result;
std::unordered_map<int, int> replace_map;
while (k--) {
std::uniform_int_distribution<int> distrib(0, n - 1);
int rand_int = distrib(*rng);
auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) {
sample_result.push_back(rand_int);
} else {
sample_result.push_back(iter->second);
}

for(int i = 0; i < k; i ++) {
int j = distrib(*rng);
if (j >= i) {
// buff_nid[offset + i] = nid[j] if m.find(j) == m.end() else nid[m[j]]
auto iter_j = replace_map.find(j);
if(iter_j == replace_map.end()) {
sample_result[i] = j;
} else {
sample_result[i] = iter_j -> second;
}
// m[j] = i if m.find(i) == m.end() else m[i]
auto iter_i = replace_map.find(i);
if(iter_i == replace_map.end()) {
replace_map[j] = i;
} else {
replace_map[j] = (iter_i -> second);
}
iter = replace_map.find(n - 1);
if (iter == replace_map.end()) {
replace_map[rand_int] = n - 1;
} else {
sample_result[i] = sample_result[j];
// buff_nid[offset + j] = nid[i] if m.find(i) == m.end() else nid[m[i]]
auto iter_i = replace_map.find(i);
if(iter_i == replace_map.end()) {
sample_result[j] = i;
} else {
sample_result[j] = (iter_i -> second);
}
replace_map[rand_int] = iter->second;
}
--n;
}
return sample_result;
}
Expand Down

0 comments on commit d743606

Please sign in to comment.