Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon No.28] implement logcumsumexp #42267

Merged
merged 27 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5c3b6bb
implement logcumsumexp
tiancaishaonvjituizi Apr 26, 2022
4054f7c
polish
tiancaishaonvjituizi Apr 26, 2022
b8ade29
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Apr 30, 2022
1f98cc7
fix ci
tiancaishaonvjituizi May 3, 2022
518c75a
reformat
tiancaishaonvjituizi May 3, 2022
e94f42c
update
tiancaishaonvjituizi May 3, 2022
8c680e6
address reviews
tiancaishaonvjituizi May 9, 2022
442bc00
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi May 9, 2022
a3e50da
add OpTest
tiancaishaonvjituizi May 9, 2022
0b4b8ca
use user defined grad
tiancaishaonvjituizi May 13, 2022
3bf4cfe
add formula in docs, address reviews
tiancaishaonvjituizi May 13, 2022
34f57f1
remove 'reference' comment
tiancaishaonvjituizi May 14, 2022
661bff3
Update logcumsumexp_grad_kernel.h
tiancaishaonvjituizi May 14, 2022
30241bb
Update logcumsumexp_sig.cc
tiancaishaonvjituizi May 14, 2022
2454012
Update logcumsumexp_grad_impl.h
tiancaishaonvjituizi May 14, 2022
1734440
decrease input size, update python
tiancaishaonvjituizi May 16, 2022
d6a773e
Merge branch 'logcumsumexp' of github.com:tiancaishaonvjituizi/Paddle…
tiancaishaonvjituizi May 16, 2022
3e4953a
shrink test data size
tiancaishaonvjituizi May 17, 2022
790e616
fix sample code
tiancaishaonvjituizi May 17, 2022
9797d10
refine docs
tiancaishaonvjituizi May 18, 2022
3b4b8fe
update docs
tiancaishaonvjituizi May 25, 2022
57bb711
fix docs;test=document_fix
tiancaishaonvjituizi May 27, 2022
4601d17
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi May 27, 2022
250998c
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Jun 7, 2022
6a3647c
set test timeout to 30s
tiancaishaonvjituizi Jun 7, 2022
d6c7aa7
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Jun 8, 2022
13edc4f
reformat
tiancaishaonvjituizi Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion paddle/fluid/operators/cumsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ class CumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
};

class CumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -74,17 +87,70 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of logcumsumexp operator");
AddOutput("Out", "Output of logcumsumexp operator");
AddAttr<int>("axis",
"The dimension to accumulate along. -1 means the last "
"dimension [default -1].")
.SetDefault(-1);
AddAttr<bool>("flatten",
"Whether to compute the logcumsumexp over the flattened array. "
"[default false].")
.SetDefault(false);
AddAttr<bool>("exclusive",
"Whether to perform exclusive logcumsumexp. [default false].")
.SetDefault(false);
AddAttr<bool>("reverse",
"If true, the logcumsumexp is performed in the reversed direction. "
"[default false].")
.SetDefault(false);
AddComment(R"DOC(
Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is the minimum value of dtype.
)DOC");
}
};

template <typename T>
class LogcumsumexpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("logcumsumexp_grad");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
BOOST_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse",
BOOST_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
BOOST_GET_CONST(bool, this->GetAttr("exclusive")));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumsumInferMeta));
PD_INFER_META(phi::CumInferMeta));
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker,
ops::LogcumsumexpGradMaker<paddle::framework::OpDesc>,
ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::CumGradOp);

REGISTER_OP_VERSION(cumsum)
.AddCheckpoint(
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_layout(x.layout());
}

void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
void CumInferMeta(const MetaTensor& x,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cumsum 和 Logcumsumexp 复用同一个 infer meta 函数

int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
auto x_dims = x.dims();
if (flatten) {
out->set_dims(phi::make_ddim({phi::product(x_dims)}));
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);
void CumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);

void DiagInferMeta(const MetaTensor& x,
int offset,
Expand Down
Loading