Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add soft-label support for cross-entropy operator. #4081

Merged
merged 6 commits into from
Sep 19, 2017

Conversation

xinghai-sun
Copy link
Contributor

@xinghai-sun xinghai-sun commented Sep 13, 2017

Resolve #4080
Resolve #3898

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update and fix conflicts.

auto *X = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("label");
auto *x = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("Label");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add not null check for Input(X) and Input(Label). Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -17,59 +17,71 @@ limitations under the License. */
namespace paddle {
namespace operators {

class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
class CrossEntropyOp : public framework::OperatorWithKernel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luotao1 changes CrossEntropyOp -> OnehotCrossEntropyOp. Please use the new name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it supports not only one-hot cross-entropy but also soft-label cross-entropy, it would be better to use CrossEntropyOp instead of OnehotCrossEntropyOp.

// normal cross entropy
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0]);
}
ctx.Output<Tensor>("Y")->Resize({x->dims()[0]});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output<framework::LoDTensor>

Now must use Output<framework::LoDTensor> for output in forward and backward InferShape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qingqing01 是否需要按照新的命名规范将输出 Y 改为 Out ?(之前益群在FC中使用Y作为输出时,也被要求改为 Out ,是否需要统一?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with @Xreki , we both prefer "Loss" as the output name rather than "Out". I think Loss is more meaningful than "Out".

public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output< framework::LoDTensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto x = ctx.Input<Tensor>("X");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add not null check for Input(X). Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Y[i] = -log(X[i][j])
The second input (Label tensor) supports two kinds of shapes:
1) Rank(Label) = 1, Label[i] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add space before and after formula.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add space before and after formula.

@@ -21,17 +21,16 @@ namespace operators {
using Tensor = framework::Tensor;

template <typename T>
__host__ __device__ T clipping_log(const T x) {
__host__ __device__ T tolerable_value(const T x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include paddle/platform/hostdevice.h, then use HOSTDEVICE .

HOSTDEVICE T tolerable_value(const T x) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question here, if this function use __host__ __device__ in the declearation, why we need to implement it again in *.cc ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果换成 HOSTDEVICE,依据HOSTDEVICE的定义:

#ifdef __CUDACC__
#define HOSTDEVICE __host__ __device__
#define HOST __host__
#else
#define HOSTDEVICE
#define HOST
#endif

确实cpu、gpu可以公用这个tolerable_value

Copy link
Contributor

@lcy-seso lcy-seso Sep 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯~ 明白啦~ HOSTDEVICE 在 CPU 下为空。

self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.05)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could tune the max_relative_error smaller?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@lcy-seso
Copy link
Contributor

当label变成soft时,已经不再会有离散化的操作,是否可以直接调用 Egien?

Copy link
Contributor Author

@xinghai-sun xinghai-sun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All done. Thanks.

auto *X = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("label");
auto *x = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("Label");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// normal cross entropy
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0]);
}
ctx.Output<Tensor>("Y")->Resize({x->dims()[0]});
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto x = ctx.Input<Tensor>("X");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Y[i] = -log(X[i][j])
The second input (Label tensor) supports two kinds of shapes:
1) Rank(Label) = 1, Label[i] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -21,17 +21,16 @@ namespace operators {
using Tensor = framework::Tensor;

template <typename T>
__host__ __device__ T clipping_log(const T x) {
__host__ __device__ T tolerable_value(const T x) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.05)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@xinghai-sun
Copy link
Contributor Author

@lcy-seso 我觉得也是可以的,如果愿意容忍同一个if的两个分支:一个走cuda代码,一个走eigen。

@lcy-seso
Copy link
Contributor

lcy-seso commented Sep 16, 2017

caffe2 分为两个分支,可能更关键的是两者在计算上那个更高效。暂时未知。

// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
Copy link
Contributor

@lcy-seso lcy-seso Sep 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的小疑问请教一下 @qingqing01

  • kernel 函数里面的这个for循环,计算结果不会出错,但逻辑上感觉很怪。
  • grid中总会有一些block的thread是多余出来,不会严格对齐到输入数据,这里for循环的作用相当于跳过这些没有对齐的部分。
  • i += blockDim.x * gridDim.x 循环变量i只要增加一次,就会直接超过最大thread数,也就是这个kernel函数实际是并不需要循环多次,只计算输出向量的一个位置,逻辑上等价于判断i>= batch_size时,直接return,写成循环有什么考虑呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果针对下面设置grid, threadd的方式,确实不需要for循环。但如果将下面的grid的设置为一个固定的数,就是总共发起固定数目的总线程数,for循环就是有用的,有可能一个线程计算多个输出。这样这个kernel已经处理了边界,就不需要修改了。

int block = 512;
int grid = (n + block - 1) / block;

Copy link
Contributor

@lcy-seso lcy-seso Sep 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白啦~ 确实,cross entropy 这个 kernel 比较简单也比较特殊。grid 数目也已经计算好。

@lcy-seso lcy-seso mentioned this pull request Sep 17, 2017
2 tasks
auto *label = ctx.Input<Tensor>("Label");

PADDLE_ENFORCE_EQ(x->dims().size(), 2, "X's rank must be 2.");
PADDLE_ASSERT(label->dims().size() == 1 || label->dims().size() == 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in this morning, we should also use the label with rank 2 for the int label. Please help to modify it. And if so, there is no way to determine whether it is normal cross entropy or soft cross entropy by rank. Can we change to use attr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

using Tensor = framework::Tensor;

template <typename T>
HOSTDEVICE T tolerable_value(const T x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @lcy-seso said, both CPU and GPU kernel can use this common function, if we use HOSTDEVICE. Use this function to replace it in paddle/operators/cross_entropy_op.h file and delete this function in paddle/operators/cross_entropy_op.cc file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

sum += label[i * D + j] * log(X[i * D + j]);
}
Y[i] = -tolerable_value(sum);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put tolerable_value aftern log:

for (int j = 0; j < D; j++) {
  sum += - label[i * D + j] * tolerable_value(log(X[i * D + j]));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

CrossEntropy Operator.

The second input (Label tensor) supports two kinds of shapes:
1) Rank(Label) = 1, Label[i] indicates the class index for sample i:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If modify the int label rank, this doc comments also are needed to modify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add todo optimization for this kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

T sum = static_cast<T>(0);
for (int j = 0; j < class_num; ++j) {
sum += label_data[index] * std::log(x_data[index]);
y_data[i] = -tolerable_value(sum);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tolerable_value before std::log.

sum += - label_data[index] * tolerable_value(std::log(x_data[index]));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

qingqing01
qingqing01 previously approved these changes Sep 19, 2017
@qingqing01 qingqing01 merged commit 5b42d2b into PaddlePaddle:develop Sep 19, 2017
@qingqing01
Copy link
Contributor

qingqing01 commented Sep 19, 2017

Merge this PR, if there is any question, we will fix later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants