Skip to content

Commit

Permalink
Feat/zero optimization in nn.Graph (#7165)
Browse files Browse the repository at this point in the history
* debug

* modify graph.py

* fix bug about graph debug interface

* Fix nn graph variable bind (#6895)

* fix(AutoParallel): nn.Graph support auto_parallel change sbp

* fix(AutoParallel): use tensor.set_data interface and add print sbp info

* add comment

* hack check

* add test

* refine test

* refine test

* refine code

* add and refine zero

* fix test

* refine code

* rm debug log

* refine min size set

* add note

* debug zero

* fix cudnn config

* refine test doc

* add comment of check

* eager mode in graph pass

* format

* rebuid parameter according to sbp in synced plan

* auto format by CI

* fix code check

* fix test

* try init session at graph init

* refine and revert session init

* rm useless code

* add back print of sys conf

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: grybd <52237830+grybd@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: wyg1997 <wyg19970408@gmail.com>
  • Loading branch information
5 people authored Jan 15, 2022
1 parent 30d90ee commit d20a5d9
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 89 deletions.
97 changes: 64 additions & 33 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ const std::vector<std::string>& NNGraph::outputs_tensor_meta_str() const {
return outputs_tensor_meta_str_;
}

int64_t NNGraph::variable_op_size() const { return variable_op_name2eager_blob_.size(); }
int64_t NNGraph::variable_op_size() const { return variable_op_names_.size(); }

Maybe<void> NNGraph::RegisterInputOpNamesAndTensors(
const std::vector<std::string>& inputs_op_names,
Expand Down Expand Up @@ -144,21 +144,12 @@ Maybe<void> NNGraph::RegisterVariableOpNamesAndTensors(
const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors) {
JUST(vm::CurrentRankSync());
CHECK_EQ_OR_RETURN(variable_op_names.size(), variable_tensors.size());
CHECK_OR_RETURN(variable_op_name2eager_blob_.empty());
for (int32_t i = 0; i < variable_op_names.size(); ++i) {
const std::shared_ptr<one::Tensor>& var = variable_tensors.at(i);
CHECK_OR_RETURN(var->is_eager());
Blob* var_blob = nullptr;
if (var->is_consistent()) {
// NOTE(chengcheng): var_blob maybe nullptr when consistent tensor has no cur_rank_phy_tensor.
const std::shared_ptr<one::MirroredTensor> local_var = JUST(var->cur_rank_phy_tensor());
var_blob = JUST(local_var->eager_blob_object())->mut_blob();
} else {
var_blob = JUST(var->eager_blob_object())->mut_blob();
}
const std::string& var_name = variable_op_names.at(i);
CHECK_OR_RETURN(!var_name.empty());
CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second);
CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second);
CHECK_OR_RETURN(variable_op_names_.insert(var_name).second);
}
return Maybe<void>::Ok();
Expand All @@ -172,24 +163,15 @@ Maybe<void> NNGraph::RegisterFreeEagerTensorsToVariableOpNames() {
const std::string& var_name = pair.first;
const std::shared_ptr<one::Tensor>& var = pair.second;
CHECK_OR_RETURN(var->is_eager());
Blob* var_blob = nullptr;
if (var->is_consistent()) {
const std::shared_ptr<one::MirroredTensor> local_var = JUST(var->cur_rank_phy_tensor());
var_blob = JUST(local_var->eager_blob_object())->mut_blob();
} else {
var_blob = JUST(var->eager_blob_object())->mut_blob();
}
CHECK_OR_RETURN(!var_name.empty());
CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second);
CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second);
CHECK_OR_RETURN(variable_op_names_.insert(var_name).second);
}
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::CreateAndRegisterNewVariableOpInJobPass() {
JUST(vm::CurrentRankSync());
// NOTE(chengcheng): New EagerTensor need set LazyMode false.
auto lazy_mode_disabled_guard = LazyMode::Guard(/* is_enabled */ false);
OpGraph op_graph(job_);
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }
Expand Down Expand Up @@ -218,18 +200,19 @@ Maybe<void> NNGraph::CreateAndRegisterNewVariableOpInJobPass() {
} else {
OF_UNIMPLEMENTED();
}
std::shared_ptr<one::Tensor> tensor = JUST(one::functional::ConsistentConstant(
blob_desc.shape(), value, Symbol<DType>(dtype), placement, *sbp_tuple));
JUST(vm::CurrentRankSync());
const std::shared_ptr<one::MirroredTensor> local_var = JUST(tensor->cur_rank_phy_tensor());
Blob* var_blob = JUST(local_var->eager_blob_object())->mut_blob();
CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second);
CHECK_OR_RETURN(variable_op_names_.insert(var_name).second);

