Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add target assigner operator for SSD detection. #8193

Merged
merged 4 commits into from
Feb 7, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/framework/mixed_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ class Vector : public std::vector<T> {
T *data() { return std::vector<T>::data(); }
const T *data() const { return std::vector<T>::data(); }

T *data(const platform::Place &place) {
if (platform::is_cpu_place(place)) {
return data();
} else {
return cuda_data();
}
}

/* Synchronize host vector to device vector */
void CopyToCUDA();
/* Synchronize device vector to host vector */
Expand Down
172 changes: 172 additions & 0 deletions paddle/operators/target_assign_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/target_assign_op.h"

namespace paddle {
namespace operators {

class TargetAssignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
// checkout inputs
PADDLE_ENFORCE(ctx->HasInput("EncodedGTBBox"),
"Input(EncodedGTBBox) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("GTScoreLabel"),
"Input(GTScoreLabel) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("MatchIndices"),
"Input(MatchIndices) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("NegIndices"),
"Input(NegIndices) of TargetAssignOp should not be null");

// checkout outputs
PADDLE_ENFORCE(
ctx->HasOutput("PredBBoxLabel"),
"Output(PredBBoxLabel) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PredBBoxWeight"),
"Output(PredBBoxWeight) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PredScoreLabel"),
"Output(PredScoreLabel) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PredScoreWeight"),
"Output(PredScoreWeight) of TargetAssignOp should not be null.");

auto blabel_dims = ctx->GetInputDim("EncodedGTBBox");
auto slabel_dims = ctx->GetInputDim("GTScoreLabel");
auto mi_dims = ctx->GetInputDim("MatchIndices");
auto neg_dims = ctx->GetInputDim("NegIndices");

PADDLE_ENFORCE_EQ(blabel_dims.size(), 3UL,
"The rank of Input(EncodedGTBBox) must be 3.");
PADDLE_ENFORCE_EQ(slabel_dims.size(), 2UL,
"The rank of Input(GTScoreLabel) must be 2.");
PADDLE_ENFORCE_EQ(mi_dims.size(), 2UL,
"The rank of Input(MatchIndices) must be 2.");
PADDLE_ENFORCE_EQ(neg_dims.size(), 2UL,
"The rank of Input(NegIndices) must be 2.");

PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0],
"The 1st dimension of Input(EncodedGTBBox) and "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1st dimension --> The 1st dimension(means the number of ground-truth bounding boxes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"Input(GTScoreLabel) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1],
"The 2nd dimension of Input(EncodedGTBBox) and "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2nd dimension(means the number of prior boxes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"Input(MatchIndices) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[2], 4,
"The 3rd dimension of Input(EncodedGTBBox) must be 4.");

auto n = mi_dims[0];
auto np = mi_dims[1];
ctx->SetOutputDim("PredBBoxLabel", {n, np, 4});
ctx->SetOutputDim("PredBBoxWeight", {n, np, 1});
ctx->SetOutputDim("PredScoreLabel", {n, np, 1});
ctx->SetOutputDim("PredScoreWeight", {n, np, 1});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("EncodedGTBBox")->type()),
ctx.device_context());
}
};

class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("EncodedGTBBox",
"(LoDTensor), The encoded ground-truth bounding boxes with shape "
"[Ng, Np, 4], where Ng is the total number of ground-truth boxes "
"in this mini-batch, Np the number of predictions, 4 is the "
"number of coordinate in [xmin, ymin, xmax, ymax] layout.");
AddInput("GTScoreLabel",
"(LoDTensor, default LoDTensor<int>), The input ground-truth "
"labels with shape [Ng, 1], where the Ng is the same as it in "
"the input of EncodedGTBBox.");
AddInput("MatchIndices",
"(Tensor, default LoDTensor<int>), The input matched indices "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default Tensor<int>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"with shape [N, Np], where N is the batch size, Np is the same "
"as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the j-th prior box is not matched to any ground-truh "
"box in i-th instance.");
AddInput("NegIndices",
"(LoDTensor, default LoDTensor<int>), The input negative example "
"indics with shape [Neg, 1], where is the total number of "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indics --> indices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"negative example indices.");
AddAttr<int>("background_label",
"(int, default 0), Label id for background class.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Label id of background class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

.SetDefault(0);
AddOutput("PredBBoxLabel",
"(Tensor), The output encoded ground-truth labels "
"with shape [N, Np, 4], N is the batch size and Np, 4 is the "
"same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
"box for background_label_id in i-th instance.");
AddOutput("PredBBoxWeight",
"(Tensor), The weight for PredBBoxLabel with the shape "
"of [N, Np, 1]");
AddOutput("PredScoreLabel",
"(Tensor, default Tensor<int>), The output score labels for "
"each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
"is -1, PredScoreLabel[i][j] = background_label_id.");
AddOutput("PredScoreWeight",
"(Tensor), The weight for PredScoreLabel with the shape "
"of [N, Np, 1]");
AddComment(R"DOC(
This operator is, for given the encoded boxes between prior boxes and
ground-truth boxes and ground-truth class labels, to assign classification
and regression targets to each prior box as well as weights to each
prior box. The weights is used to specify which prior box would not contribute
to training loss.

TODO(dang qingqing) add an example.

)DOC");
}
};

template <typename T>
struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices,
const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label, T* out_label_wt) {
for (int i = 0; i < num; ++i) {
for (int j = lod[i]; j < lod[i + 1]; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_t j

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

int id = neg_indices[j];
out_label[i * num_prior_box + id] = background_label;
out_label_wt[i * num_prior_box + id] = static_cast<T>(1.0);
}
}
}
};

template struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, float>;
template struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, double>;

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(target_assign, ops::TargetAssignOp,
ops::TargetAssignOpMaker);
REGISTER_OP_CPU_KERNEL(
target_assign,
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, float>,
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, double>);
61 changes: 61 additions & 0 deletions paddle/operators/target_assign_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/target_assign_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void UpdateTargetLabelKernel(const int* neg_indices,
const size_t* lod, const int num,
const int num_prior_box,
const int background_label,
int* out_label, T* out_label_wt) {
int bidx = blockIdx.x;
int st = lod[bidx];
int ed = lod[bidx + 1];

for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
int id = neg_indices[i];
out_label[bidx * num_prior_box + id] = background_label;
out_label_wt[bidx * num_prior_box + id] = 1.;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bidx * num_prior_box has appeared many times and it is inside a loop, so it should be optimized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
}

template <typename T>
struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const int* neg_indices, const size_t* lod, const int num,
const int num_prior_box, const int background_label,
int* out_label, T* out_label_wt) {
const int block_size = 256;
const int grid_size = num;
UpdateTargetLabelKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>(
neg_indices, lod, num, num_prior_box, background_label, out_label,
out_label_wt);
}
};

template struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, float>;
template struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, double>;

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
target_assign,
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, float>,
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, double>);
Loading