Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modified huber loss operator #3987

Merged
merged 8 commits into from
Sep 20, 2017
Merged

Conversation

pkuyym
Copy link
Contributor

@pkuyym pkuyym commented Sep 9, 2017

Resolves #3923

};

template <typename T>
struct ModifiedHuberLossForward {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么ModifiedHuberLossForward单独列出来,而backward却没有呢?

Copy link
Contributor Author

@pkuyym pkuyym Sep 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前向是因为CPU和GPU共用,但是BP时两者分离,而且CPU逻辑比较简单,所以直接写循环了,GPU实现用的thrust,所以单列了

} else if (inter_val_ptr[i] < 1) {
x_grad_ptr[i] = -2 * (1 - inter_val_ptr[i]) * (2 * y_ptr[i] - 1) *
out_grad_ptr[i];
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个反向写的不对吧。92行和95行,都不需要再乘一个out_grad_ptr[i]

Copy link
Contributor

@qingqing01 qingqing01 Sep 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luotao1 在operator里,loss的写法和Paddle Layer的写法不一样,output_grad即使是1.0也是框架自动设置,或loss op外面自动设置的,所以需要乘out_grad_ptr[i]


PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
"Dimensions of X and Y must be the same.");
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework::arity(x->dims()) --> x->dims().size()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"Dimensions of X and Y must be the same.");
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2,
"Tensor rank of X must be 2.");
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "Second dimension of X must be 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--> The second
or The 2nd for short.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input value of ModifiedHuberLossOp.");
AddInput("Y", "Target labels of ModifiedHuberLossOp.");
AddOutput("intermediate_val",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have the same naming style? e.g. IntermediateVal ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def modified_huber_loss_forward(val):
if val < -1:
return -4 * a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a --> val. Done.

luotao1
luotao1 previously approved these changes Sep 20, 2017
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pkuyym pkuyym merged commit 51f1148 into PaddlePaddle:develop Sep 20, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants