Skip to content

Commit

Permalink
Merge pull request #32 from jiaoxuewu/paddlebox
Browse files Browse the repository at this point in the history
merge fused
  • Loading branch information
qingshui authored Jan 26, 2022
2 parents a43c1ab + 6ba0254 commit cd4ff54
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 5 deletions.
74 changes: 74 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,75 @@ class MaskMetricMsg : public MetricMsg {
std::string mask_varname_;
};

class MultiMaskMetricMsg : public MetricMsg {
public:
MultiMaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& mask_varname_list, const std::string& mask_varvalue_list,
int bucket_size = 1000000,
bool mode_collect_in_gpu = false, int max_batch_size = 0) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_list_ = string::split_string(mask_varname_list, " ");
const std::vector<std::string> tmp_val_lst = string::split_string(mask_varvalue_list, " ");
for (const auto& it : tmp_val_lst) {
mask_varvalue_list_.emplace_back(atoi(it.c_str()));
}
PADDLE_ENFORCE_EQ(mask_varname_list_.size(), mask_varvalue_list_.size(),
platform::errors::PreconditionNotMet("mast var num[%zu] should be equal to mask val num[%zu]",
mask_varname_list_.size(), mask_varvalue_list_.size()));

metric_phase_ = metric_phase;
calculator = new BasicAucCalculator(mode_collect_in_gpu);
calculator->init(bucket_size);
}
virtual ~MultiMaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);

std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);

PADDLE_ENFORCE_EQ(label_data.size(), pred_data.size(),
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));

std::vector<std::vector<int64_t>> mask_value_data_list(mask_varname_list_.size());
for (size_t name_idx = 0; name_idx < mask_varname_list_.size(); ++name_idx) {
get_data<int64_t>(exe_scope, mask_varname_list_[name_idx], &mask_value_data_list[name_idx]);
PADDLE_ENFORCE_EQ(label_data.size(), mask_value_data_list[name_idx].size(),
platform::errors::PreconditionNotMet(
"the label data length[%d] should be consistent with "
"the %s[%zu] length", label_data.size(), mask_value_data_list[name_idx].size()));
}
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
size_t batch_size = label_data.size();
bool flag = true;
for (size_t ins_idx = 0; ins_idx < batch_size; ++ins_idx) {
flag = true;
for (size_t val_idx = 0; val_idx < mask_varvalue_list_.size(); ++val_idx) {
if (mask_value_data_list[val_idx][ins_idx] != mask_varvalue_list_[val_idx]) {
flag = false;
break;
}
}
if (flag) {
cal->add_unlock_data(pred_data[ins_idx], label_data[ins_idx]);
}
}

}

protected:
std::vector<int> mask_varvalue_list_;
std::vector<std::string> mask_varname_list_;
std::string cmatch_rank_varname_;
};

