Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Graph reuse eager lbn without create duplicate variable op #6981

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/tensor_name_scope.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/compiler.h"
Expand Down Expand Up @@ -253,6 +254,9 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
// TODO(chengcheng): CHECK job valid for each rank.
JUST(CreateAndRegisterNewVariableOpInJobPass());

// NOTE(chengcheng): TensorNameScope need to be cleared after current graph build.
chengtbf marked this conversation as resolved.
Show resolved Hide resolved
one::TensorNameScope::Global()->Clear();

// NOTE(chengcheng): Global<JobDesc> need be clear before GlobalJobDescScope construct.
if (Global<JobDesc>::Get() != nullptr) { Global<JobDesc>::Delete(); }

Expand Down
24 changes: 18 additions & 6 deletions oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const
const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);
CHECK_OR_RETURN(input_tensor->is_eager());

auto infer_ctx = JUST(GetCurInferCtx());

// Check outputs num and setup output tensor properties.
CHECK_EQ_OR_RETURN(outputs->size(), 1);
CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);
CHECK_OR_RETURN(!(*outputs)[0]);

const std::string& opt_lbn = TensorNameScope::Global()->Lookup(input_tensor);
if (!opt_lbn.empty()) {
// NOTE(chengcheng): This eager tensor has been feed as variable op before, so we just use the
chengtbf marked this conversation as resolved.
Show resolved Hide resolved
// lbn, and will NOT create duplicate variable op again.
(*outputs)[0] = input_tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里记得计划改成返回一个lazy tensor?

VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name()
<< " try to add variable op name : \n: " << op_expr.op_name()
<< " but it has been created as : " << opt_lbn << ". \n So we just reuse this tensor.";
return Maybe<void>::Ok();
}

std::shared_ptr<Scope> scope = JUST(NewScopeWithParallelDescByTensor(input_tensor));

OperatorConf op_conf;
Expand All @@ -275,7 +293,6 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const
if (unlikely(l2 != 0.0)) { var_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l2(l2); }
}

auto infer_ctx = JUST(GetCurInferCtx());
VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name()
<< " try to add op: \n: " << op_conf.DebugString() << std::endl;
OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(op_conf));
Expand All @@ -288,11 +305,6 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const
int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf));
auto blob_parallel_desc = JUST(GetSymbol<cfg::ParallelConf, ParallelDesc>(parallel_desc_sym_id));

// Check outputs num and setup output tensor properties.
CHECK_EQ_OR_RETURN(outputs->size(), 1);
CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);
CHECK_OR_RETURN(!(*outputs)[0]);

const std::string obn = "out"; // NOTE(chengcheng): obn is NOT op_expr.indexed_obns
(*outputs)[0] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true,
/* is_local */ input_tensor->is_local()));
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/framework/tensor_name_scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,10 @@ void TensorNameScope::Record(const std::shared_ptr<Tensor>& tensor, const std::s
tensor_names_[key] = name;
}

void TensorNameScope::Clear() {
std::lock_guard<std::mutex> lock(mutex_);
tensor_names_.clear();
}

} // namespace one
} // namespace oneflow
3 changes: 3 additions & 0 deletions oneflow/core/framework/tensor_name_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class TensorNameScope {

void Record(const std::shared_ptr<Tensor>& tensor, const std::string& name);

// NOTE(chengcheng): TensorNameScope need to be cleared after current graph build.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

接口这里应该不需要加使用那里需要的注释

chengtbf marked this conversation as resolved.
Show resolved Hide resolved
void Clear();

private:
TensorNameScope() : default_tensor_name_("") {}
virtual ~TensorNameScope() = default;
Expand Down