-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUAA】【Infer Symbolic Shape】add set_value_with_tensor, sigmoid_cross_entropy_with_logits,swiglu #67098
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
test_set_value_op.py, test_sigmoid_cross_entropy_with_logits_op.py,test_swiglu_op.py 中有对应optest单测, CI-Coverage 未覆盖位置并不是修改添加的op。 |
} else { | ||
std::vector<symbol::DimExpr> x_shape = x_shape_or_data.shape(); | ||
int x_last = static_cast<int>(x_shape[rank - 1].Get<std::int64_t>()); | ||
infer_context->AddEqualCstr(symbol::DimExpr{x_last % 2}, | ||
symbol::DimExpr{0}); | ||
x_shape[rank - 1] = symbol::DimExpr{x_last / 2}; | ||
infer_context->SetShapeOrDataForValue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接对DimExpr添加约束,取int就默认是静态shape了
pos_shape_or_data.shape()[i]); | ||
} | ||
} | ||
infer_context->SetShapeOrDataForValue(op->result(0), input_shape_or_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好将Shape取出来单独设置,这里输入输出的data区应该不一样吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValueWithTensor) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个PR没有给set_value在yaml中加接口,下一个PR可以顺带加一下(或者给加set_value的同学说一声)
…entropy_with_logits,swiglu (PaddlePaddle#67098)
PR Category
CINN
PR Types
Others
Description
添加set_value_with_tensor, sigmoid_cross_entropy_with_logits,swiglu算子符号推导接口实现。