// NOTE(chengcheng): just for tensor lifetime hold by session context in graph lifetime valid.
Global<MultiClientSessionContext>::Get()->StoreFreeEagerTensorWithNameByGraphName(
name_, tensor, var_name);

{
// NOTE(chengcheng): New EagerTensor need set LazyMode false.
auto lazy_mode_disabled_guard = LazyMode::Guard(/* is_enabled */ false);
std::shared_ptr<one::Tensor> tensor = JUST(one::functional::ConsistentConstant(
blob_desc.shape(), value, Symbol<DType>(dtype), placement, *sbp_tuple));
CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, tensor).second);
CHECK_OR_RETURN(variable_op_names_.insert(var_name).second);

// NOTE(chengcheng): just for tensor lifetime hold by session context in graph lifetime
// valid.
Global<MultiClientSessionContext>::Get()->StoreFreeEagerTensorWithNameByGraphName(
name_, tensor, var_name);
}
VLOG(2) << "Lazy nn.Graph name " << name_ << " op : \n"
<< variable_op.op_conf().DebugString()
<< " created in JobPass, nn.Graph will new EagerTensor for this variable.\n";
Expand Down Expand Up @@ -297,11 +280,59 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table());

NewRuntimeBuffers();

JUST(GetVariableRealBlobAfterSyncPlan());
runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_));
runtime_inited_ = true;
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::GetVariableRealBlobAfterSyncPlan() {
CHECK_OR_RETURN(variable_op_name2eager_blob_.empty());
JUST(vm::CurrentRankSync());
JobBuildAndInferCtx* job_ctx = JUST(GetJobBuildAndInferCtx(name_));
auto job_id = job_ctx->job_id();
for (const std::string& var_name : variable_op_names_) {
auto iter = variable_op_name2tensor_.find(var_name);
CHECK_OR_RETURN(iter != variable_op_name2tensor_.end()) << var_name << " not found.";
std::shared_ptr<one::Tensor> tensor = iter->second;
Blob* var_blob = nullptr;
if (tensor->is_consistent()) {
cfg::NdSbpSignature var_nd_sbp_signature =
cfg::NdSbpSignature(plan_.job_id2op_attribute_ref_table()
.at(job_id)
.op_name2op_attribute()
.at(var_name)
.nd_sbp_signature());
cfg::NdSbp optimized_nd_sbp = var_nd_sbp_signature.bn_in_op2nd_sbp().at("out");
// Change variable tensor's impl with new sbp when job pass has changed their sbp.
if (*JUST(tensor->nd_sbp()) != optimized_nd_sbp) {
LOG(INFO) << "Graph with name " << name_ << " variable with name `" << var_name
<< "` changes its' sbp from " << NdSbpToString(*JUST(tensor->nd_sbp())) << " to "
<< NdSbpToString(optimized_nd_sbp) << " after compile optimization.";
std::vector<Symbol<cfg::SbpParallel>> optimized_sbp_parallels;
for (int i = 0; i < optimized_nd_sbp.sbp_parallel_size(); ++i) {
optimized_sbp_parallels.emplace_back(optimized_nd_sbp.sbp_parallel(i));
}
{
auto lazy_mode_disabled_guard = LazyMode::Guard(/* is_enabled */ false);
const auto& new_tensor = JUST(one::functional::ToConsistent(
tensor, JUST(tensor->parallel_desc()), optimized_sbp_parallels, {}));
JUST(vm::CurrentRankSync());
// Use tensor.set_data inferface and make new TensorImpl instead of the old one.
JUST(tensor->set_data(new_tensor));
}
}
const std::shared_ptr<one::MirroredTensor> local_var = JUST(tensor->cur_rank_phy_tensor());
var_blob = JUST(local_var->eager_blob_object())->mut_blob();
} else {
var_blob = JUST(tensor->eager_blob_object())->mut_blob();
}
CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second);
}
return Maybe<void>::Ok();
}

