-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewIR]new ir dygraph to static supoort gpu #55620
[NewIR]new ir dygraph to static supoort gpu #55620
Conversation
… add_kernel_dialect
… add_kernel_dialect
… lower_pd_op_to_kernel_dialect
… lower_pd_op_to_kernel_dialect
… lower_pd_op_to_kernel_dialect
… lower_pd_op_to_kernel_dialect
… new_interprector_support_new_IR
… new_interprector_support_new_IR
… new_interprector_support_new_IR
… new_interprector_support_new_IR
… dygraph2static_support_new_ir
… fix_new_ir_train_step_bug
… fix_new_ir_train_step_bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} | ||
// for (size_t block_idx = 0; block_idx < legacy_program_->Size(); | ||
// block_idx++) { | ||
// const BlockDesc& block = legacy_program_->Block(block_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无用注释是否需要删除捏?
@@ -2461,6 +2464,10 @@ | |||
extra : | |||
attrs : [bool use_mkldnn=false] | |||
|
|||
- op : shaddow_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shadow
// translator here | ||
|
||
std::unique_ptr<::ir::Program> ir_program; | ||
if (FLAGS_enable_new_ir_in_executor) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分新 IR 的逻辑适配逻辑可否放到一个单独的函数或工具类中,包括后续其他模块的适配,都不建议这样直接把实现代码放进来,后续不利于快速迭代,以及未来的旧代码下线,因为不清楚「新旧边界」在哪里。
auto program = std::make_unique<::ir::Program>(ir_ctx); | ||
|
||
std::set<std::string> set_output_names; | ||
// TODO(phlrain): no end add all the input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not important,但辛苦也注意下所有comment的语法、语义的规范性,后续可能会有很多外部开发者来协作。
std::set<std::string> set_output_names; | ||
// TODO(phlrain): no end add all the input | ||
for (auto op_desc : forward_program->Block(0).AllOps()) { | ||
for (const auto &n : op_desc->Outputs()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (const auto &n : op_desc->Outputs()) { | |
for (const auto &out : op_desc->Outputs()) { |
n 不是一个好的命名。
std::set<std::string> set_parameter_names; | ||
// TODO(phlrain): no need add all the input | ||
for (auto op_desc : backward_program->Block(0).AllOps()) { | ||
for (const auto &n : op_desc->Inputs()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (const auto &n : op_desc->Inputs()) { | |
for (const auto &input : op_desc->Inputs()) { |
} | ||
} | ||
|
||
for (auto &t : output_names) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto &t : output_names) { | |
for (auto &out_name : output_names) { |
@@ -362,7 +389,8 @@ void BuildScope(const ir::Block& block, | |||
|
|||
if (op_name == "pd.feed" || op_name == "pd.fetch" || | |||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" || | |||
op_name == "builtin.get_parameter" || op_name == "builtin.slice") { | |||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" || | |||
op_name == "pd.feed_with_place" || op_name == "pd.shaddow_output") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
special_handlers["one_hot_v2"] = OneHotTranscriber(); | ||
special_handlers["add_n"] = AddNOpTranscriber(); | ||
special_handlers["sum"] = AddNOpTranscriber(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里后续我们是否可以按照字母序排列?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我后面重构下吧,把OpTranscriber拆分到单独文件里面,方便维护
// for (size_t block_idx = 0; block_idx < legacy_program_->Size(); | ||
// block_idx++) { | ||
// const BlockDesc& block = legacy_program_->Block(block_idx); | ||
// SetStopGradientAttributeForAllValue(block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么注释了?似乎没有TODO
@@ -24,6 +24,12 @@ template <typename T, typename Context> | |||
void FeedWithPlaceKernel(const Context& ctx, | |||
int64_t index, | |||
phi::DataType data_type, | |||
// std::string name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是要删掉么?还是知识临时加的?
@@ -159,8 +160,63 @@ def test_with_new_ir(self): | |||
np.testing.assert_array_equal(out[0], gold_res) | |||
|
|||
|
|||
class TestNewIrDygraph(unittest.TestCase): | |||
def test_with_new_ir(self): | |||
paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
更推荐的做法是重写 SetUp() 和 TearDown() 来控制disable_static 和 enable_static
@@ -165,6 +170,11 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { | |||
auto& op_translator = OpTranslator::instance(); | |||
for (auto op : block.AllOps()) { | |||
OpTranslateFn& fn = op_translator[op->Type()]; | |||
if (op->Type() == "shaddow_output") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么需要跳过shaddow_output,我看ShaddowOutputOpTranscriber中少了一段把输出记录在param_map的逻辑,这是什么原因?
auto create_op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name()); | ||
ir::Operation* operation = | ||
ir::Operation::Create(op_inputs, attribute_map, {}, create_op_info); | ||
program->block()->push_back(operation); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shaddow_output是一个有输出的Op,这里为什么没有处理输出?
special_handlers["one_hot_v2"] = OneHotTranscriber(); | ||
special_handlers["add_n"] = AddNOpTranscriber(); | ||
special_handlers["sum"] = AddNOpTranscriber(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我后面重构下吧,把OpTranscriber拆分到单独文件里面,方便维护
@@ -121,6 +123,9 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { | |||
std::unordered_set<std::string> inner_defining_variables; | |||
|
|||
for (auto op_desc : block.AllOps()) { | |||
if (op_desc->Type() == "feed") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么跳过feed?跳过的话可能导致多余的GetParameterOp吧
… fix_new_ir_train_step_bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM For yaml
* add kernel dialect * change DenseTensorTypeStorage to DenseTensorType * add test case` * add first pd_op to kernel dialect * lower pd op to kernel dialect * update * update * remove useless code * add attrite print test * fix bug * update * update * update * update * polish code * fix bug * polish code and add python test * add test * fix test error * relax constraint when inserting get_parameter * add env flag * fix bug * dygraph2static support new ir * fix bug * revert test env * change cc_test_old to cc_test * update * fix build_static bug * update test * fix type test error * udpate cmake * disable test in windows * fix inference compile * fix program translator error * only run on cpu, not support gpu yet * fix conflict * polish code * fix bug * add feed with place op * update * remove useless unitest * udpate mkldnn * update * update * align mkldnn version * new ir support builtin slice op * fix bug * fix phi kernel adaptor bug * add enable static * add enable_static * remove useless test case * change feed list to single variable * update * add feed with place and shaddow output op * fix bug * remove usless code * support gpu * fix bug * fix bug * remove template * add more data type * fix cimpile bug * udpate * remove useless code * revert dygraph2st test * remove usless code * revert op * fix bug * new ir dygraph2static support gpu * remove usless code * code polish * add const * revert code and remove useless code * revert code * revert legacy op yaml * remove useless code * delete std::move --------- Co-authored-by: kangguangli <kangguangli@hotmail.com>
* add kernel dialect * change DenseTensorTypeStorage to DenseTensorType * add test case` * add first pd_op to kernel dialect * lower pd op to kernel dialect * update * update * remove useless code * add attrite print test * fix bug * update * update * update * update * polish code * fix bug * polish code and add python test * add test * fix test error * relax constraint when inserting get_parameter * add env flag * fix bug * dygraph2static support new ir * fix bug * revert test env * change cc_test_old to cc_test * update * fix build_static bug * update test * fix type test error * udpate cmake * disable test in windows * fix inference compile * fix program translator error * only run on cpu, not support gpu yet * fix conflict * polish code * fix bug * add feed with place op * update * remove useless unitest * udpate mkldnn * update * update * align mkldnn version * new ir support builtin slice op * fix bug * fix phi kernel adaptor bug * add enable static * add enable_static * remove useless test case * change feed list to single variable * update * add feed with place and shaddow output op * fix bug * remove usless code * support gpu * fix bug * fix bug * remove template * add more data type * fix cimpile bug * udpate * remove useless code * revert dygraph2st test * remove usless code * revert op * fix bug * new ir dygraph2static support gpu * remove usless code * code polish * add const * revert code and remove useless code * revert code * revert legacy op yaml * remove useless code * delete std::move --------- Co-authored-by: kangguangli <kangguangli@hotmail.com>
PR types
Bug fixes
PR changes
Others
Description
Others
Pcard-67164