Skip to content

Commit

Permalink
Feat/debug pass (#7054)
Browse files Browse the repository at this point in the history
* add pass debug

* debug pass

* refine comment of fuse add pass

* auto format by CI

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 24, 2021
1 parent 120ecad commit ab63596
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
34 changes: 31 additions & 3 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,8 +965,36 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
Global<JobDesc>::Delete();
auto scope = std::make_unique<GlobalJobDescScope>(mut_job()->job_conf(), job_id());
JobPassCtx job_pass_ctx(GlobalJobDesc());
auto DoPass = [&](const std::string& pass_name) -> Maybe<void> {
return JobPass4Name(pass_name)(mut_job(), &job_pass_ctx);
const auto& job_name = job().job_conf().job_name();
auto LogJob = [&](const std::string& name_suffix) -> void {
std::string full_log_name =
job_name + "-job_id_" + std::to_string(job_id()) + "-" + name_suffix;
TeePersistentLogStream::Create(full_log_name)->Write(job());
Global<OpGraph>::New(job());
Global<OpGraph>::Get()->ToDotWithFilePath(full_log_name + ".dot");
Global<OpGraph>::Delete();
};
std::string debug_pass_name = GetStringFromEnv("ONEFLOW_DEBUG_PASS", "");
auto NeedLogJob = [&](const std::string& pass_name) -> bool {
if ("ALL" == debug_pass_name) {
return true;
} else if (pass_name == debug_pass_name) {
return true;
} else {
return false;
}
};
auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe<void> {
if (unlikely(NeedLogJob(pass_name))) {
std::string cnt_str = cnt > 0 ? std::to_string(cnt) : "";
LogJob(pass_name + cnt_str + "-before");
}
JUST(JobPass4Name(pass_name)(mut_job(), &job_pass_ctx));
if (unlikely(NeedLogJob(pass_name))) {
std::string cnt_str = cnt > 0 ? std::to_string(cnt) : "";
LogJob(pass_name + cnt_str + "-after");
}
return Maybe<void>::Ok();
};

if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()
Expand Down Expand Up @@ -1007,7 +1035,7 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
JUST(DoPass("FuseAddToOutputPass"));
// run this pass again to fuse ops created in the first run.
// TODO(guoran): loop multiple times inside the pass
JUST(DoPass("FuseAddToOutputPass"));
JUST(DoPass("FuseAddToOutputPass", 1));
JUST(DoPass("IndexedSlicesOptimizerRewritePass"));
JUST(DoPass("SplitSparseSoftmaxCrossEntropyOpPass"));
JUST(DoPass("DoParallelCastBeforeWideningTypeCast"));
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
if (user_op_conf.has_input("_add_to_output", 0)) { return false; }
return true;
};

// Save all op's ctrl in op name in a set.
HashSet<std::string> ctrl_in_op_names;
op_graph.ForEachNode([&](const OpNode* op_node) {
for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {
Expand Down Expand Up @@ -113,6 +115,7 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
} else {
return;
}
// Make a new_add_to_op to fuse add_n into this op.
OperatorConf new_add_to_op_conf = add_to_node->op().op_conf();
*(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))["_add_to_output"]
.mutable_s()
Expand All @@ -124,6 +127,7 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) {
op_name2op_conf[consumer_op_name] = consumer->op().op_conf();
}
// Make add_n op's consumer to consume the new_add_to_op
for (const std::string& ibn : consumer->op().input_bns()) {
if (consumer->op().BnInOp2Lbi(ibn) == out) {
OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name);
Expand All @@ -133,6 +137,7 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
}
}
}
// Add the add_n op to removing list
delete_ops.emplace_back(op_conf);
});
job_builder->DelOps(delete_ops);
Expand Down

0 comments on commit ab63596

Please sign in to comment.