Skip to content

Commit

Permalink
merge develop
Browse files Browse the repository at this point in the history
Change-Id: I483d79dc37052d2ccca62e81a1a567a0b315546e
  • Loading branch information
seiriosPlus committed Apr 15, 2021
1 parent b371177 commit 05c0838
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions paddle/fluid/distributed/table/depends/large_scale_kv.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -118,7 +118,7 @@ class ValueBlock {
value_dims_(value_dims),
value_offsets_(value_offsets),
value_idx_(value_idx) {
for (int x = 0; x < value_dims.size(); ++x) {
for (size_t x = 0; x < value_dims.size(); ++x) {
value_length_ += value_dims[x];
}

Expand All @@ -127,13 +127,15 @@ class ValueBlock {
auto slices = string::split_string<std::string>(entry_attr, ":");
if (slices[0] == "none") {
entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0);
threshold_ = 0;
} else if (slices[0] == "count_filter_entry") {
int threshold = std::stoi(slices[1]);
entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold);
threshold_ = std::stoi(slices[1]);
entry_func_ =
std::bind(&count_entry, std::placeholders::_1, threshold_);
} else if (slices[0] == "probability_entry") {
float threshold = std::stof(slices[1]);
threshold_ = std::stof(slices[1]);
entry_func_ =
std::bind(&probility_entry, std::placeholders::_1, threshold);
std::bind(&probility_entry, std::placeholders::_1, threshold_);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Not supported Entry Type : %s, Only support [CountFilterEntry, "
Expand Down Expand Up @@ -201,6 +203,21 @@ class ValueBlock {
return value.data_;
}

VALUE *InitGet(const uint64_t &id, const bool with_update = true,
const int counter = 1) {
if (!Has(id)) {
values_.emplace(std::make_pair(id, VALUE(value_length_)));
}

auto &value = values_.at(id);

if (with_update) {
AttrUpdate(&value, counter);
}

return &value;
}

void AttrUpdate(VALUE *value, const int counter) {
// update state
value->unseen_days_ = 0;
Expand All @@ -210,7 +227,7 @@ class ValueBlock {
value->is_entry_ = entry_func_(value);
if (value->is_entry_) {
// initialize
for (int x = 0; x < value_names_.size(); ++x) {
for (size_t x = 0; x < value_names_.size(); ++x) {
initializers_[x]->GetValue(value->data_ + value_offsets_[x],
value_dims_[x]);
}
Expand Down Expand Up @@ -255,6 +272,8 @@ class ValueBlock {
return;
}

float GetThreshold() { return threshold_; }

private:
bool Has(const uint64_t id) {
auto got = values_.find(id);
Expand All @@ -277,6 +296,7 @@ class ValueBlock {

std::function<bool(VALUE *)> entry_func_;
std::vector<std::shared_ptr<Initializer>> initializers_;
float threshold_;
};

} // namespace distributed
Expand Down

0 comments on commit 05c0838

Please sign in to comment.