-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add C++ EinsumOp which support 2 operands einsum. #42105
Conversation
…ast 3. MultiVariable
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/phi/core/ddim.h" | ||
#include "paddle/phi/kernels/impl/einsum_impl.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<memory>,<unordered_map>,, einsum_impl.h这4个头文件是必需的么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前einsum impl中有InferMeta,后续会提PR将这个函数移动到 unary.cc 中,下个删除
paddle/fluid/operators/einsum_op.cc
Outdated
void Make() override { | ||
AddInput("Operands", "(Tensor), The input tensor of svd op.") | ||
.AsDuplicable(); | ||
AddOutput("Out", "(Tensor), The output VH tensor of svd op."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AddOutput("Out", "(Tensor), The output VH tensor of svd op."); | |
AddOutput("Out", "(Tensor), The output tensor of einsum op."); |
paddle/fluid/operators/einsum_op.cc
Outdated
using framework::OperatorWithKernel::OperatorWithKernel; | ||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
auto in_x = "Operands"; | ||
auto out_x_g_n = framework::GradVarName(in_x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto out_x_g_n = framework::GradVarName(in_x); | |
auto x_name = "Operands"; | |
auto x_grad_name = framework::GradVarName(x_name); |
注意命名规范
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/einsum_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "paddle/phi/kernels/einsum_kernel.h" 要放在第一行,且空行隔开
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/einsum_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "paddle/phi/kernels/einsum_kernel.h" 要放在第一行,且空行隔开
} else { | ||
PADDLE_ENFORCE_EQ((*labelshape)[c], | ||
op_dim[dim_ptr], | ||
phi::errors::InvalidArgument("")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的报错信息需要完善下
PADDLE_ENFORCE_EQ(v == 1 || broadcast_dims->at(idx) == 1 || | ||
broadcast_dims->at(idx) == v, | ||
true, | ||
phi::errors::InvalidArgument("can't broad cast.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的报错信息需要完善下
const LabelMap& labelshape, | ||
std::vector<int>* output_dims) { | ||
for (int c : right) { | ||
if (c == '.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议block用{} 括住
std::vector<int>* output_dims, | ||
std::string* right) { | ||
auto results = paddle::string::split_string(equation, "->"); | ||
auto left = results[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否需要先ENFORCE下 results的size是否符合预期?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个步骤在 ValidEinsum 函数中已经验证了。这里就没验证了。
VLOG(5) << "Einsum Infershape: output dims:" | ||
<< paddle::string::join_strings(output_dims, ","); | ||
print_label(all_labels, labeltype); | ||
print_label(all_labels, labelshape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个print放到VLOG下来触发吧,另外,op的vlog等级可以设置为3
namespace ops = paddle::operators; | ||
|
||
DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor, | ||
PD_INFER_META(phi::EinsumInferShape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是少提交了一个文件?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(test_einsum_op PROPERTIES TIMEOUT 120)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(test_einsum_op PROPERTIES TIMEOUT 120)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
细节麻烦追加完善下
} | ||
} | ||
|
||
inline void EinsumInferShape(const std::vector<const MetaTensor*>& inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议移到infermeta子目录
auto it = std::find(all_labels.begin(), all_labels.end(), c); | ||
PADDLE_ENFORCE_NE(it, | ||
all_labels.end(), | ||
phi::errors::InvalidArgument("Must in all_labels.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息过短的问题后面也一并完善下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* full api fix * when out is None, go old dygraph mode * by static check * first version: support 2-inputs forwards. TODO: 1. backward 2. BroadCast 3. MultiVariable * time out -> 120
PR types
New features
PR changes
APIs
Describe
This PR developed C++ EinsumOp.
TODO: