-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add target assigner operator for SSD detection. #8193
Conversation
… ssd_target_assign
paddle/operators/target_assign_op.cc
Outdated
@@ -0,0 +1,172 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
"labels with shape [Ng, 1], where the Ng is the same as it in " | ||
"the input of EncodedGTBBox."); | ||
AddInput("MatchIndices", | ||
"(Tensor, default LoDTensor<int>), The input matched indices " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default Tensor<int>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.h
Outdated
@@ -0,0 +1,155 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cu
Outdated
@@ -0,0 +1,61 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
for n in range(batch_size): | ||
gt_num = gt_lod[n + 1] - gt_lod[n] | ||
ids = random.sample([i for i in range(num_prior)], gt_num) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample是截取指定长度的片断,是不是sample前可以再加一个shuffle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample采样出来是乱序的:
>>> import random
>>> print random.sample([i for i in range(10)], 4)
[7, 3, 5, 6]
paddle/operators/target_assign_op.cc
Outdated
"indics with shape [Neg, 1], where is the total number of " | ||
"negative example indices."); | ||
AddAttr<int>("background_label", | ||
"(int, default 0), Label id for background class.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Label id of background class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
const size_t* lod, const int num, const int num_prior_box, | ||
const int background_label, int* out_label, T* out_label_wt) { | ||
for (int i = 0; i < num; ++i) { | ||
for (int j = lod[i]; j < lod[i + 1]; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t j
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto* out_label_wt = ctx.Output<framework::Tensor>("PredScoreWeight"); | ||
|
||
PADDLE_ENFORCE_EQ(enc_gt_box->lod().size(), 1UL); | ||
PADDLE_ENFORCE_EQ(gt_label->lod().size(), 1UL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
检查 gt_label->lod()和 enc_gt_box->lod() 的值是一致的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in following code.
|
||
import unittest | ||
import numpy as np | ||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有用math,可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
import unittest | ||
import numpy as np | ||
import math | ||
import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有用sys,可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.h
Outdated
T* obox = out_box_ + (row * num_prior_box_ + col) * 4; | ||
int* olabel = out_label_ + row * num_prior_box_ + col; | ||
T* obox_wt = out_box_wt_ + row * num_prior_box_ + col; | ||
T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row * num_prior_box_
has computed many times in kernel. It can be optimized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cu
Outdated
for (int i = st + threadIdx.x; i < ed; i += blockDim.x) { | ||
int id = neg_indices[i]; | ||
out_label[bidx * num_prior_box + id] = background_label; | ||
out_label_wt[bidx * num_prior_box + id] = 1.; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bidx * num_prior_box
has appeared many times and it is inside a loop, so it should be optimized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanghaox @chengduoZH Thanks for your review.
paddle/operators/target_assign_op.cc
Outdated
@@ -0,0 +1,172 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
"The rank of Input(NegIndices) must be 2."); | ||
|
||
PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0], | ||
"The 1st dimension of Input(EncodedGTBBox) and " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
"The 1st dimension of Input(EncodedGTBBox) and " | ||
"Input(GTScoreLabel) must be the same."); | ||
PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1], | ||
"The 2nd dimension of Input(EncodedGTBBox) and " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
"box in i-th instance."); | ||
AddInput("NegIndices", | ||
"(LoDTensor, default LoDTensor<int>), The input negative example " | ||
"indics with shape [Neg, 1], where is the total number of " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
"indics with shape [Neg, 1], where is the total number of " | ||
"negative example indices."); | ||
AddAttr<int>("background_label", | ||
"(int, default 0), Label id for background class.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cc
Outdated
const size_t* lod, const int num, const int num_prior_box, | ||
const int background_label, int* out_label, T* out_label_wt) { | ||
for (int i = 0; i < num; ++i) { | ||
for (int j = lod[i]; j < lod[i + 1]; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cu
Outdated
@@ -0,0 +1,61 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.h
Outdated
@@ -0,0 +1,155 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.cu
Outdated
for (int i = st + threadIdx.x; i < ed; i += blockDim.x) { | ||
int id = neg_indices[i]; | ||
out_label[bidx * num_prior_box + id] = background_label; | ||
out_label_wt[bidx * num_prior_box + id] = 1.; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/target_assign_op.h
Outdated
T* obox = out_box_ + (row * num_prior_box_ + col) * 4; | ||
int* olabel = out_label_ + row * num_prior_box_ + col; | ||
T* obox_wt = out_box_wt_ + row * num_prior_box_ + col; | ||
T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Fix #8192