void NNGraph::NewRuntimeBuffers() {
// NOTE(chengcheng):
// 1. The BufferSize comes from job_conf.concurrency_width configured by user (default = 128)
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/framework/nn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class NNGraph final : public NNGraphIf {
private:
Maybe<void> RegisterFreeEagerTensorsToVariableOpNames();
Maybe<void> CreateAndRegisterNewVariableOpInJobPass();
Maybe<void> GetVariableRealBlobAfterSyncPlan();

void NewRuntimeBuffers();
void CloseRuntimeBuffers();
Expand All @@ -68,6 +69,7 @@ class NNGraph final : public NNGraphIf {
std::vector<bool> output_tensors_valid_;
std::vector<std::string> inputs_tensor_meta_str_;
std::vector<std::string> outputs_tensor_meta_str_;
HashMap<std::string, std::shared_ptr<one::Tensor>> variable_op_name2tensor_;
HashMap<std::string, Blob*> variable_op_name2eager_blob_;
HashSet<std::string> variable_op_names_;
Job job_;
Expand Down
8 changes: 6 additions & 2 deletions oneflow/core/job/oneflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,11 @@ void CheckNonDistributeOptimizerAvailable(const std::vector<std::shared_ptr<Job>
FOR_RANGE(int64_t, job_id, 0, jobs.size()) {
if (!IsEnabled(*jobs.at(job_id))) { continue; }
for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) {
if (op_conf.op_type_case() == OperatorConf::kVariableConf) { continue; }
if (op_conf.op_type_case() != OperatorConf::kVariableConf) { continue; }
if (var_names.find(op_conf.name()) == var_names.end()) {
var_names.emplace(op_conf.name());
} else {
// optimizer_placement_optimization jobs has a same variable in between them.
LOG(FATAL)
<< "Only support optimizer_placement_optimization when jobs not sharing same variable";
}
Expand All @@ -444,8 +445,9 @@ void CheckNonDistributeOptimizerAvailable(const std::vector<std::shared_ptr<Job>
FOR_RANGE(int64_t, job_id, 0, jobs.size()) {
if (IsEnabled(*jobs.at(job_id))) { continue; }
for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) {
if (op_conf.op_type_case() == OperatorConf::kVariableConf) { continue; }
if (op_conf.op_type_case() != OperatorConf::kVariableConf) { continue; }
if (var_names.find(op_conf.name()) != var_names.end()) {
// Other jobs has a same variable in optimizer_placement_optimization jobs.
LOG(FATAL)
<< "Only support optimizer_placement_optimization when jobs not sharing same variable";
}
Expand Down Expand Up @@ -913,6 +915,8 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defin
Maybe<void> CompileJobsAndMergePlans(const PbRpf<Job>& job_confs, Plan& plan) {
std::vector<std::shared_ptr<Job>> jobs(job_confs.size());
FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(job_confs.Get(i))); }
// These checks donot work in nn.Graph API because there is only on job compile each time.
// And nn.Graph Support training and evaluation share the same variable.
if (jobs.size() > 1) { CheckNonDistributeOptimizerAvailable(jobs); }
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs,
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job/plan_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(
CHECK(!var_name.empty());
CHECK_EQ(regst_desc->register_num(), 1);
CHECK_EQ(regst_desc->min_register_num(), 1);
CHECK_EQ(regst_desc->max_register_num(), 1);
// NOTE(xuxiaoyu): this check cannot pass when open ZeRO
// CHECK_EQ(regst_desc->max_register_num(), 1) << var_name;
regst_desc->set_variable_op_name(var_name);
}

Expand Down
38 changes: 18 additions & 20 deletions oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd,
Maybe<void> GetDataParallelVariableAndNaiveSuccNode(
const OpNode* start, const std::function<bool(const OpNode*)>& IsAllowed,
std::vector<const OpNode*>* out) {
// Find sequence like: vairable -> cast_fp32_to_fp16
if (!start->op().op_conf().has_variable_conf()) { return Maybe<void>::Ok(); }
const ParallelDesc& pd = start->parallel_desc();
if (pd.device_type() != DeviceType::kCUDA) { return Maybe<void>::Ok(); }
Expand Down Expand Up @@ -100,22 +101,6 @@ Maybe<void> GetDataParallelVariableAndNaiveSuccNode(
return Maybe<void>::Ok();
}

void ForEachOutNodeConsumingLbi(
const OpNode* node, const LogicalBlobId& lbi,
const std::function<void(const OpNode* out_node, const std::string& ibn)>& Handler) {
node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {
for (const std::string& ibn : out_node->op().input_bns()) {
if (out_node->op().BnInOp2Lbi(ibn) == lbi) { Handler(out_node, ibn); }
}
});
}

void ForEachOutNodeConsumingSoleOut(
const OpNode* node,
const std::function<void(const OpNode* out_node, const std::string& ibn)>& Handler) {
ForEachOutNodeConsumingLbi(node, node->op().BnInOp2Lbi(node->op().SoleObn()), Handler);
}

void SetBroadcastParallel4OpNodeIbn(JobBuilder* builder, const OpNode* node,
const std::string& ibn) {
OpBlobArg op_blob_arg;
Expand All @@ -127,10 +112,15 @@ void SetBroadcastParallel4OpNodeIbn(JobBuilder* builder, const OpNode* node,
}

void SetBroadcastParallel4Consumers(JobBuilder* builder, const SequencePtr& sequence) {
ForEachOutNodeConsumingSoleOut(sequence->GetLastNode(),
[&](const OpNode* out_node, const std::string& ibn) {
SetBroadcastParallel4OpNodeIbn(builder, out_node, ibn);
});
const OpNode* node = sequence->GetLastNode();
const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn());
node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {
for (const std::string& ibn : out_node->op().input_bns()) {
if (out_node->op().BnInOp2Lbi(ibn) == lbi) {
SetBroadcastParallel4OpNodeIbn(builder, out_node, ibn);
}
}
});
}

std::function<int64_t(const OpNode*)> MakeGetterOpNode2TopoOrder(const OpGraph& op_graph) {
Expand Down Expand Up @@ -158,6 +148,7 @@ void ForEachDataParallelNodeSequence(const OpGraph& op_graph,
auto OpNode2Order = MakeGetterOpNode2TopoOrder(op_graph);
op_graph.ForEachNode([&](const OpNode* node) {
std::vector<const OpNode*> nodes;
// Find sequence like: vairable -> cast_fp32_to_fp16
CHECK_JUST(GetDataParallelVariableAndNaiveSuccNode(node, IsAllowed, &nodes));
if (nodes.empty()) { return; }
const int64_t order = GetMinConsumerOrder(op_graph, nodes.back(), OpNode2Order);
Expand All @@ -178,6 +169,7 @@ void ForEachParallelSortedNodeSequence(
const std::function<bool(const SequencePtr&, const SequencePtr&)>& Comp,
const std::function<void(const ParallelDesc&, std::vector<SequencePtr>&&)>& Handler) {
HashMap<ParallelDesc, std::vector<SequencePtr>> parallel_desc2sequences;
// Find sequence like: vairable -> cast_fp32_to_fp16
ForEachDataParallelNodeSequence(op_graph, IsAllowed, [&](SequencePtr&& sequence) {
parallel_desc2sequences[sequence->parallel_desc()].emplace_back(std::move(sequence));
});
Expand Down Expand Up @@ -236,13 +228,18 @@ Maybe<void> RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder
if (n->op().op_conf().has_variable_conf()) {
const Shape shape(n->op().op_conf().variable_conf().shape());
const int64_t parallel_num = n->parallel_desc().parallel_num();
// Parameter needs to be able to evenly splited and one slice size >= threshold
return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num;
} else {
return IsS0SignatureSupported(n);
}
};
const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd,
std::vector<SequencePtr>&& sorted_sequences) {
// For all sorted sequnence, set the variable op in the sequence to S(0)
// and add ctrl edge to control the exectuion order between variable ops.
// A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass
// consume the fp16 variable and the optimizer consume the fp32 variable.
for (int64_t i = 0; i < sorted_sequences.size(); ++i) {
const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode();
OperatorConf new_var_op_conf = var_node->op().op_conf();
Expand All @@ -255,6 +252,7 @@ Maybe<void> RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder
new_var_op_conf.add_ctrl_in_op_name(prev_op_name);
}
builder->MutOpsOnlyOnce({new_var_op_conf});
// Set consumers to consum this variable op's cast op's output as Broadcast.
SetBroadcastParallel4Consumers(builder, sorted_sequences.at(i));
}
};
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ std::shared_ptr<const OperatorConf> Operator::shared_op_conf() const { return op
DeviceType Operator::device_type() const { return device_type_; }

const std::string& Operator::SoleIbn() const {
CHECK_EQ(input_bns().size(), 1);
CHECK_EQ(input_bns().size(), 1) << ", op_name " << op_name();
return input_bns().Get(0);
}
const std::string& Operator::SoleObn() const {
CHECK_EQ(output_bns().size(), 1);
CHECK_EQ(output_bns().size(), 1) << ", op_name " << op_name();
return output_bns().Get(0);
}
const std::string& Operator::SoleTbn() const {
Expand Down
10 changes: 10 additions & 0 deletions python/oneflow/nn/graph/graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"):
assert mode in ("distributed_split", "non_distributed")
self.proto.set_optimizer_placement_optimization_mode(mode)

def set_zero_redundancy_optimizer_min_size_after_split(self, value):
"""Set the min size of optimizer state/grad/parameter after split.
Args:
value (int): min size value.
"""
assert isinstance(value, int)
assert value >= 1
self.proto.set_optimizer_placement_optimization_threshold(value)

def enable_xla_jit(self, value=True):
"""Whether use xla_jit in xrt or not. When this option enable, oneflow will check all operators is supported by
xla_jit or not. Clustering supported operators as subgraph, then runing subgraph by xla_jit.
Expand Down
4 changes: 2 additions & 2 deletions python/oneflow/test/graph/test_graph_linear_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def one_iter():
return check_list

iter_num = 3
module_check_list = train_with_module()
graph_check_list = train_with_graph()
module_check_list = train_with_module(iter_num)
graph_check_list = train_with_graph(iter_num)
for i in range(iter_num):
# check equal on loss
test_case.assertTrue(
Expand Down
Loading

0 comments on commit d20a5d9

Please sign in to comment.