class CmatchRankMaskMetricMsg : public MetricMsg {
public:
CmatchRankMaskMetricMsg(const std::string& label_varname,
Expand Down Expand Up @@ -1142,6 +1211,11 @@ void BoxWrapper::InitMetric(const std::string& method, const std::string& name,
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size, mode_collect_in_gpu,
max_batch_size));
} else if (method == "MultiMaskAucCalculator") {
metric_lists_.emplace(
name, new MultiMaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, cmatch_rank_group, bucket_size, mode_collect_in_gpu,
max_batch_size));
} else if (method == "CmatchRankMaskAucCalculator") {
metric_lists_.emplace(
name, new CmatchRankMaskMetricMsg(
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(0.0);
AddAttr<bool>("use_cvm", "bool, use cvm or not").SetDefault(true);
AddAttr<bool>("need_filter", "(bool, default false)").SetDefault(false);
AddAttr<bool>("embed_threshold_filter", "(bool, default false)").SetDefault(false);
AddAttr<float>("show_coeff", "(float, default 0.2)").SetDefault(0.2);
AddAttr<float>("clk_coeff", "(float, default 1)").SetDefault(1);
AddAttr<float>("threshold", "(float, default 0.96)").SetDefault(0.96);
AddAttr<float>("embed_threshold", "(float, default 0)").SetDefault(0);
AddAttr<int>("cvm_offset", "(int, default 2)").SetDefault(2);
AddAttr<int>("quant_ratio", "(int, default 128)").SetDefault(0);
AddAttr<bool>("clk_filter", "(bool, default false)").SetDefault(false);
Expand Down
116 changes: 111 additions & 5 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,104 @@ __global__ void FusedSeqpoolKernelQuantFilter(
*(seqpool_output_values[x] + y * embedding_size + offset) = val;
}
}
// embed quant filter
template <typename T>
__global__ void FusedSeqpoolKernelEmbedQuantFilter(
const size_t N, T **input_values, T **seqpool_output_values,
size_t **lods_values, const int batch_size, const int embedding_size,
const float pad_value, const int cvm_offset, const float show_coeff,
const float clk_coeff, const float threshold, const int quant_ratio,
const float embed_threshold) {
CUDA_KERNEL_LOOP(i, N) {
int key = i / embedding_size;
int offset = i % embedding_size; // embedx id
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = *(lods_values[x] + y);
auto &end = *(lods_values[x] + y + 1);

double val = pad_value;
for (auto k = start; k < end; ++k) {
T &show = *(input_values[x] + k * embedding_size);
T &click = *(input_values[x] + k * embedding_size + 1);
if ((show - click) * show_coeff + click * clk_coeff < threshold) {
continue;
}
T &embedw = *(input_values[x] + k * embedding_size + cvm_offset);
T embedx_weight_score = 0.0;
for (int i = cvm_offset+1; i < embedding_size; i++) {
embedx_weight_score += pow(*(input_values[x] + k * embedding_size + i), 2);
}
embedx_weight_score = std::sqrt(embedx_weight_score) + std::abs(embedw);
if (embedx_weight_score < embed_threshold) {
continue;
}
if (offset < cvm_offset) { // show & click
val += *(input_values[x] + k * embedding_size + offset);
} else {
val += ((static_cast<int>(
*(input_values[x] + k * embedding_size + offset) *
quant_ratio +
0.5)) /
static_cast<float>(quant_ratio));
}
}
*(seqpool_output_values[x] + y * embedding_size + offset) = val;
}
}
// embed quant filter opt
template <typename T>
__global__ void FusedSeqpoolKernelEmbedQuantOptFilter(
const size_t N, T **input_values, T **seqpool_output_values,
size_t **lods_values, const int batch_size, const int embedding_size,
const float pad_value, const int cvm_offset, const float show_coeff,
const float clk_coeff, const float threshold, const int quant_ratio,
const float embed_threshold) {
CUDA_KERNEL_LOOP(i, N) {
int key = i / embedding_size;
int offset = i % embedding_size; // embedx id
int x = key / batch_size; // slot id
int y = key % batch_size; // ins id
auto &start = *(lods_values[x] + y);
auto &end = *(lods_values[x] + y + 1);

bool is_filter[end - start];
if (offset == 0) {
is_filter[end - start] = {false};
for (auto k = start; k < end; ++k) {
T &show = *(input_values[x] + k * embedding_size);
T &click = *(input_values[x] + k * embedding_size + 1);
T &embedw = *(input_values[x] + k * embedding_size + cvm_offset);
T embedx_weight_score = 0.0;
for (int i = cvm_offset+1; i < embedding_size; i++) {
embedx_weight_score += pow(*(input_values[x] + k * embedding_size + i), 2);
}
T show_click_score = (show - click) * show_coeff + click * clk_coeff;
embedx_weight_score = std::sqrt(embedx_weight_score) + std::abs(embedw);
if (show_click_score < threshold || embedx_weight_score < embed_threshold) {
is_filter[k-start] = true;
}
}
}

double val = pad_value;
for (auto k = start; k < end; ++k) {
if (is_filter[k-start]) {
continue;
}
if (offset < cvm_offset) { // show & click
val += *(input_values[x] + k * embedding_size + offset);
} else {
val += ((static_cast<int>(
*(input_values[x] + k * embedding_size + offset) *
quant_ratio +
0.5)) /
static_cast<float>(quant_ratio));
}
}
*(seqpool_output_values[x] + y * embedding_size + offset) = val;
}
}
// join need show click input
template <typename T>
__global__ void FusedCVMKernelWithCVM(const size_t N, T **output_values,
Expand Down Expand Up @@ -190,8 +288,8 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place,
std::vector<const size_t *> lods, const int batch_size,
const int slot_num, const int embedding_size,
const float padding_value, const bool use_cvm,
const int cvm_offset, float need_filter, float show_coeff,
float clk_coeff, float threshold, const int quant_ratio,
const int cvm_offset, float need_filter, const bool embed_threshold_filter, float show_coeff,
float clk_coeff, float threshold, float embed_threshold, const int quant_ratio,
const bool clk_filter) {
auto stream = dynamic_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
Expand Down Expand Up @@ -224,7 +322,13 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place,

size_t N = static_cast<size_t>(batch_size * slot_num * embedding_size);
// first sum pool
if (need_filter) { // quant need filter
if (need_filter && embed_threshold_filter) { // embed quant filter
FusedSeqpoolKernelEmbedQuantFilter<<<GET_BLOCK(N), PADDLE_CUDA_NUM_THREADS, 0,
stream>>>(
N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size,
embedding_size, padding_value, cvm_offset, show_coeff, clk_coeff,
threshold, quant_ratio, embed_threshold);
} else if (need_filter) { // quant need filter
FusedSeqpoolKernelQuantFilter<<<GET_BLOCK(N), PADDLE_CUDA_NUM_THREADS, 0,
stream>>>(
N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size,
Expand Down Expand Up @@ -414,9 +518,11 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
auto padding_value = ctx.Attr<float>("pad_value");
auto use_cvm = ctx.Attr<bool>("use_cvm");
bool need_filter = ctx.Attr<bool>("need_filter");
bool embed_threshold_filter = ctx.Attr<bool>("embed_threshold_filter");
float show_coeff = ctx.Attr<float>("show_coeff");
float clk_coeff = ctx.Attr<float>("clk_coeff");
float threshold = ctx.Attr<float>("threshold");
float embed_threshold = ctx.Attr<float>("embed_threshold");
const int cvm_offset = ctx.Attr<int>("cvm_offset");
const int quant_ratio = ctx.Attr<int>("quant_ratio");
bool clk_filter = ctx.Attr<bool>("clk_filter");
Expand Down Expand Up @@ -458,8 +564,8 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
FusedSeqpoolCVM(ctx.GetPlace(), input_data, output_data,
seqpool_output_data, lods_data, batch_size, slot_size,
embedding_size, padding_value, use_cvm, cvm_offset,
need_filter, show_coeff, clk_coeff, threshold, quant_ratio,
clk_filter);
need_filter, embed_threshold_filter, show_coeff, clk_coeff,
threshold, embed_threshold, quant_ratio, clk_filter);
}
};

Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/contrib/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,9 +1549,11 @@ def fused_seqpool_cvm(input,
pad_value=0.0,
use_cvm=True,
need_filter=False,
embed_threshold_filter=False,
show_coeff=0.2,
clk_coeff=1.0,
threshold=0.96,
embed_threshold=0,
cvm_offset=2,
quant_ratio=0,
clk_filter=False):
Expand Down Expand Up @@ -1603,9 +1605,11 @@ def fused_seqpool_cvm(input,
"use_cvm": use_cvm,
"cvm_offset": cvm_offset,
"need_filter": need_filter,
"embed_threshold_filter": embed_threshold_filter,
"show_coeff": show_coeff,
"clk_coeff": clk_coeff,
"threshold": threshold,
"embed_threshold": embed_threshold,
"quant_ratio": quant_ratio,
"clk_filter": clk_filter
})
Expand Down

0 comments on commit cd4ff54

Please sign in to comment.