Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add target assigner operator for SSD detection. #8193

Merged
merged 4 commits into from
Feb 7, 2018

Conversation

qingqing01
Copy link
Contributor

@qingqing01 qingqing01 commented Feb 6, 2018

Fix #8192

  • Support CPU and GPU.
  • There are 4 outputs:
    • target bboxes with shape [N, Np, 4].
    • weight for target bboxes with shape [N, Np, 1].
    • target labels with shape [N, Np, 1].
    • weight for target labels with shape [N, Np, 1].

@qingqing01 qingqing01 changed the title Add target_assign_op for SSD detection. Add target assigner operator for SSD detection. Feb 6, 2018
@@ -0,0 +1,172 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"labels with shape [Ng, 1], where the Ng is the same as it in "
"the input of EncodedGTBBox.");
AddInput("MatchIndices",
"(Tensor, default LoDTensor<int>), The input matched indices "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default Tensor<int>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,155 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,61 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


for n in range(batch_size):
gt_num = gt_lod[n + 1] - gt_lod[n]
ids = random.sample([i for i in range(num_prior)], gt_num)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample是截取指定长度的片断,是不是sample前可以再加一个shuffle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample采样出来是乱序的:

>>> import random
>>> print random.sample([i for i in range(10)], 4)
[7, 3, 5, 6]

"indics with shape [Neg, 1], where is the total number of "
"negative example indices.");
AddAttr<int>("background_label",
"(int, default 0), Label id for background class.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Label id of background class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label, T* out_label_wt) {
for (int i = 0; i < num; ++i) {
for (int j = lod[i]; j < lod[i + 1]; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_t j

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto* out_label_wt = ctx.Output<framework::Tensor>("PredScoreWeight");

PADDLE_ENFORCE_EQ(enc_gt_box->lod().size(), 1UL);
PADDLE_ENFORCE_EQ(gt_label->lod().size(), 1UL);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检查 gt_label->lod()和 enc_gt_box->lod() 的值是一致的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in following code.


import unittest
import numpy as np
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有用math,可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import unittest
import numpy as np
import math
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有用sys,可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

T* obox = out_box_ + (row * num_prior_box_ + col) * 4;
int* olabel = out_label_ + row * num_prior_box_ + col;
T* obox_wt = out_box_wt_ + row * num_prior_box_ + col;
T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

row * num_prior_box_ has computed many times in kernel. It can be optimized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
int id = neg_indices[i];
out_label[bidx * num_prior_box + id] = background_label;
out_label_wt[bidx * num_prior_box + id] = 1.;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bidx * num_prior_box has appeared many times and it is inside a loop, so it should be optimized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanghaox @chengduoZH Thanks for your review.

@@ -0,0 +1,172 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"The rank of Input(NegIndices) must be 2.");

PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0],
"The 1st dimension of Input(EncodedGTBBox) and "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"The 1st dimension of Input(EncodedGTBBox) and "
"Input(GTScoreLabel) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1],
"The 2nd dimension of Input(EncodedGTBBox) and "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"box in i-th instance.");
AddInput("NegIndices",
"(LoDTensor, default LoDTensor<int>), The input negative example "
"indics with shape [Neg, 1], where is the total number of "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"indics with shape [Neg, 1], where is the total number of "
"negative example indices.");
AddAttr<int>("background_label",
"(int, default 0), Label id for background class.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label, T* out_label_wt) {
for (int i = 0; i < num; ++i) {
for (int j = lod[i]; j < lod[i + 1]; ++j) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,61 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,155 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
int id = neg_indices[i];
out_label[bidx * num_prior_box + id] = background_label;
out_label_wt[bidx * num_prior_box + id] = 1.;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

T* obox = out_box_ + (row * num_prior_box_ + col) * 4;
int* olabel = out_label_ + row * num_prior_box_ + col;
T* obox_wt = out_box_wt_ + row * num_prior_box_ + col;
T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@qingqing01 qingqing01 merged commit ae0740c into PaddlePaddle:develop Feb 7, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants