Skip to content

Commit

Permalink
Refactor add straighten algo interface (#8435)
Browse files Browse the repository at this point in the history
* feat(*): export straighten nodes algorithm inferface

* export documentation

* Update python/oneflow/nn/graph/graph_config.py

Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>

Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>
  • Loading branch information
wyg1997 and Yipeng1994 authored Jun 16, 2022
1 parent c532e7c commit 60e7800
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Base class for running neural networks in Static Graph Mode.
set_zero_redundancy_optimizer_mode,
set_zero_redundancy_optimizer_min_size_after_split,
enable_cudnn_conv_heuristic_search_algo,
enable_random_straighten_nodes,
:member-order: bysource


Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ void ForEachOpGraphNecessaryCtrlEdge(

} // namespace

TaskGraph::TaskGraph() {
TaskGraph::TaskGraph(bool random_straighten_nodes) {
OpGraph* op_graph = Global<OpGraph>::Get();
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
Expand Down Expand Up @@ -451,7 +451,7 @@ TaskGraph::TaskGraph() {
}
});

if (ParseBooleanFromEnv("ONEFLOW_RANDOM_STRAIGHTEN_NODES", false)) {
if (random_straighten_nodes) {
SetOrderInGraphForEachNode();
} else {
StraightenNodes(this, &ordered_task_nodes_);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
OF_DISALLOW_COPY_AND_MOVE(TaskGraph);
~TaskGraph() override;

explicit TaskGraph();
explicit TaskGraph(bool random_straighten_nodes);

const char* TypeName() const override { return "TaskGraph"; }
void RemoveEmptyRegsts();
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {

// Step3: build task_gph.
// TODO(levi): we can rewrite this part of code in visitor pattern.
auto task_gph = std::make_unique<TaskGraph>();
auto task_gph =
std::make_unique<TaskGraph>(job->job_conf().random_straighten_nodes_in_task_graph());
using std::placeholders::_1;
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ message JobConfigProto {
optional bool cudnn_conv_enable_pseudo_half = 600 [default = true];
optional bool enable_auto_mixed_precision = 602 [default = false];
optional bool enable_quantization_aware_training = 603 [default = false];

optional bool random_straighten_nodes_in_task_graph = 700 [default = false];

optional int64 concurrency_width = 1000 [default = 128];

Expand Down
10 changes: 10 additions & 0 deletions python/oneflow/nn/graph/graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,16 @@ def build(self, x):
"""
self.proto.cudnn_conv_heuristic_search_algo = mode

def enable_random_straighten_nodes(self, mode: bool = False):
r""" Whether turn off the straighten algorithm.
If using nccl compute stream, turning it on might not speed up the training.
If not using nccl compute stream, turning it on might slow down data parallelism by 0.6% and slow down model parallelism by 6%.
The switch is off by default (i.e. use the straighten algorithm by default).
"""
self.proto.random_straighten_nodes_in_task_graph = mode

def _generate_optimizer_and_variable_configs(
self, opt_dict: OptDict = None, variables_conf: OrderedDict = None,
):
Expand Down

0 comments on commit 60e7800

Please sign in to comment.