From a470ff43f2705914e57af5fda6176e047bee4da1 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 8 Apr 2022 15:23:10 +0800 Subject: [PATCH 01/46] add zero limit --- .../optimizer_placement_optimization_pass.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 1ca857fd11f..f19d1554504 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -228,8 +228,14 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder if (n->op().op_conf().has_variable_conf()) { const Shape shape(n->op().op_conf().variable_conf().shape()); const int64_t parallel_num = n->parallel_desc().parallel_num(); + bool is_1d_broadcast = false; + if (n->op().op_conf().variable_conf().nd_sbp_size() == 1 && + n->op().op_conf().variable_conf().nd_sbp(0) == "B") { + // NOTE(strint): Only 1D Broadcast Variable will be split by ZeRO. + is_1d_broadcast = true; + } // Parameter needs to be able to evenly splited and one slice size >= threshold - return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; + return is_1d_broadcast && shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; } else { return IsS0SignatureSupported(n); } From 9447157a2b290f25c529c0618e2b47322c44eb18 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 12 Apr 2022 21:10:53 +0800 Subject: [PATCH 02/46] add debug --- python/oneflow/test/graph/test_graph_zero.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 20fc7366bab..a246dad1aa1 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -28,6 +28,7 @@ def train_with_graph(iter_num=1): S0 = flow.sbp.split(0) linear = flow.nn.Linear(8, 4) linear = linear.to_global(placement=P, sbp=B) + #linear_mp = flow.nn.Linear(4, 8) flow.nn.init.constant_(linear.weight, 2.068758) flow.nn.init.constant_(linear.bias, 0.23) of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) @@ -76,6 +77,7 @@ def build(self, x): return out linear_t_g = LinearTrainGraphWithZeRO() + linear_t_g.debug(2) linear_e_g = LinearEvalGraphWithZeRO() def one_train_iter(): @@ -106,10 +108,10 @@ class TestLinearTrainGraphWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_with_zero_1(test_case): _test_linear_train_graph_with_zero(test_case, 1) - def test_linear_train_graph_with_zero_2(test_case): + def _test_linear_train_graph_with_zero_2(test_case): _test_linear_train_graph_with_zero(test_case, 2) - def test_linear_train_graph_with_zero_3(test_case): + def _test_linear_train_graph_with_zero_3(test_case): _test_linear_train_graph_with_zero(test_case, 3) From e3acaa9a993b3aed4e2269675d37d5f101b06dfd Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 12 Apr 2022 22:05:46 +0800 Subject: [PATCH 03/46] add mix zero test --- python/oneflow/test/graph/test_graph_zero.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index a246dad1aa1..ca535dd1de5 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -26,12 +26,18 @@ def train_with_graph(iter_num=1): P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast S0 = flow.sbp.split(0) - linear = flow.nn.Linear(8, 4) - linear = linear.to_global(placement=P, sbp=B) - #linear_mp = flow.nn.Linear(4, 8) - flow.nn.init.constant_(linear.weight, 2.068758) - flow.nn.init.constant_(linear.bias, 0.23) - of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) + + linear_dp = flow.nn.Linear(8, 4) + linear_dp = linear_dp.to_global(placement=P, sbp=B) + flow.nn.init.constant_(linear_dp.weight, 2.068758) + flow.nn.init.constant_(linear_dp.bias, 0.23) + + linear_mp = flow.nn.Linear(2, 8) + linear_mp = linear_mp.to_global(placement=P, sbp=S0) + flow.nn.init.constant_(linear_mp.weight, 2.068758) + flow.nn.init.constant_(linear_mp.bias, 0.23) + + of_sgd = flow.optim.SGD([{"params": linear_dp.parameters()}, {"params": linear_mp.parameters()}], lr=0.001, momentum=0.9) grad_scaler = flow.amp.StaticGradScaler(200) x = flow.randint(1, 100, (4, 8), dtype=flow.float32, placement=P, sbp=S0) @@ -39,7 +45,8 @@ def train_with_graph(iter_num=1): class LinearTrainGraphWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.linear = linear + self.linear_dp = linear_dp + self.linear_mp = linear_mp self.add_optimizer(of_sgd) self.config.enable_amp(True) @@ -60,7 +67,9 @@ def __init__(self): flow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) def build(self, x): - out = self.linear(x) + out = self.linear_dp(x) + out = out.to_global(placement=P, sbp=B) + out = self.linear_mp(x) loss = out.sum() loss.backward() return out @@ -68,12 +77,15 @@ def build(self, x): class LinearEvalGraphWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.linear = linear + self.linear_dp = linear_dp + self.linear_mp = linear_mp self.config.enable_amp(True) def build(self, x): - out = self.linear(x) + out = self.linear_dp(x) + out = out.to_global(placement=P, sbp=B) + out = self.linear_mp(x) return out linear_t_g = LinearTrainGraphWithZeRO() From b481a7e2709bda6f24230e8c05708b9fcc4d09f7 Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 13 Apr 2022 12:01:50 +0800 Subject: [PATCH 04/46] refactor zero api --- docs/source/graph.rst | 3 +- .../optimizer_placement_optimization_pass.cpp | 18 +-- python/oneflow/nn/graph/graph_config.py | 123 +++++++++--------- python/oneflow/test/graph/test_graph_zero.py | 24 +--- .../test/graph/test_optimization_conf.py | 2 +- 5 files changed, 78 insertions(+), 92 deletions(-) diff --git a/docs/source/graph.rst b/docs/source/graph.rst index b5ae269e4bb..b4d024bc52c 100644 --- a/docs/source/graph.rst +++ b/docs/source/graph.rst @@ -20,12 +20,11 @@ Base class for running neural networks in Static Graph Mode. .. autoclass:: oneflow.nn.graph.graph_config.GraphConfig :members: enable_amp, + enable_zero, allow_fuse_model_update_ops, allow_fuse_add_to_output, allow_fuse_cast_scale, set_gradient_accumulation_steps, - set_zero_redundancy_optimizer_mode, - set_zero_redundancy_optimizer_min_size_after_split, enable_xla_jit, enable_tensorrt, enable_openvino, diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index f19d1554504..0adc3a57039 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -228,14 +228,15 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder if (n->op().op_conf().has_variable_conf()) { const Shape shape(n->op().op_conf().variable_conf().shape()); const int64_t parallel_num = n->parallel_desc().parallel_num(); - bool is_1d_broadcast = false; - if (n->op().op_conf().variable_conf().nd_sbp_size() == 1 && - n->op().op_conf().variable_conf().nd_sbp(0) == "B") { - // NOTE(strint): Only 1D Broadcast Variable will be split by ZeRO. - is_1d_broadcast = true; - } + // bool is_1d_broadcast = false; + // if (n->op().op_conf().variable_conf().nd_sbp_size() == 1 && + // n->op().op_conf().variable_conf().nd_sbp(0) == "B") { + // // NOTE(strint): Only 1D Broadcast Variable will be split by ZeRO. + // is_1d_broadcast = true; + // } // Parameter needs to be able to evenly splited and one slice size >= threshold - return is_1d_broadcast && shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; + return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; + // return is_1d_broadcast && shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; } else { return IsS0SignatureSupported(n); } @@ -319,7 +320,8 @@ class OptimizerPlacementOptimizationPass final : public JobPass { Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!(ctx->job_desc().IsTrain() - && ctx->job_desc().job_conf().has_optimizer_placement_optimization_mode())) { + && ctx->job_desc().job_conf().has_optimizer_placement_optimization_mode() + && ctx->job_desc().job_conf().optimizer_placement_optimization_mode() != "none")) { return Maybe::Ok(); } const std::string& mode = ctx->job_desc().job_conf().optimizer_placement_optimization_mode(); diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 74289801bb5..be3c9c82e9e 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -17,6 +17,7 @@ from collections import OrderedDict +import oneflow from oneflow.nn.graph.optimizer import OptDict import oneflow._oneflow_internal.oneflow.core.job.job_conf as job_conf_cfg @@ -45,24 +46,39 @@ def training(self): return False raise NotImplementedError - def set_outputs_buffer_size(self, value: int = 2): - r"""Set the outputs buffer size of ``nn.Graph``. + def enable_amp(self, mode: bool = True): + r"""If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training. - When graph's outputs buffer size is greater than 2, multiple call on the graph can work like a pipeline. This makes multiple call takes less time. + For example: - The default outputs buffer size is 2. + .. code-block:: python - # TODO (lixiang): Explain the meaning of the size of buffer size and add sample code. - # The size of the buffer size indicates the maximum number of iterations that the output of the Graph and the Graph actually executed asynchronously can overlap. - # If the buffer size is 1, there is no pipeline. A size of 2 means that it can execute 1 iter ahead of time. A size of 3 means that two iters can be executed ahead of time. + import oneflow as flow + + class Graph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.linear = flow.nn.Linear(3, 8, False) + self.config.enable_amp(True) # Use mixed precision mode. + def build(self, x): + return self.linear(x) + + graph = Graph() Args: - value (int): graph ouputs buffer size. + mode (bool, optional): The default vaule is True. + """ - self._outputs_buffer_size = value + assert type(mode) is bool + self.proto.set_enable_auto_mixed_precision(mode) - def enable_amp(self, mode: bool = True): - r"""If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training. + def enable_zero(self, mode: bool = True, *, stage: int = 2, min_splited_size: int = 1024): + r"""Enable ZeRO redundancy optimizer. + + This optimzation will reduce optimizer states memory consumption as described + by ZeRO https://arxiv.org/abs/1910.02054 . + + The default zero stage is 2. For example: @@ -74,17 +90,36 @@ class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) - self.config.enable_amp(True) # Use mixed precision mode. + self.config.enable_zero() def build(self, x): return self.linear(x) graph = Graph() Args: - mode (bool, optional): The default vaule is True. + mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. + stage (int): optimization stage, range from 1 to 3. + min_splited_size (int): min size of a sharded optimizer state. """ - assert type(mode) is bool - self.proto.set_enable_auto_mixed_precision(mode) + if not mode: + self.proto.set_optimizer_placement_optimization_mode("none") + return + assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." + assert min_splited_size > 0, "ZeRO min size of a sharded optimizer state must > 0." + if stage == 1: + print("zero stage 1 optimization") + self.proto.set_optimizer_placement_optimization_mode("distributed_split") + self.proto.set_optimizer_placement_optimization_threshold(min_splited_size) + elif stage == 2: + self.proto.set_optimizer_placement_optimization_mode("distributed_split") + self.proto.set_optimizer_placement_optimization_threshold(min_splited_size) + oneflow.boxing.nccl.enable_use_compute_stream(True) + elif stage == 3: + print("zero stage 3 optimization") + self.proto.set_optimizer_placement_optimization_mode("distributed_split") + self.proto.set_optimizer_placement_optimization_threshold(min_splited_size) + oneflow.boxing.nccl.enable_use_compute_stream(True) + oneflow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) def allow_fuse_model_update_ops(self, mode: bool = True): r"""If set to true, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance. @@ -188,61 +223,21 @@ def build(self, x): """ self.proto.set_num_gradient_accumulation_steps(value) - def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"): - r"""Set mode to remove redundancy of optimizer states. - This optimzation will reduce optimizer states memory consumption as described - by ZeRO https://arxiv.org/abs/1910.02054 . - - For example: - - .. code-block:: python - - import oneflow as flow - - class Graph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.linear = flow.nn.Linear(3, 8, False) - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - def build(self, x): - return self.linear(x) - - graph = Graph() - - Args: - mode (str): "distributed_split" or "non_distributed". "distributed_split" mode - will shard each optimizer state across devices. "non_distributed" mode - will place each optimizer state to only one device. - """ - assert mode in ("distributed_split", "non_distributed") - self.proto.set_optimizer_placement_optimization_mode(mode) - - def set_zero_redundancy_optimizer_min_size_after_split(self, value): - r"""Set the min size of optimizer state/grad/parameter after split. - - For example: - - .. code-block:: python + def set_outputs_buffer_size(self, value: int = 2): + r"""Set the outputs buffer size of ``nn.Graph``. - import oneflow as flow + When graph's outputs buffer size is greater than 2, multiple call on the graph can work like a pipeline. This makes multiple call takes less time. - class Graph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.linear = flow.nn.Linear(3, 8, False) - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - def build(self, x): - return self.linear(x) + The default outputs buffer size is 2. - graph = Graph() + # TODO (lixiang): Explain the meaning of the size of buffer size and add sample code. + # The size of the buffer size indicates the maximum number of iterations that the output of the Graph and the Graph actually executed asynchronously can overlap. + # If the buffer size is 1, there is no pipeline. A size of 2 means that it can execute 1 iter ahead of time. A size of 3 means that two iters can be executed ahead of time. Args: - value (int): min size value. + value (int): graph ouputs buffer size. """ - assert isinstance(value, int) - assert value >= 1 - self.proto.set_optimizer_placement_optimization_threshold(value) + self._outputs_buffer_size = value def enable_xla_jit(self, value=True): r"""Whether use xla_jit in xrt or not. diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index ca535dd1de5..ca3e6629f0b 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -32,7 +32,7 @@ def train_with_graph(iter_num=1): flow.nn.init.constant_(linear_dp.weight, 2.068758) flow.nn.init.constant_(linear_dp.bias, 0.23) - linear_mp = flow.nn.Linear(2, 8) + linear_mp = flow.nn.Linear(8, 2) linear_mp = linear_mp.to_global(placement=P, sbp=S0) flow.nn.init.constant_(linear_mp.weight, 2.068758) flow.nn.init.constant_(linear_mp.bias, 0.23) @@ -51,20 +51,8 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - if zero_stage == 1: - print("zero stage 1 optimization") - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - if zero_stage == 2: - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - flow.boxing.nccl.enable_use_compute_stream(True) - if zero_stage == 3: - print("zero stage 3 optimization") - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - flow.boxing.nccl.enable_use_compute_stream(True) - flow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) + self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) + self.debug(2) def build(self, x): out = self.linear_dp(x) @@ -103,8 +91,10 @@ def one_eval_iter(): # After pass rewrite in training graph, parameters' sbp has been # changed from flow.sbp.broadcast to flow.sbp.split(0) - test_case.assertEqual(linear.weight.sbp[0], S0) - test_case.assertEqual(linear.bias.sbp[0], S0) + test_case.assertEqual(linear_dp.weight.sbp[0], S0) + test_case.assertEqual(linear_dp.bias.sbp[0], S0) + test_case.assertEqual(linear_mp.weight.sbp[0], S0) + test_case.assertEqual(linear_mp.bias.sbp[0], S0) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. diff --git a/python/oneflow/test/graph/test_optimization_conf.py b/python/oneflow/test/graph/test_optimization_conf.py index da6348b7033..a60d339be8b 100644 --- a/python/oneflow/test/graph/test_optimization_conf.py +++ b/python/oneflow/test/graph/test_optimization_conf.py @@ -66,7 +66,7 @@ def __init__(self): self.config.allow_fuse_add_to_output(True) self.config.allow_fuse_cast_scale(True) self.config.set_gradient_accumulation_steps(100) - self.config.set_zero_redundancy_optimizer_mode("distributed_split") + self.config.enable_zero(True) self.config.enable_cudnn_conv_heuristic_search_algo(False) def build(self, x): From 3b564680d741d1ebee50da1402778ebfb90ee53c Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 14 Apr 2022 15:49:18 +0800 Subject: [PATCH 05/46] zero test with mp --- .../optimizer_placement_optimization_pass.cpp | 7 ------- .../oneflow/framework/multi_client_session.py | 3 +++ python/oneflow/test/graph/test_graph_zero.py | 17 +++++++---------- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 0adc3a57039..3d7dae11e5d 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -228,15 +228,8 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder if (n->op().op_conf().has_variable_conf()) { const Shape shape(n->op().op_conf().variable_conf().shape()); const int64_t parallel_num = n->parallel_desc().parallel_num(); - // bool is_1d_broadcast = false; - // if (n->op().op_conf().variable_conf().nd_sbp_size() == 1 && - // n->op().op_conf().variable_conf().nd_sbp(0) == "B") { - // // NOTE(strint): Only 1D Broadcast Variable will be split by ZeRO. - // is_1d_broadcast = true; - // } // Parameter needs to be able to evenly splited and one slice size >= threshold return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; - // return is_1d_broadcast && shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; } else { return IsS0SignatureSupported(n); } diff --git a/python/oneflow/framework/multi_client_session.py b/python/oneflow/framework/multi_client_session.py index 72c6e093779..64a82c12b27 100644 --- a/python/oneflow/framework/multi_client_session.py +++ b/python/oneflow/framework/multi_client_session.py @@ -124,4 +124,7 @@ def update_resource_eagerly(self, resource_config): self._session_ctx.update_resource(config_proto_str) def __del__(self): + if self._env.is_shutting_down(): + # After python shutting down, it's not safe to call oneflow + return self._TryClose() diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index ca3e6629f0b..adbd1926b96 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -27,20 +27,18 @@ def train_with_graph(iter_num=1): B = flow.sbp.broadcast S0 = flow.sbp.split(0) - linear_dp = flow.nn.Linear(8, 4) + linear_dp = flow.nn.Linear(8, 4, bias=False) linear_dp = linear_dp.to_global(placement=P, sbp=B) flow.nn.init.constant_(linear_dp.weight, 2.068758) - flow.nn.init.constant_(linear_dp.bias, 0.23) - linear_mp = flow.nn.Linear(8, 2) + linear_mp = flow.nn.Linear(4, 5, bias=False) linear_mp = linear_mp.to_global(placement=P, sbp=S0) flow.nn.init.constant_(linear_mp.weight, 2.068758) - flow.nn.init.constant_(linear_mp.bias, 0.23) of_sgd = flow.optim.SGD([{"params": linear_dp.parameters()}, {"params": linear_mp.parameters()}], lr=0.001, momentum=0.9) grad_scaler = flow.amp.StaticGradScaler(200) - x = flow.randint(1, 100, (4, 8), dtype=flow.float32, placement=P, sbp=S0) + x = flow.randint(1, 100, (6, 8), dtype=flow.float32, placement=P, sbp=S0) class LinearTrainGraphWithZeRO(flow.nn.Graph): def __init__(self): @@ -57,7 +55,7 @@ def __init__(self): def build(self, x): out = self.linear_dp(x) out = out.to_global(placement=P, sbp=B) - out = self.linear_mp(x) + out = self.linear_mp(out) loss = out.sum() loss.backward() return out @@ -73,12 +71,13 @@ def __init__(self): def build(self, x): out = self.linear_dp(x) out = out.to_global(placement=P, sbp=B) - out = self.linear_mp(x) + out = self.linear_mp(out) return out linear_t_g = LinearTrainGraphWithZeRO() - linear_t_g.debug(2) + linear_t_g.debug(1) linear_e_g = LinearEvalGraphWithZeRO() + linear_e_g.debug(1) def one_train_iter(): out = linear_t_g(x) @@ -92,9 +91,7 @@ def one_eval_iter(): # After pass rewrite in training graph, parameters' sbp has been # changed from flow.sbp.broadcast to flow.sbp.split(0) test_case.assertEqual(linear_dp.weight.sbp[0], S0) - test_case.assertEqual(linear_dp.bias.sbp[0], S0) test_case.assertEqual(linear_mp.weight.sbp[0], S0) - test_case.assertEqual(linear_mp.bias.sbp[0], S0) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. From 66c8ac3b8f619affb799b078448e785c63f0012d Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 14 Apr 2022 23:48:34 +0800 Subject: [PATCH 06/46] add 2d test --- python/oneflow/test/graph/test_graph_zero.py | 84 ++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index adbd1926b96..085e1c5d3e3 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -100,6 +100,85 @@ def one_eval_iter(): iter_num = 1 graph_check_list = train_with_graph(iter_num) +def _test_linear_train_graph_2d_with_zero(test_case, zero_stage=1): + def train_with_graph(iter_num=1): + P = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + B = flow.sbp.broadcast + S0 = flow.sbp.split(0) + S1 = flow.sbp.split(1) + + linear_dp_mp = flow.nn.Linear(8, 4, bias=False) + linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) + flow.nn.init.constant_(linear_dp_mp.weight, 2.068758) + + linear_mp_dp = flow.nn.Linear(8, 5, bias=False) + linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) + flow.nn.init.constant_(linear_mp_dp.weight, 2.068758) + + of_sgd = flow.optim.SGD([{"params": linear_dp_mp.parameters()}, {"params": linear_mp_dp.parameters()}], lr=0.001, momentum=0.9) + grad_scaler = flow.amp.StaticGradScaler(200) + + x = flow.randint(1, 100, (6, 8), dtype=flow.float32, placement=P, sbp=[S0, B]) + + class LinearTrainGraph2DWithZeRO(flow.nn.Graph): + def __init__(self): + super().__init__() + self.linear_dp_mp = linear_dp_mp + self.linear_mp_dp = linear_mp_dp + self.add_optimizer(of_sgd) + + self.config.enable_amp(True) + self.set_grad_scaler(grad_scaler) + #self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) + self.debug(2) + + def build(self, x): + out = self.linear_dp_mp(x) + out = out.to_global(placement=P, sbp=[B, S0]) + out = self.linear_mp_dp(out) + loss = out.sum() + loss.backward() + return out + + class LinearEvalGraph2DWithZeRO(flow.nn.Graph): + def __init__(self): + super().__init__() + self.linear_dp_mp = linear_dp_mp + self.linear_mp_dp = linear_mp_dp + + self.config.enable_amp(True) + + def build(self, x): + out = self.linear_dp_mp(x) + out = out.to_global(placement=P, sbp=[B, S0]) + out = self.linear_mp_dp(out) + return out + + linear_t_g = LinearTrainGraph2DWithZeRO() + linear_t_g.debug(1) + linear_e_g = LinearEvalGraph2DWithZeRO() + linear_e_g.debug(1) + + def one_train_iter(): + out = linear_t_g(x) + + def one_eval_iter(): + out = linear_e_g(x) + + for i in range(iter_num): + one_train_iter() + + # After pass rewrite in training graph, parameters' sbp has been + # changed from flow.sbp.broadcast to flow.sbp.split(0) + #test_case.assertEqual(linear_dp.weight.sbp[0], S0) + #test_case.assertEqual(linear_mp.weight.sbp[0], S0) + + # In evaluation graph, paramters's sbp are flow.sbp.split(0). + # But their consumer will consum them as flow.sbp.broadcast. + one_eval_iter() + + iter_num = 1 + graph_check_list = train_with_graph(iter_num) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() @@ -113,6 +192,11 @@ def _test_linear_train_graph_with_zero_2(test_case): def _test_linear_train_graph_with_zero_3(test_case): _test_linear_train_graph_with_zero(test_case, 3) +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n4d() +class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): + def test_linear_train_graph_2d_with_zero_1(test_case): + _test_linear_train_graph_2d_with_zero(test_case, 1) if __name__ == "__main__": unittest.main() From 4f56df294636cb0a4e48300cb0cf2f0af848bc72 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 15 Apr 2022 23:45:39 +0800 Subject: [PATCH 07/46] add zero nd --- oneflow/core/job/job_build_and_infer_ctx.cpp | 3 + oneflow/core/job/job_builder.cpp | 1 + .../optimizer_placement_optimization_pass.cpp | 64 +++++++++++++++++-- python/oneflow/nn/graph/graph_config.py | 8 ++- python/oneflow/test/graph/test_graph_zero.py | 43 +++++++++---- 5 files changed, 99 insertions(+), 20 deletions(-) diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 6288f153060..6e723cd9818 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -998,6 +998,9 @@ Maybe LazyJobBuildAndInferCtx::Complete() { }; int32_t pass_cnt = 0; auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe { + VLOG(1) << job_name << " is compiling with pass" + << " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + << (cnt > 0 ? std::to_string(cnt) : ""); if (unlikely(NeedLogJob(pass_name))) { std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-before"); diff --git a/oneflow/core/job/job_builder.cpp b/oneflow/core/job/job_builder.cpp index 6a8d9097c03..5fc1ce1f521 100644 --- a/oneflow/core/job/job_builder.cpp +++ b/oneflow/core/job/job_builder.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 3d7dae11e5d..e624a482236 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" +#include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" @@ -85,12 +86,33 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->in_edges().size() > 1) { break; } if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); - if (!cur_node->SbpParallel4BnInOp(sole_ibn).has_broadcast_parallel()) { break; } + LOG(ERROR) << cur_node->op().op_name() + << " has sbp: " << cur_node->NdSbp4BnInOp(sole_ibn).DebugString(); + // if (!cur_node->SbpParallel4BnInOp(sole_ibn).has_broadcast_parallel()) { break; } + // if (!(ibn_nd_sbp.sbp_parallel_size() == 1 && + // ibn_nd_sbp.sbp_parallel(0).has_broadcast_parallel())) { break; } + const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn); + if (ibn_nd_sbp.sbp_parallel_size() == 0) { break; } + bool has_broadcast = false; + FOR_RANGE(int, i, 0, ibn_nd_sbp.sbp_parallel_size()) { + if (ibn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; + } + if (!has_broadcast) { break; } } if (!IsAllowed(cur_node)) { break; } if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); - if (!cur_node->SbpParallel4BnInOp(sole_obn).has_broadcast_parallel()) { break; } + LOG(ERROR) << cur_node->op().op_name() + << " has sbp: " << cur_node->NdSbp4BnInOp(sole_obn).DebugString(); + // if (!cur_node->SbpParallel4BnInOp(sole_obn).has_broadcast_parallel()) { break; } + const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); + // if (!(obn_nd_sbp.sbp_parallel_size() == 1 && + // obn_nd_sbp.sbp_parallel(0).has_broadcast_parallel())) { break; } + bool has_broadcast = false; + FOR_RANGE(int, i, 0, obn_nd_sbp.sbp_parallel_size()) { + if (obn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; + } + if (!has_broadcast) { break; } out->emplace_back(cur_node); if (cur_node->out_edges().size() == 1) { cur_node = cur_node->SoleOutEdge()->dst_node(); @@ -123,6 +145,26 @@ void SetBroadcastParallel4Consumers(JobBuilder* builder, const SequencePtr& sequ }); } +void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::string& ibn, + const NdSbp& nd_sbp) { + OpBlobArg op_blob_arg; + op_blob_arg.set_op_name(node->op().op_name()); + op_blob_arg.set_bn_in_op(ibn); + builder->SetNdSbp4Oba(op_blob_arg, nd_sbp); +} + +void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { + const OpNode* node = sequence->GetLastNode(); + const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); + node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + for (const std::string& ibn : out_node->op().input_bns()) { + if (out_node->op().BnInOp2Lbi(ibn) == lbi) { + SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); + } + } + }); +} + std::function MakeGetterOpNode2TopoOrder(const OpGraph& op_graph) { HashMap op_node2topo_order; int64_t node_cnt = 0; @@ -243,9 +285,18 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder for (int64_t i = 0; i < sorted_sequences.size(); ++i) { const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); OperatorConf new_var_op_conf = var_node->op().op_conf(); - CHECK_EQ(pd.hierarchy()->NumAxes(), 1); - new_var_op_conf.mutable_variable_conf()->clear_nd_sbp(); - *new_var_op_conf.mutable_variable_conf()->add_nd_sbp() = "S(0)"; + const std::string& sole_obn = var_node->op().SoleObn(); + LOG(ERROR) << var_node->op().op_name() + << " has sbp: " << var_node->NdSbp4BnInOp(sole_obn).DebugString(); + const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); + // CHECK_EQ(pd.hierarchy()->NumAxes(), 1); + FOR_RANGE(int, i, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { + if (new_var_op_conf.variable_conf().nd_sbp(i) == "B") { + *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(i) = "S(0)"; + } + } + // new_var_op_conf.mutable_variable_conf()->clear_nd_sbp(); + // *new_var_op_conf.mutable_variable_conf()->add_nd_sbp() = "S(0)"; if (i != 0) { const std::string& prev_op_name = sorted_sequences.at(i - 1)->GetVariableNode()->op().op_name(); @@ -253,7 +304,8 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder } builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. - SetBroadcastParallel4Consumers(builder, sorted_sequences.at(i)); + // SetBroadcastParallel4Consumers(builder, sorted_sequences.at(i)); + SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index be3c9c82e9e..f4bb36f73d1 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -72,7 +72,9 @@ def build(self, x): assert type(mode) is bool self.proto.set_enable_auto_mixed_precision(mode) - def enable_zero(self, mode: bool = True, *, stage: int = 2, min_splited_size: int = 1024): + def enable_zero( + self, mode: bool = True, *, stage: int = 2, min_splited_size: int = 1024 + ): r"""Enable ZeRO redundancy optimizer. This optimzation will reduce optimizer states memory consumption as described @@ -105,7 +107,9 @@ def build(self, x): self.proto.set_optimizer_placement_optimization_mode("none") return assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." - assert min_splited_size > 0, "ZeRO min size of a sharded optimizer state must > 0." + assert ( + min_splited_size > 0 + ), "ZeRO min size of a sharded optimizer state must > 0." if stage == 1: print("zero stage 1 optimization") self.proto.set_optimizer_placement_optimization_mode("distributed_split") diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 085e1c5d3e3..39d221be58c 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -27,18 +27,22 @@ def train_with_graph(iter_num=1): B = flow.sbp.broadcast S0 = flow.sbp.split(0) - linear_dp = flow.nn.Linear(8, 4, bias=False) + linear_dp = flow.nn.Linear(800, 400, bias=False) linear_dp = linear_dp.to_global(placement=P, sbp=B) flow.nn.init.constant_(linear_dp.weight, 2.068758) - linear_mp = flow.nn.Linear(4, 5, bias=False) + linear_mp = flow.nn.Linear(400, 500, bias=False) linear_mp = linear_mp.to_global(placement=P, sbp=S0) flow.nn.init.constant_(linear_mp.weight, 2.068758) - of_sgd = flow.optim.SGD([{"params": linear_dp.parameters()}, {"params": linear_mp.parameters()}], lr=0.001, momentum=0.9) + of_sgd = flow.optim.SGD( + [{"params": linear_dp.parameters()}, {"params": linear_mp.parameters()}], + lr=0.001, + momentum=0.9, + ) grad_scaler = flow.amp.StaticGradScaler(200) - x = flow.randint(1, 100, (6, 8), dtype=flow.float32, placement=P, sbp=S0) + x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=S0) class LinearTrainGraphWithZeRO(flow.nn.Graph): def __init__(self): @@ -81,6 +85,8 @@ def build(self, x): def one_train_iter(): out = linear_t_g(x) + if flow.env.get_rank() == 0: + print(linear_t_g) def one_eval_iter(): out = linear_e_g(x) @@ -100,6 +106,7 @@ def one_eval_iter(): iter_num = 1 graph_check_list = train_with_graph(iter_num) + def _test_linear_train_graph_2d_with_zero(test_case, zero_stage=1): def train_with_graph(iter_num=1): P = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) @@ -107,18 +114,25 @@ def train_with_graph(iter_num=1): S0 = flow.sbp.split(0) S1 = flow.sbp.split(1) - linear_dp_mp = flow.nn.Linear(8, 4, bias=False) + linear_dp_mp = flow.nn.Linear(800, 400, bias=False) linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) flow.nn.init.constant_(linear_dp_mp.weight, 2.068758) - linear_mp_dp = flow.nn.Linear(8, 5, bias=False) + linear_mp_dp = flow.nn.Linear(800, 500, bias=False) linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) flow.nn.init.constant_(linear_mp_dp.weight, 2.068758) - of_sgd = flow.optim.SGD([{"params": linear_dp_mp.parameters()}, {"params": linear_mp_dp.parameters()}], lr=0.001, momentum=0.9) + of_sgd = flow.optim.SGD( + [ + {"params": linear_dp_mp.parameters()}, + {"params": linear_mp_dp.parameters()}, + ], + lr=0.001, + momentum=0.9, + ) grad_scaler = flow.amp.StaticGradScaler(200) - x = flow.randint(1, 100, (6, 8), dtype=flow.float32, placement=P, sbp=[S0, B]) + x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) class LinearTrainGraph2DWithZeRO(flow.nn.Graph): def __init__(self): @@ -129,7 +143,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - #self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) + self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) self.debug(2) def build(self, x): @@ -161,6 +175,8 @@ def build(self, x): def one_train_iter(): out = linear_t_g(x) + if flow.env.get_rank() == 0: + print(linear_t_g) def one_eval_iter(): out = linear_e_g(x) @@ -170,8 +186,8 @@ def one_eval_iter(): # After pass rewrite in training graph, parameters' sbp has been # changed from flow.sbp.broadcast to flow.sbp.split(0) - #test_case.assertEqual(linear_dp.weight.sbp[0], S0) - #test_case.assertEqual(linear_mp.weight.sbp[0], S0) + # test_case.assertEqual(linear_dp.weight.sbp[0], S0) + # test_case.assertEqual(linear_mp.weight.sbp[0], S0) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. @@ -180,6 +196,7 @@ def one_eval_iter(): iter_num = 1 graph_check_list = train_with_graph(iter_num) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestLinearTrainGraphWithZeRO(oneflow.unittest.TestCase): @@ -192,11 +209,13 @@ def _test_linear_train_graph_with_zero_2(test_case): def _test_linear_train_graph_with_zero_3(test_case): _test_linear_train_graph_with_zero(test_case, 3) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_2d_with_zero_1(test_case): - _test_linear_train_graph_2d_with_zero(test_case, 1) + _test_linear_train_graph_2d_with_zero(test_case, 2) + if __name__ == "__main__": unittest.main() From 2834289c90982d5ccc34194fa1ad1c82532729f5 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 22 Apr 2022 00:43:49 +0800 Subject: [PATCH 08/46] add nd zero --- .../optimizer_placement_optimization_pass.cpp | 13 ++++++++----- python/oneflow/test/graph/test_graph_zero.py | 8 ++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index e624a482236..90e9cfb1269 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -270,6 +270,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder if (n->op().op_conf().has_variable_conf()) { const Shape shape(n->op().op_conf().variable_conf().shape()); const int64_t parallel_num = n->parallel_desc().parallel_num(); + // TODO(strint): zero with nd check size // Parameter needs to be able to evenly splited and one slice size >= threshold return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; } else { @@ -286,17 +287,19 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); OperatorConf new_var_op_conf = var_node->op().op_conf(); const std::string& sole_obn = var_node->op().SoleObn(); - LOG(ERROR) << var_node->op().op_name() + LOG(ERROR) << var_node->op().op_name() << " can be splited, " << " has sbp: " << var_node->NdSbp4BnInOp(sole_obn).DebugString(); const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); // CHECK_EQ(pd.hierarchy()->NumAxes(), 1); FOR_RANGE(int, i, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { if (new_var_op_conf.variable_conf().nd_sbp(i) == "B") { + // TODO(strint): zero with nd choose dim *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(i) = "S(0)"; + LOG(ERROR) << var_node->op().op_name() << " ranks dim " << i << " sbp is changed form B to S(0) " << new_var_op_conf.variable_conf().DebugString(); + // Only split one more dim. + break; } } - // new_var_op_conf.mutable_variable_conf()->clear_nd_sbp(); - // *new_var_op_conf.mutable_variable_conf()->add_nd_sbp() = "S(0)"; if (i != 0) { const std::string& prev_op_name = sorted_sequences.at(i - 1)->GetVariableNode()->op().op_name(); @@ -304,8 +307,8 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder } builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. - // SetBroadcastParallel4Consumers(builder, sorted_sequences.at(i)); - SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); + bool limit_consume_b = ParseBooleanFromEnv("ZERO_LIMIT_B", true); + if (limit_consume_b) { SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } } }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 39d221be58c..46b082b87ce 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -144,11 +144,11 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) - self.debug(2) + self.debug(1) def build(self, x): out = self.linear_dp_mp(x) - out = out.to_global(placement=P, sbp=[B, S0]) + #out = out.to_global(placement=P, sbp=[B, S0]) out = self.linear_mp_dp(out) loss = out.sum() loss.backward() @@ -164,7 +164,7 @@ def __init__(self): def build(self, x): out = self.linear_dp_mp(x) - out = out.to_global(placement=P, sbp=[B, S0]) + #out = out.to_global(placement=P, sbp=[B, S0]) out = self.linear_mp_dp(out) return out @@ -214,7 +214,7 @@ def _test_linear_train_graph_with_zero_3(test_case): @flow.unittest.skip_unless_1n4d() class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_2d_with_zero_1(test_case): - _test_linear_train_graph_2d_with_zero(test_case, 2) + _test_linear_train_graph_2d_with_zero(test_case, 1) if __name__ == "__main__": From b805256e4811f3c81ee1852b6c31c992e47d6219 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 22 Apr 2022 02:53:43 +0800 Subject: [PATCH 09/46] add sbp cast --- .../optimizer_placement_optimization_pass.cpp | 52 ++++++++++++++++--- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 90e9cfb1269..5fc1e5d77c5 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/util.h" +#include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" @@ -156,13 +157,47 @@ void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::stri void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { const OpNode* node = sequence->GetLastNode(); const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); - node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { - for (const std::string& ibn : out_node->op().input_bns()) { - if (out_node->op().BnInOp2Lbi(ibn) == lbi) { - SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); + int limit_consumer_mode = ParseIntegerFromEnv("ZERO_LIMIT_CONSUMER_MODE", 2); + // If limit_consumer_mode == 0, no limit on consumer + if (limit_consumer_mode == 1) { + // Soft limt consumer to consume weight as Broadcast. + const auto parallel_cast_op = + user_op::UserOpConfWrapperBuilder("System-ZeRO-ParallelCast-" + node->op().op_name() + "-" + NewUniqueId()) + .Op("hierarchical_parallel_cast") + .Input("in", GenLogicalBlobName(lbi)) + .Output("out") + .Attr>("nd_sbp", NdSbpToStringList(nd_sbp)) + .Attr("grad_mode", "auto") + .Attr>("grad_nd_sbp", std::vector()) + .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) + .Build(); + builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()}); + auto out_lbn = GenLogicalBlobName(parallel_cast_op.op_name(), "out"); + node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + for (const std::string& ibn : out_node->op().input_bns()) { + if (out_node->op().BnInOp2Lbi(ibn) == lbi) { + // SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); + if (!CHECK_JUST(builder->IsInMutOpTransaction(out_node->op().op_name()))) { + CHECK_JUST(builder->MutOpTransactionMut(out_node->op().op_conf())); + } + OperatorConf& mut_consumer_op = + *CHECK_JUST(builder->MutOpTransactionGet(out_node->op().op_name())); + const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, out_lbn); + CHECK_EQ(old_lbn, GenLogicalBlobName(lbi)); + } } - } - }); + }); + } else if (limit_consumer_mode == 2) { + // Hard limt consumer to consume weight as Broadcast. + // Default is 2. + node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + for (const std::string& ibn : out_node->op().input_bns()) { + if (out_node->op().BnInOp2Lbi(ibn) == lbi) { + SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); + } + } + }); + } } std::function MakeGetterOpNode2TopoOrder(const OpGraph& op_graph) { @@ -305,14 +340,15 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder sorted_sequences.at(i - 1)->GetVariableNode()->op().op_name(); new_var_op_conf.add_ctrl_in_op_name(prev_op_name); } + // TODO(strint): rewrite with MutOpTransactioin builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. - bool limit_consume_b = ParseBooleanFromEnv("ZERO_LIMIT_B", true); - if (limit_consume_b) { SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } + SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, PlacementSequencesAsSplitParallel); + JUST(builder->MutOpTransactionCommit()); return Maybe::Ok(); } From 2ede3545260164c9214b853e006b81d13859306b Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 22 Apr 2022 16:40:31 +0800 Subject: [PATCH 10/46] test passed soft limit consumer --- .../job_rewriter/optimizer_placement_optimization_pass.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 5fc1e5d77c5..7310f98acbb 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -167,12 +167,12 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const .Input("in", GenLogicalBlobName(lbi)) .Output("out") .Attr>("nd_sbp", NdSbpToStringList(nd_sbp)) - .Attr("grad_mode", "auto") + .Attr("grad_mode", "identity") // don't do ndsbp cast at backward .Attr>("grad_nd_sbp", std::vector()) .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) .Build(); builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()}); - auto out_lbn = GenLogicalBlobName(parallel_cast_op.op_name(), "out"); + auto out_lbn = parallel_cast_op.output("out", 0); node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { From 0227f5487784893e55c08487b3fc56acfb555874 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 22 Apr 2022 22:49:00 +0800 Subject: [PATCH 11/46] refine size api --- python/oneflow/nn/graph/graph_config.py | 12 ++++++------ python/oneflow/test/graph/test_graph_zero.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index f4bb36f73d1..81a36725d60 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -73,7 +73,7 @@ def build(self, x): self.proto.set_enable_auto_mixed_precision(mode) def enable_zero( - self, mode: bool = True, *, stage: int = 2, min_splited_size: int = 1024 + self, mode: bool = True, *, stage: int = 2, min_shard_size: int = 1024 ): r"""Enable ZeRO redundancy optimizer. @@ -101,27 +101,27 @@ def build(self, x): Args: mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. stage (int): optimization stage, range from 1 to 3. - min_splited_size (int): min size of a sharded optimizer state. + min_shard_size (int): min size of a shard of an optimizer state. """ if not mode: self.proto.set_optimizer_placement_optimization_mode("none") return assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." assert ( - min_splited_size > 0 + min_shard_size > 0 ), "ZeRO min size of a sharded optimizer state must > 0." if stage == 1: print("zero stage 1 optimization") self.proto.set_optimizer_placement_optimization_mode("distributed_split") - self.proto.set_optimizer_placement_optimization_threshold(min_splited_size) + self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) elif stage == 2: self.proto.set_optimizer_placement_optimization_mode("distributed_split") - self.proto.set_optimizer_placement_optimization_threshold(min_splited_size) + self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) oneflow.boxing.nccl.enable_use_compute_stream(True) elif stage == 3: print("zero stage 3 optimization") self.proto.set_optimizer_placement_optimization_mode("distributed_split") - self.proto.set_optimizer_placement_optimization_threshold(min_splited_size) + self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) oneflow.boxing.nccl.enable_use_compute_stream(True) oneflow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 46b082b87ce..a0d0a4e6aaa 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -53,7 +53,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) + self.config.enable_zero(True, stage=zero_stage, min_shard_size=1) self.debug(2) def build(self, x): @@ -143,7 +143,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - self.config.enable_zero(True, stage=zero_stage, min_splited_size=1) + self.config.enable_zero(True, stage=zero_stage, min_shard_size=1) self.debug(1) def build(self, x): From 7036e041cafa51dfdcd394682ff2aff2ad374236 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 28 Apr 2022 18:15:34 +0800 Subject: [PATCH 12/46] zero use stage 2 --- python/oneflow/test/graph/test_graph_zero.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index a0d0a4e6aaa..fed270bdf25 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -148,7 +148,6 @@ def __init__(self): def build(self, x): out = self.linear_dp_mp(x) - #out = out.to_global(placement=P, sbp=[B, S0]) out = self.linear_mp_dp(out) loss = out.sum() loss.backward() @@ -214,7 +213,7 @@ def _test_linear_train_graph_with_zero_3(test_case): @flow.unittest.skip_unless_1n4d() class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_2d_with_zero_1(test_case): - _test_linear_train_graph_2d_with_zero(test_case, 1) + _test_linear_train_graph_2d_with_zero(test_case, 2) if __name__ == "__main__": From c26763eef227980cc171e6d54e771dceec766822 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 29 Apr 2022 17:45:12 +0800 Subject: [PATCH 13/46] add limit consumer api --- python/oneflow/test/graph/test_graph_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index fed270bdf25..21bec2cd8c8 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -53,7 +53,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - self.config.enable_zero(True, stage=zero_stage, min_shard_size=1) + self.config.enable_zero(True, stage=zero_stage, min_shard_size=1, parameter_consumer_limit_level=0) self.debug(2) def build(self, x): @@ -143,7 +143,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - self.config.enable_zero(True, stage=zero_stage, min_shard_size=1) + self.config.enable_zero(True, stage=zero_stage, min_shard_size=1, parameter_consumer_limit_level=2) self.debug(1) def build(self, x): From d84e8a97a03877689b73ee8817d75dbe5760963b Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 29 Apr 2022 17:46:09 +0800 Subject: [PATCH 14/46] add new api --- oneflow/core/job/eager_nccl_comm_manager.cpp | 3 +++ oneflow/core/job/job_conf.proto | 1 + .../optimizer_placement_optimization_pass.cpp | 11 +++++------ python/oneflow/nn/graph/graph_config.py | 18 +++++++----------- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index d8b77cdbb72..959a7837010 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -71,6 +71,9 @@ void CreateNcclComm(ncclComm_t* comm, const int dev, const std::string& key, << ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank << ", key = {" << key << "}\n"; OF_NCCL_CHECK(ncclCommInitRank(comm, device_vec.size(), nccl_unique_id, rank)); + VLOG(2) << " EagerNcclCommMgr::ncclCommInitRank succeed device_vec.size() = " << device_vec.size() + << ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank + << ", key = {" << key << "}\n"; } } // namespace diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 664a0ac5989..939297053a6 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -219,6 +219,7 @@ message JobConfigProto { optional bool enable_gradients_stats_aggregation = 106 [default = true]; optional string optimizer_placement_optimization_mode = 107; optional int64 optimizer_placement_optimization_threshold = 108 [default = 1024]; + optional int64 optimizer_placement_optimization_comsumer_limit_level = 110 [default = 2]; optional QatConfig qat_config = 109; diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 7310f98acbb..d8e82712005 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -87,7 +87,7 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->in_edges().size() > 1) { break; } if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); - LOG(ERROR) << cur_node->op().op_name() + VLOG(3) << cur_node->op().op_name() << " has sbp: " << cur_node->NdSbp4BnInOp(sole_ibn).DebugString(); // if (!cur_node->SbpParallel4BnInOp(sole_ibn).has_broadcast_parallel()) { break; } // if (!(ibn_nd_sbp.sbp_parallel_size() == 1 && @@ -103,7 +103,7 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (!IsAllowed(cur_node)) { break; } if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); - LOG(ERROR) << cur_node->op().op_name() + VLOG(3) << cur_node->op().op_name() << " has sbp: " << cur_node->NdSbp4BnInOp(sole_obn).DebugString(); // if (!cur_node->SbpParallel4BnInOp(sole_obn).has_broadcast_parallel()) { break; } const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); @@ -157,7 +157,7 @@ void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::stri void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { const OpNode* node = sequence->GetLastNode(); const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); - int limit_consumer_mode = ParseIntegerFromEnv("ZERO_LIMIT_CONSUMER_MODE", 2); + const int64_t limit_consumer_mode = builder->job().job_conf().optimizer_placement_optimization_comsumer_limit_level(); // If limit_consumer_mode == 0, no limit on consumer if (limit_consumer_mode == 1) { // Soft limt consumer to consume weight as Broadcast. @@ -176,7 +176,6 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { - // SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); if (!CHECK_JUST(builder->IsInMutOpTransaction(out_node->op().op_name()))) { CHECK_JUST(builder->MutOpTransactionMut(out_node->op().op_conf())); } @@ -322,7 +321,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); OperatorConf new_var_op_conf = var_node->op().op_conf(); const std::string& sole_obn = var_node->op().SoleObn(); - LOG(ERROR) << var_node->op().op_name() << " can be splited, " + VLOG(3) << var_node->op().op_name() << " can be splited, " << " has sbp: " << var_node->NdSbp4BnInOp(sole_obn).DebugString(); const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); // CHECK_EQ(pd.hierarchy()->NumAxes(), 1); @@ -330,7 +329,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder if (new_var_op_conf.variable_conf().nd_sbp(i) == "B") { // TODO(strint): zero with nd choose dim *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(i) = "S(0)"; - LOG(ERROR) << var_node->op().op_name() << " ranks dim " << i << " sbp is changed form B to S(0) " << new_var_op_conf.variable_conf().DebugString(); + VLOG(3) << var_node->op().op_name() << " ranks dim " << i << " sbp is changed form B to S(0) " << new_var_op_conf.variable_conf().DebugString(); // Only split one more dim. break; } diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 81a36725d60..37085b9de1a 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -73,7 +73,7 @@ def build(self, x): self.proto.set_enable_auto_mixed_precision(mode) def enable_zero( - self, mode: bool = True, *, stage: int = 2, min_shard_size: int = 1024 + self, mode: bool = True, *, stage: int = 2, min_shard_size: int = 1024, parameter_consumer_limit_level: int = 2, ): r"""Enable ZeRO redundancy optimizer. @@ -102,6 +102,7 @@ def build(self, x): mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. stage (int): optimization stage, range from 1 to 3. min_shard_size (int): min size of a shard of an optimizer state. + parameter_consumer_limit_level (int): limit consumer to comsume sharded parameter with Broadcast, level 2 is hard limit, level 1 is soft limit, level 0 is no limit. Note that this paremeter is at pre-alpha stage and is not stable. """ if not mode: self.proto.set_optimizer_placement_optimization_mode("none") @@ -110,19 +111,14 @@ def build(self, x): assert ( min_shard_size > 0 ), "ZeRO min size of a sharded optimizer state must > 0." - if stage == 1: - print("zero stage 1 optimization") - self.proto.set_optimizer_placement_optimization_mode("distributed_split") - self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) - elif stage == 2: - self.proto.set_optimizer_placement_optimization_mode("distributed_split") - self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) - oneflow.boxing.nccl.enable_use_compute_stream(True) - elif stage == 3: - print("zero stage 3 optimization") + assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." + if stage >= 1: self.proto.set_optimizer_placement_optimization_mode("distributed_split") self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) + self.proto.set_optimizer_placement_optimization_comsumer_limit_level(parameter_consumer_limit_level) + if stage >= 2: oneflow.boxing.nccl.enable_use_compute_stream(True) + if stage >= 3: oneflow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) def allow_fuse_model_update_ops(self, mode: bool = True): From ac0b9d22fef84d0cef0b04ed7a56319b20adf76e Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 29 Apr 2022 23:11:27 +0800 Subject: [PATCH 15/46] refine zero s select --- oneflow/core/job/job_build_and_infer_ctx.cpp | 5 ++ .../optimizer_placement_optimization_pass.cpp | 65 +++++++++++++------ python/oneflow/nn/graph/graph_config.py | 11 +++- python/oneflow/test/graph/test_graph_zero.py | 16 ++++- 4 files changed, 71 insertions(+), 26 deletions(-) diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 6e723cd9818..6b2f24edb77 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "glog/logging.h" +#include "oneflow/api/python/env/env.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/config_def.h" @@ -997,6 +999,7 @@ Maybe LazyJobBuildAndInferCtx::Complete() { } }; int32_t pass_cnt = 0; + const int64_t prev_v = FLAGS_v; auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe { VLOG(1) << job_name << " is compiling with pass" << " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name @@ -1004,9 +1007,11 @@ Maybe LazyJobBuildAndInferCtx::Complete() { if (unlikely(NeedLogJob(pass_name))) { std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-before"); + FLAGS_v = 3; } JUST(JobPass4Name(pass_name)(mut_job(), &job_pass_ctx)); if (unlikely(NeedLogJob(pass_name))) { + FLAGS_v = prev_v; std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-after"); } diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index d8e82712005..6270f266139 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/util.h" #include "oneflow/core/framework/user_op_conf.h" +#include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" @@ -88,10 +89,7 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); VLOG(3) << cur_node->op().op_name() - << " has sbp: " << cur_node->NdSbp4BnInOp(sole_ibn).DebugString(); - // if (!cur_node->SbpParallel4BnInOp(sole_ibn).has_broadcast_parallel()) { break; } - // if (!(ibn_nd_sbp.sbp_parallel_size() == 1 && - // ibn_nd_sbp.sbp_parallel(0).has_broadcast_parallel())) { break; } + << " has sbp: " << cur_node->NdSbp4BnInOp(sole_ibn).DebugString(); const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn); if (ibn_nd_sbp.sbp_parallel_size() == 0) { break; } bool has_broadcast = false; @@ -104,11 +102,8 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); VLOG(3) << cur_node->op().op_name() - << " has sbp: " << cur_node->NdSbp4BnInOp(sole_obn).DebugString(); - // if (!cur_node->SbpParallel4BnInOp(sole_obn).has_broadcast_parallel()) { break; } + << " has sbp: " << cur_node->NdSbp4BnInOp(sole_obn).DebugString(); const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); - // if (!(obn_nd_sbp.sbp_parallel_size() == 1 && - // obn_nd_sbp.sbp_parallel(0).has_broadcast_parallel())) { break; } bool has_broadcast = false; FOR_RANGE(int, i, 0, obn_nd_sbp.sbp_parallel_size()) { if (obn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; @@ -157,12 +152,14 @@ void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::stri void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { const OpNode* node = sequence->GetLastNode(); const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); - const int64_t limit_consumer_mode = builder->job().job_conf().optimizer_placement_optimization_comsumer_limit_level(); + const int64_t limit_consumer_mode = + builder->job().job_conf().optimizer_placement_optimization_comsumer_limit_level(); // If limit_consumer_mode == 0, no limit on consumer if (limit_consumer_mode == 1) { // Soft limt consumer to consume weight as Broadcast. const auto parallel_cast_op = - user_op::UserOpConfWrapperBuilder("System-ZeRO-ParallelCast-" + node->op().op_name() + "-" + NewUniqueId()) + user_op::UserOpConfWrapperBuilder("System-ZeRO-ParallelCast-" + node->op().op_name() + "-" + + NewUniqueId()) .Op("hierarchical_parallel_cast") .Input("in", GenLogicalBlobName(lbi)) .Output("out") @@ -313,7 +310,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder }; const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, std::vector&& sorted_sequences) { - // For all sorted sequnence, set the variable op in the sequence to S(0) + // For all sorted sequnence, set the variable op in the sequence to S // and add ctrl edge to control the exectuion order between variable ops. // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass // consume the fp16 variable and the optimizer consume the fp32 variable. @@ -321,19 +318,43 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); OperatorConf new_var_op_conf = var_node->op().op_conf(); const std::string& sole_obn = var_node->op().SoleObn(); - VLOG(3) << var_node->op().op_name() << " can be splited, " - << " has sbp: " << var_node->NdSbp4BnInOp(sole_obn).DebugString(); const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); - // CHECK_EQ(pd.hierarchy()->NumAxes(), 1); - FOR_RANGE(int, i, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { - if (new_var_op_conf.variable_conf().nd_sbp(i) == "B") { - // TODO(strint): zero with nd choose dim - *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(i) = "S(0)"; - VLOG(3) << var_node->op().op_name() << " ranks dim " << i << " sbp is changed form B to S(0) " << new_var_op_conf.variable_conf().DebugString(); + std::string new_split_signature = ""; + int64_t split_dim = 0; + if (new_var_op_conf.variable_conf().nd_sbp_size() == 1 + && new_var_op_conf.variable_conf().nd_sbp(i) == "B") { + new_split_signature = "S(0)"; + split_dim = 0; + } else { + FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { + if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { + std::vector adjacent_dim{j - 1, j + 1}; + for (auto const& dim_to_try : adjacent_dim) { + if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { + SbpParallel sbp; + if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), + &sbp) + && sbp.has_split_parallel()) { + new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); + split_dim = j; + } + } + if (new_split_signature != "") break; + } + } // Only split one more dim. - break; + if (new_split_signature != "") break; } } + if (new_split_signature != "") { + *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature; + VLOG(3) << var_node->op().op_name() << " succeed to change form B to " + << new_split_signature << " on ranks dim " << split_dim << " with op conf " + << new_var_op_conf.variable_conf().DebugString(); + } else { + VLOG(3) << var_node->op().op_name() << " failed to change form B to S " + << " with op conf " << new_var_op_conf.variable_conf().DebugString(); + } if (i != 0) { const std::string& prev_op_name = sorted_sequences.at(i - 1)->GetVariableNode()->op().op_name(); @@ -342,7 +363,9 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder // TODO(strint): rewrite with MutOpTransactioin builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. - SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); + if (new_split_signature != "") { + SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); + } } }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 37085b9de1a..34b90e664d0 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -73,7 +73,12 @@ def build(self, x): self.proto.set_enable_auto_mixed_precision(mode) def enable_zero( - self, mode: bool = True, *, stage: int = 2, min_shard_size: int = 1024, parameter_consumer_limit_level: int = 2, + self, + mode: bool = True, + *, + stage: int = 2, + min_shard_size: int = 1024, + parameter_consumer_limit_level: int = 2, ): r"""Enable ZeRO redundancy optimizer. @@ -115,7 +120,9 @@ def build(self, x): if stage >= 1: self.proto.set_optimizer_placement_optimization_mode("distributed_split") self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) - self.proto.set_optimizer_placement_optimization_comsumer_limit_level(parameter_consumer_limit_level) + self.proto.set_optimizer_placement_optimization_comsumer_limit_level( + parameter_consumer_limit_level + ) if stage >= 2: oneflow.boxing.nccl.enable_use_compute_stream(True) if stage >= 3: diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 21bec2cd8c8..7712b08f706 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -53,7 +53,12 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - self.config.enable_zero(True, stage=zero_stage, min_shard_size=1, parameter_consumer_limit_level=0) + self.config.enable_zero( + True, + stage=zero_stage, + min_shard_size=1, + parameter_consumer_limit_level=0, + ) self.debug(2) def build(self, x): @@ -143,7 +148,12 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - self.config.enable_zero(True, stage=zero_stage, min_shard_size=1, parameter_consumer_limit_level=2) + self.config.enable_zero( + True, + stage=zero_stage, + min_shard_size=1, + parameter_consumer_limit_level=0, + ) self.debug(1) def build(self, x): @@ -163,7 +173,7 @@ def __init__(self): def build(self, x): out = self.linear_dp_mp(x) - #out = out.to_global(placement=P, sbp=[B, S0]) + # out = out.to_global(placement=P, sbp=[B, S0]) out = self.linear_mp_dp(out) return out From dd0a865f8bf4db1eb56598bbb5837fe78208456b Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 5 May 2022 16:02:23 +0800 Subject: [PATCH 16/46] fix index out of range --- .../core/job_rewriter/optimizer_placement_optimization_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 6270f266139..4cb709a457d 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -322,7 +322,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder std::string new_split_signature = ""; int64_t split_dim = 0; if (new_var_op_conf.variable_conf().nd_sbp_size() == 1 - && new_var_op_conf.variable_conf().nd_sbp(i) == "B") { + && new_var_op_conf.variable_conf().nd_sbp(0) == "B") { new_split_signature = "S(0)"; split_dim = 0; } else { From 501518fd8f27370bd258e88570adfae30819e1ac Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 6 May 2022 16:58:15 +0800 Subject: [PATCH 17/46] rm zero limit on device type --- .../core/job_rewriter/optimizer_placement_optimization_pass.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index e624a482236..30b46fc0f78 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -77,7 +77,6 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( // Find sequence like: vairable -> cast_fp32_to_fp16 if (!start->op().op_conf().has_variable_conf()) { return Maybe::Ok(); } const ParallelDesc& pd = start->parallel_desc(); - if (pd.device_type() != DeviceType::kCUDA) { return Maybe::Ok(); } if (pd.parallel_num() == 1) { return Maybe::Ok(); } const OpNode* cur_node = start; while (cur_node != nullptr) { From e3eed8cea4048e04b3ea8df5bb68f2a513b65a5c Mon Sep 17 00:00:00 2001 From: strint Date: Sat, 7 May 2022 14:57:37 +0800 Subject: [PATCH 18/46] zero test with activation checkpointing --- python/oneflow/test/graph/test_graph_zero.py | 59 +++++++++++++------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 7712b08f706..50502e21677 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -119,18 +119,37 @@ def train_with_graph(iter_num=1): S0 = flow.sbp.split(0) S1 = flow.sbp.split(1) - linear_dp_mp = flow.nn.Linear(800, 400, bias=False) - linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) - flow.nn.init.constant_(linear_dp_mp.weight, 2.068758) - - linear_mp_dp = flow.nn.Linear(800, 500, bias=False) - linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) - flow.nn.init.constant_(linear_mp_dp.weight, 2.068758) + def get_mixed_linear(): + linear_dp_mp = flow.nn.Linear(800, 400, bias=False) + linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) + flow.nn.init.constant_(linear_dp_mp.weight, 2.068758) + + linear_mp_dp = flow.nn.Linear(800, 400, bias=False) + linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) + flow.nn.init.constant_(linear_mp_dp.weight, 2.068758) + + class MixedLinear(flow.nn.Module): + def __init__(self): + super().__init__() + self.dp_mp = linear_dp_mp + self.mp_dp = linear_mp_dp + + def forward(self, x): + x = self.dp_mp(x) + x = flow.relu(x) + x = self.mp_dp(x) + x = flow.relu(x) + return x + + return MixedLinear() + + mixed_linear0 = get_mixed_linear() + mixed_linear1 = get_mixed_linear() of_sgd = flow.optim.SGD( [ - {"params": linear_dp_mp.parameters()}, - {"params": linear_mp_dp.parameters()}, + {"params": mixed_linear0.parameters()}, + {"params": mixed_linear1.parameters()}, ], lr=0.001, momentum=0.9, @@ -139,11 +158,14 @@ def train_with_graph(iter_num=1): x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) + #flow.boxing.nccl.enable_use_compute_stream(True) class LinearTrainGraph2DWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.linear_dp_mp = linear_dp_mp - self.linear_mp_dp = linear_mp_dp + self.mixed_linear0 = mixed_linear0 + self.mixed_linear0.config.activation_checkpointing = True + self.mixed_linear1 = mixed_linear1 + self.mixed_linear1.config.activation_checkpointing = True self.add_optimizer(of_sgd) self.config.enable_amp(True) @@ -152,13 +174,13 @@ def __init__(self): True, stage=zero_stage, min_shard_size=1, - parameter_consumer_limit_level=0, + parameter_consumer_limit_level=1, ) self.debug(1) def build(self, x): - out = self.linear_dp_mp(x) - out = self.linear_mp_dp(out) + out = self.mixed_linear0(x) + out = self.mixed_linear1(out) loss = out.sum() loss.backward() return out @@ -166,15 +188,14 @@ def build(self, x): class LinearEvalGraph2DWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.linear_dp_mp = linear_dp_mp - self.linear_mp_dp = linear_mp_dp + self.mixed_linear0 = mixed_linear0 + self.mixed_linear1 = mixed_linear1 self.config.enable_amp(True) def build(self, x): - out = self.linear_dp_mp(x) - # out = out.to_global(placement=P, sbp=[B, S0]) - out = self.linear_mp_dp(out) + out = self.mixed_linear0(x) + out = self.mixed_linear1(out) return out linear_t_g = LinearTrainGraph2DWithZeRO() From ebc9ff9306e0d69faef889519957cffc5d45ad36 Mon Sep 17 00:00:00 2001 From: strint Date: Sat, 21 May 2022 13:29:35 +0800 Subject: [PATCH 19/46] add indentity when dp sequence len is 1 --- .../optimizer_placement_optimization_pass.cpp | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index a4ec8101e9b..7b71e1eb772 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" +#include "oneflow/core/operator/operator.h" namespace oneflow { @@ -38,6 +39,7 @@ class DataParallelNodeSequence final { const OpNode* var_node = nodes_.front(); CHECK(var_node->op().op_conf().has_variable_conf()); model_size_ = GetSoleOutBlobSize(var_node); + len_ = nodes_.size(); } ~DataParallelNodeSequence() = default; @@ -53,10 +55,13 @@ class DataParallelNodeSequence final { int64_t model_size() const { return model_size_; } + int64_t len() const { return len_; } + private: std::vector nodes_; int64_t order_; int64_t model_size_; + int64_t len_; }; using SequencePtr = std::shared_ptr; @@ -155,12 +160,27 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const builder->job().job_conf().optimizer_placement_optimization_comsumer_limit_level(); // If limit_consumer_mode == 0, no limit on consumer if (limit_consumer_mode == 1) { - // Soft limt consumer to consume weight as Broadcast. + // input lbn for parallel cast op + std::string parallel_cast_input_lbn = GenLogicalBlobName(lbi); + // Add indentity to enable mem reuse of boxing op when there is no op between var op and boxing. + if (sequence->len() == 1) { + LOG(ERROR) << "ZeRO find a data-parallel sequence only has one variable " << sequence->GetVariableNode()->op().op_name(); + const auto var_identity_op = user_op::UserOpConfWrapperBuilder("System-ZeRO-Identity-" + node->op().op_name() + "-" + + NewUniqueId()) + .Op("identity") + .Input("in", GenLogicalBlobName(lbi)) + .Output("out") + .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) + .Build(); + builder->AddOps(node->parallel_desc().parallel_conf(), {var_identity_op.op_conf()}); + parallel_cast_input_lbn = var_identity_op.output("out", 0); + } + // Add parallel cast op to make soft limt on consumer to consume weight with Broadcast SBP. const auto parallel_cast_op = user_op::UserOpConfWrapperBuilder("System-ZeRO-ParallelCast-" + node->op().op_name() + "-" + NewUniqueId()) .Op("hierarchical_parallel_cast") - .Input("in", GenLogicalBlobName(lbi)) + .Input("in", parallel_cast_input_lbn) .Output("out") .Attr>("nd_sbp", NdSbpToStringList(nd_sbp)) .Attr("grad_mode", "identity") // don't do ndsbp cast at backward @@ -168,6 +188,8 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) .Build(); builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()}); + + // Make consumers to consume parallel cast op auto out_lbn = parallel_cast_op.output("out", 0); node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { From 2011e2ce3b6e23aafff2d1e67e7bf45dbb575fd8 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 26 May 2022 22:26:00 +0800 Subject: [PATCH 20/46] move to base with master --- oneflow/core/job/nd_sbp_util.cpp | 101 ++++++ oneflow/core/job/nd_sbp_util.h | 8 + .../insert_nccl_logical_op_pass.cpp | 59 +++- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 22 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 291 ++++++++++++++++++ oneflow/user/ops/nccl_logical_ops.cpp | 33 ++ 6 files changed, 497 insertions(+), 17 deletions(-) create mode 100644 oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index 26b543b27ed..3a2224061bf 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -122,4 +122,105 @@ TensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc) { return TensorSliceView(blob_desc.shape()); } +bool NdSbpHasPartialParallel(const NdSbp& nd_sbp) { + CHECK_GT(nd_sbp.sbp_parallel_size(), 0); + FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { + if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return true; } + } + return false; +} + +namespace { +// Go through all the ranks while transfer between two nd sbps with no PartialSum under the same +// placement. +// NOTE: We need to make sure no partial sums in the sbps of the producer and consumer. +void DfsTraverseRanks4NdSbp( + int32_t depth, std::vector& in_parallel_ids, + const std::vector& out_parallel_ids, const Shape& parallel_hierarchy, + const NdIndexOffsetHelper& hierarchy_index_helper, + const NdSbp& in_nd_sbp, const std::function& visit) { + if (depth >= parallel_hierarchy.NumAxes()) { + visit(hierarchy_index_helper.NdIndexToOffset(out_parallel_ids.data(), + parallel_hierarchy.NumAxes()), + hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), + parallel_hierarchy.NumAxes())); + return; + } + if (in_nd_sbp.sbp_parallel(depth).has_broadcast_parallel()) { + // If Broadcast in the sbp of the producer, only visit those ranks with the same id as the + // current rank along the depth-dimension. + in_parallel_ids[depth] = out_parallel_ids[depth]; + DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, parallel_hierarchy, + hierarchy_index_helper, in_nd_sbp, visit); + } else { + // If Split or PartialSum, go through all the ranks along the depth-dimension. + for (int64_t i = 0; i < parallel_hierarchy.dim_vec().at(depth); i++) { + in_parallel_ids[depth] = i; + DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, parallel_hierarchy, + hierarchy_index_helper, in_nd_sbp, visit); + } + } +} + +void DfsTraverse4NdSbp(int64_t out_id, const std::shared_ptr parallel_hierarchy, + const NdSbp& in_nd_sbp, const std::function& visit) { + int32_t hierarchy_dimension = parallel_hierarchy->NumAxes(); + const NdIndexOffsetHelper hierarchy_index_helper( + parallel_hierarchy->dim_vec().data(), hierarchy_dimension); + std::vector in_parallel_ids(hierarchy_dimension); + std::vector out_parallel_ids(hierarchy_dimension); + hierarchy_index_helper.OffsetToNdIndex(out_id, out_parallel_ids.data(), hierarchy_dimension); + DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *parallel_hierarchy, + hierarchy_index_helper, in_nd_sbp, visit); +} + +bool NdSbpNoPartialParallel(const NdSbp& nd_sbp) { + CHECK_GT(nd_sbp.sbp_parallel_size(), 0); + FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { + if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; } + } + return true; +} + +} // namespace + +void GetSendRecvIntersection(int64_t parallel_id, const std::shared_ptr parallel_hierarchy, + const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, + const Shape& logical_shape, + std::vector* src_send_intersections, + std::vector* dst_recv_intersections) { + CHECK(parallel_hierarchy); + const int64_t parallel_num = parallel_hierarchy->elem_cnt(); + CHECK_LT(parallel_id, parallel_num); + + const std::vector& out_slices = + GetTensorSliceView(*parallel_hierarchy, dst_nd_sbp, logical_shape); + const std::vector& in_slices = + GetTensorSliceView(*parallel_hierarchy, src_nd_sbp, logical_shape); + + // cur_out_slice recv from + dst_recv_intersections->resize(parallel_num); + const TensorSliceView& cur_rank_out_slice = out_slices.at(parallel_id); + const auto& add_to_dst_recv_intersections = [&](int32_t out_id, int32_t in_id) { + CHECK_EQ(out_id, parallel_id); + const TensorSliceView& in_slice = in_slices.at(in_id); + const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); + dst_recv_intersections->at(in_id) = intersection; + }; + DfsTraverse4NdSbp(parallel_id, parallel_hierarchy, src_nd_sbp, add_to_dst_recv_intersections); + + // cur_in_slice send to + src_send_intersections->resize(parallel_num); + const TensorSliceView& cur_rank_in_slice = in_slices.at(parallel_id); + const auto& add_to_src_send_intersections = [&](int32_t out_id, int32_t in_id) { + if (in_id != parallel_id) { return; } + const TensorSliceView& out_slice = out_slices.at(out_id); + const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); + src_send_intersections->at(out_id) = intersection; + }; + for (int64_t i = 0; i < parallel_num; ++i) { + DfsTraverse4NdSbp(i, parallel_hierarchy, src_nd_sbp, add_to_src_send_intersections); + } +} + } // namespace oneflow diff --git a/oneflow/core/job/nd_sbp_util.h b/oneflow/core/job/nd_sbp_util.h index 990c7ba6798..abeddf09066 100644 --- a/oneflow/core/job/nd_sbp_util.h +++ b/oneflow/core/job/nd_sbp_util.h @@ -33,6 +33,14 @@ TensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, c const Shape& logical_shape, int64_t parallel_id); TensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc); +bool NdSbpHasPartialParallel(const NdSbp& nd_sbp); + +void GetSendRecvIntersection(int64_t parallel_id, const std::shared_ptr parallel_hierarchy, + const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, + const Shape& logical_shape, + std::vector* src_send_intersections, + std::vector* dst_recv_intersections); + } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 3b270853dd8..d4033adeba1 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -330,6 +330,26 @@ bool TryBuildNcclBy2DHierarchySameDim1(OperatorConf* ret, const NdSbp& src_nd_sb return false; } +bool TryBuildNcclBy2DHierarchyOthers(OperatorConf* ret, const NdSbp& src_nd_sbp, + const NdSbp& dst_nd_sbp, + const std::shared_ptr& hierarchy, + const std::string& lbn, const int64_t scope_symbol_id, + const BlobDesc& logical_blob_desc) { + CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2); + CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2); + *ret = + user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-(Send)2(Recv)-" + NewUniqueId()) + .Op("_nccl_logical_send_recv") + .Input("in", lbn) + .Output("out") + .Attr>("src_nd_sbp", NdSbpToStringList(src_nd_sbp)) + .Attr>("dst_nd_sbp", NdSbpToStringList(dst_nd_sbp)) + .ScopeSymbolId(scope_symbol_id) + .Build() + .op_conf(); + return true; +} + Maybe BuildScopeWithReducedParallelDesc(int64_t old_scope_symbol_id, const ParallelDesc& parallel_desc) { auto* scope_storage = Global>::Get(); @@ -366,7 +386,7 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const std::shared_ptr dst_reduced_hierarchy = dst_reduced_parallel_desc->hierarchy(); if ((*src_reduced_hierarchy) == (*dst_reduced_hierarchy) - && src_reduced_nd_sbp == dst_reduced_nd_sbp) { + && (*src_reduced_nd_sbp) == (*dst_reduced_nd_sbp)) { // one to one return false; } @@ -389,18 +409,27 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const logical_blob_desc, src_reduced_parallel_desc->parallel_num()); } else if (src_reduced_hierarchy->NumAxes() == 2 && (*src_reduced_hierarchy == *dst_reduced_hierarchy)) { + bool got_nccl = false; if (src_reduced_nd_sbp->sbp_parallel(0) == dst_reduced_nd_sbp->sbp_parallel(0)) { - return TryBuildNcclBy2DHierarchySameDim0(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, + got_nccl = TryBuildNcclBy2DHierarchySameDim0(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, src_reduced_hierarchy, lbn, scope_symbol_id, logical_blob_desc); } else if (src_reduced_nd_sbp->sbp_parallel(1) == dst_reduced_nd_sbp->sbp_parallel(1)) { if (!(NdSbpAllSameSplitParallel(*src_reduced_nd_sbp) || NdSbpAllSameSplitParallel(*dst_reduced_nd_sbp))) { - return TryBuildNcclBy2DHierarchySameDim1(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, + got_nccl = TryBuildNcclBy2DHierarchySameDim1(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, src_reduced_hierarchy, lbn, scope_symbol_id, logical_blob_desc); } } + if (!got_nccl && ParseBooleanFromEnv("LOGICAL_SR", false)) { + got_nccl = TryBuildNcclBy2DHierarchyOthers(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, + src_reduced_hierarchy, lbn, scope_symbol_id, + logical_blob_desc); + } + return got_nccl; + } else { + VLOG(3) << "Cannot get nccl logical for src nd sbp " << NdSbpToString(*src_reduced_nd_sbp) << ", dst nd sbp " << NdSbpToString(*dst_reduced_nd_sbp) << "."; } return false; } @@ -460,12 +489,12 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( } if (Global::Get()->enable_debug_mode()) { - VLOG(3) << " insert nccl op: " << nccl_op.name() << " from: [" << src_op_name - << "](order=" << src_order - << ", nd_sbp=" << NdSbpToString(src_node->NdSbp4Lbi(lbi)) << ")->[" << dst_op_name - << "](order=" << node2subgraph_order.at(dst_node) - << ", nd_sbp=" << NdSbpToString(dst_node->NdSbp4Lbi(lbi)) << ") and before: [" - << next_op_name << "](order=" << src_order + 1 << ")\n"; + VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name + << ", order=" << src_order + << ", sbp=" << NdSbpToString(src_node->NdSbp4Lbi(lbi)) << "] to [" << dst_op_name + << ", order=" << node2subgraph_order.at(dst_node) + << ", sbp=" << NdSbpToString(dst_node->NdSbp4Lbi(lbi)) << "] and before [" + << next_op_name << ", order=" << src_order + 1 << "]\n"; } nccl_op_confs->emplace_back(nccl_op); nccl_op_parallel_confs->emplace_back(src_reduced_parallel_desc.parallel_conf()); @@ -527,9 +556,9 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( } if (Global::Get()->enable_debug_mode()) { - VLOG(3) << " insert nccl op: " << nccl_op.name() << " from: [" << src_op_name << "](" - << node2subgraph_order.at(src_node) << ")->[" << dst_op_name << "](" << dst_order - << ") and after: [" << pre_op_name << "](" << dst_order - 1 << ")\n"; + VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name << ", order=" + << node2subgraph_order.at(src_node) << "] to [" << dst_op_name << ", order=" << dst_order + << "] and after [" << pre_op_name << ", order=" << dst_order - 1 << "]\n"; } nccl_op_confs->emplace_back(nccl_op); // NOTE(chengcheng, guoran): set nccl op as src_node parallel_conf (hierarchy) may check @@ -617,10 +646,10 @@ void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, nccl_op_info.nccl_parallel_conf = src_reduced_parallel_desc.parallel_conf(); nccl_op_info.order = op_node2global_order.at(src_node); nccl_op_info.debug_str = - (" After ACC insert nccl op: " + nccl_op.name() + " from: [" + src_op_name + "](" - + NdSbpToString(src_node->NdSbp4Lbi(lbi)) + ")->[" + dst_op_name + "](" + (" After ACC insert nccl op: " + nccl_op.name() + " from [" + src_op_name + ", sbp=" + + NdSbpToString(src_node->NdSbp4Lbi(lbi)) + "] to [" + dst_op_name + ", sbp=" + NdSbpToString(dst_node->NdSbp4Lbi(lbi)) - + "), src_order = " + std::to_string(nccl_op_info.order) + "\n"); + + ", src_order=" + std::to_string(nccl_op_info.order) + "]\n"); nccl_op_infos.emplace_back(nccl_op_info); } diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index b13273000cb..71cf8d9ddae 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -5144,8 +5144,8 @@ def OneFlow_StackGradOp : OneFlow_BaseOp<"stack_grad", [NoSideEffect, NoGrad, De #endif // GET_ONEFLOW_MISC_OP_DEFINITIONS // Group: NCCL -// _nccl_logical_2D_same_dim0_all2all, _nccl_logical_2D_same_dim0_all_gather, _nccl_logical_2D_same_dim0_all_gather_noncontinuous, _nccl_logical_2D_same_dim0_all_reduce, _nccl_logical_2D_same_dim1_all_reduce, _nccl_logical_all_gather, _nccl_logical_all_gather_noncontinuous, _nccl_logical_all_reduce, _nccl_logical_reduce_scatter, _nccl_logical_s2s -// Total: 10 +// _nccl_logical_2D_same_dim0_all2all, _nccl_logical_2D_same_dim0_all_gather, _nccl_logical_2D_same_dim0_all_gather_noncontinuous, _nccl_logical_2D_same_dim0_all_reduce, _nccl_logical_2D_same_dim1_all_reduce, _nccl_logical_all_gather, _nccl_logical_all_gather_noncontinuous, _nccl_logical_all_reduce, _nccl_logical_reduce_scatter, _nccl_logical_s2s, _nccl_logical_send_recv +// Total: 11 #ifdef GET_ONEFLOW_NCCL_OP_DEFINITIONS @@ -5329,6 +5329,24 @@ def OneFlow__ncclLogicalS2sOp : OneFlow_BaseOp<"_nccl_logical_s2s", [NoSideEffec let has_nd_sbp_infer_fn = 1; } +def OneFlow__ncclLogicalSendRecvOp : OneFlow_BaseOp<"_nccl_logical_send_recv", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrArrayAttr:$src_nd_sbp, + StrArrayAttr:$dst_nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_and_stream_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + #endif // GET_ONEFLOW_NCCL_OP_DEFINITIONS // Group: NORMALIZATION diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp new file mode 100644 index 00000000000..fdc839cf9c0 --- /dev/null +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -0,0 +1,291 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/core/framework/infer_util.h" +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/job/eager_nccl_comm_manager.h" +#include "oneflow/core/job/nd_sbp_util.h" +#include "oneflow/core/register/tensor_slice_copier.h" +#include "oneflow/core/ep/include/primitive/memset.h" +#include "oneflow/core/ep/include/primitive/add.h" + +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + +namespace oneflow { + +class NcclLogicalSendRecvState final : public user_op::OpKernelState { + public: + explicit NcclLogicalSendRecvState(user_op::KernelInitContext* ctx); + const std::vector>& in_tensor_slice_copier_vec() const { + return in_tensor_slice_copier_vec_; + } + const std::vector>& out_tensor_slice_copier_vec() const { + return out_tensor_slice_copier_vec_; + } + bool src_nd_sbp_has_no_partial_parallel() const { return src_nd_sbp_no_partial_parallel_; } + const std::vector& send_elem_cnts() const { return send_elem_cnts_; } + const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } + ncclComm_t comm() const { return GetOrCreateComm().comm; } + + private: + struct Comm { + Comm(ncclComm_t comm) : comm(comm) {} + ncclComm_t comm; + }; + void InitComm() const; + const Comm& GetOrCreateComm() const { + if (!comm_) { InitComm(); } + return *comm_; + } + + bool has_independent_stream_; + std::string stream_name_; + ParallelDesc parallel_desc_; + mutable std::unique_ptr comm_; + bool src_nd_sbp_no_partial_parallel_; + std::vector> in_tensor_slice_copier_vec_; + std::vector> out_tensor_slice_copier_vec_; + std::vector send_elem_cnts_; + std::vector recv_elem_cnts_; +}; + +NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* ctx) : +parallel_desc_(ctx->parallel_desc()){ + has_independent_stream_ = ctx->op_conf().has_stream_name_hint(); + if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); } + NdSbp src_nd_sbp; + NdSbp dst_nd_sbp; + CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); + CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); + src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); + const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + const auto& parallel_hierarchy = parallel_desc_.hierarchy(); + CHECK_EQ(src_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes()); + CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes()); + const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); + + const Shape& logical_shape = Shape(in_logical_desc->shape()); + std::vector src_send_intersections; + std::vector dst_recv_intersections; + GetSendRecvIntersection(parallel_id, parallel_desc_.hierarchy(), src_nd_sbp, dst_nd_sbp, + logical_shape, &src_send_intersections, &dst_recv_intersections); + + const DataType data_type = in_logical_desc->data_type(); + const DeviceType device_type = parallel_desc_.device_type(); + const int64_t parallel_num = parallel_desc_.parallel_num(); + CHECK_EQ(src_send_intersections.size(), parallel_num); + send_elem_cnts_.resize(parallel_num); + in_tensor_slice_copier_vec_.resize(parallel_num); + const TensorSliceView& cur_rank_in_slice = + GetTensorSliceView4ParallelId(*parallel_hierarchy, src_nd_sbp, logical_shape, parallel_id); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = src_send_intersections.at(i); + if (!intersection.IsEmpty()) { + send_elem_cnts_.at(i) = intersection.shape().elem_cnt(); + in_tensor_slice_copier_vec_.at(i).reset( + new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type)); + } + } + + CHECK_EQ(dst_recv_intersections.size(), parallel_num); + recv_elem_cnts_.resize(parallel_num); + out_tensor_slice_copier_vec_.resize(parallel_num); + const TensorSliceView& cur_rank_out_slice = + GetTensorSliceView4ParallelId(*parallel_hierarchy, dst_nd_sbp, logical_shape, parallel_id); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = dst_recv_intersections.at(i); + if (!intersection.IsEmpty()) { + recv_elem_cnts_.at(i) = intersection.shape().elem_cnt(); + out_tensor_slice_copier_vec_.at(i).reset( + new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type)); + } + } +} + +void NcclLogicalSendRecvState::InitComm() const { + std::set> device_set; + for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); + ncclComm_t comm = nullptr; + if (has_independent_stream_) { + comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); + } else { + comm = comm_mgr->GetCommForDevice(device_set); + } + comm_.reset(new Comm(comm)); +} + +class NcclLogicalSendRecv final : public user_op::OpKernel { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclLogicalSendRecv); + NcclLogicalSendRecv() = default; + ~NcclLogicalSendRecv() override = default; + + std::shared_ptr CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + return std::make_shared(ctx); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override; + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const { + auto* kernel_state = dynamic_cast(state); + CHECK_NOTNULL(kernel_state); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + ncclComm_t comm = kernel_state->comm(); + cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + const std::vector& send_elem_cnts = kernel_state->send_elem_cnts(); + const std::vector& recv_elem_cnts = kernel_state->recv_elem_cnts(); + const int64_t parallel_num = send_elem_cnts.size(); + const DataType data_type = in->data_type(); + + std::vector send_in_ptr; + std::vector recv_out_ptr; + char* buf_ptr = tmp_buffer->mut_dptr(); + int64_t offset = 0; + for (int64_t i = 0; i < parallel_num; ++i) { + void* send_ptr = reinterpret_cast(buf_ptr + offset); + send_in_ptr.push_back(send_ptr); + offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); + } + for (int64_t i = 0; i < parallel_num; ++i) { + void* recv_ptr = reinterpret_cast(buf_ptr + offset); + recv_out_ptr.push_back(recv_ptr); + offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); + } + + const std::vector>& in_tensor_slice_copier_vec = + kernel_state->in_tensor_slice_copier_vec(); + for (int64_t i = 0; i < parallel_num; ++i) { + if (in_tensor_slice_copier_vec.at(i)) { + in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); + } + } + const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < parallel_num; ++i) { + if (send_elem_cnts.at(i) != 0) { + LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; + OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, + comm, cuda_stream)); + } + if (recv_elem_cnts.at(i) != 0) { + LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; + OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), + i, comm, cuda_stream)); + } + } + OF_NCCL_CHECK(ncclGroupEnd()); + const std::vector>& out_tensor_slice_copier_vec = + kernel_state->out_tensor_slice_copier_vec(); + + if (kernel_state->src_nd_sbp_has_no_partial_parallel()) { + for (int64_t i = 0; i < parallel_num; ++i) { + if (out_tensor_slice_copier_vec.at(i)) { + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); + } + } + } else { + std::unique_ptr add_primitive = + ep::primitive::NewPrimitive(ctx->stream()->device_type(), + out->data_type()); + CHECK(add_primitive); + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(ctx->stream()->device_type()); + CHECK(memset_primitive); + bool is_first_slice = true; + for (int64_t i = 0; i < parallel_num; ++i) { + if (out_tensor_slice_copier_vec.at(i)) { + if (is_first_slice) { + is_first_slice = false; + if (recv_elem_cnts.at(i) != out->shape().elem_cnt()) { + // if not same shape, memset out + memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, + out->shape().elem_cnt() * GetSizeOfDataType(data_type)); + } + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), + recv_out_ptr.at(i)); + } else { + if (recv_elem_cnts.at(i) == out->shape().elem_cnt()) { + add_primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(), + out->shape().elem_cnt()); + } else { + void* out_buf = reinterpret_cast(buf_ptr + offset); + memset_primitive->Launch(ctx->stream(), out_buf, 0, + out->shape().elem_cnt() * GetSizeOfDataType(data_type)); + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i)); + add_primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(), + out->shape().elem_cnt()); + } + } + } + } + } +} + +size_t InferTmpBufferSize(user_op::InferContext* ctx) { + const Shape* out_shape = ctx->OutputShape("out", 0); + const user_op::TensorDesc* logical_in_tensor = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); + const Shape& logical_shape = logical_in_tensor->shape(); + + const NdSbp& src_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + const NdSbp& dst_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + const int64_t parallel_num = ctx->parallel_num(); + const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + + std::vector src_send_intersections; + std::vector dst_recv_intersections; + GetSendRecvIntersection(parallel_id, ctx->parallel_desc().hierarchy(), src_nd_sbp, dst_nd_sbp, + logical_shape, &src_send_intersections, &dst_recv_intersections); + int64_t buf_count = 0; + CHECK_EQ(src_send_intersections.size(), parallel_num); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = src_send_intersections.at(i); + if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } + } + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = dst_recv_intersections.at(i); + if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } + } + if (NdSbpHasPartialParallel(src_nd_sbp)) { + // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. + buf_count += out_shape->elem_cnt(); + } + return buf_count; +} + +REGISTER_USER_KERNEL("_nccl_logical_send_recv") + .SetCreateFn() + .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) + .SetInferTmpSizeFn(InferTmpBufferSize); + +} // namespace oneflow + +#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index a36bd5eb24d..b3832688088 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -221,4 +221,37 @@ namespace oneflow { return DeviceAndStreamInferFn<&SyncLaunched>(ctx); } +/* static */ Maybe _ncclLogicalSendRecvOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalSendRecvOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalSendRecvOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + input_nd_sbp->clear_sbp_parallel(); + output_nd_sbp->clear_sbp_parallel(); + + JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_nd_sbp", input_nd_sbp)); + JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_nd_sbp", output_nd_sbp)); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalSendRecvOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalSendRecvOp::InferDeviceAndStream( + user_op::DeviceAndStreamInferContext* ctx) { + return DeviceAndStreamInferFn<&SyncLaunched>(ctx); +} + } // namespace oneflow From ffe2094b6a10f4839ff37ea3f50dbff00e5c4c06 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 27 May 2022 01:14:53 +0800 Subject: [PATCH 21/46] fix --- oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index fdc839cf9c0..a663da3a51b 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -71,8 +71,8 @@ parallel_desc_(ctx->parallel_desc()){ if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); } NdSbp src_nd_sbp; NdSbp dst_nd_sbp; - CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_reduced_nd_sbp", &src_nd_sbp)); - CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_reduced_nd_sbp", &dst_nd_sbp)); + CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_nd_sbp", &src_nd_sbp)); + CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_nd_sbp", &dst_nd_sbp)); src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const auto& parallel_hierarchy = parallel_desc_.hierarchy(); From b58b48ab666844a0a25c0d09657e00c855f21583 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 27 May 2022 01:18:19 +0800 Subject: [PATCH 22/46] fix --- .../core/job_rewriter/optimizer_placement_optimization_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 7b71e1eb772..0cd7f990438 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -198,7 +198,7 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const CHECK_JUST(builder->MutOpTransactionMut(out_node->op().op_conf())); } OperatorConf& mut_consumer_op = - *CHECK_JUST(builder->MutOpTransactionGet(out_node->op().op_name())); + CHECK_JUST(builder->MutOpTransactionGet(out_node->op().op_name())); const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, out_lbn); CHECK_EQ(old_lbn, GenLogicalBlobName(lbi)); } From 6975f33c09c099761245c41a948da95778da973f Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 27 May 2022 01:33:09 +0800 Subject: [PATCH 23/46] fix --- python/oneflow/nn/graph/graph_config.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 884a586ae65..72ff402b4fd 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -70,7 +70,7 @@ def build(self, x): """ assert type(mode) is bool - self.proto.set_enable_auto_mixed_precision(mode) + self.proto.enable_auto_mixed_precision = mode def enable_zero( self, @@ -110,7 +110,7 @@ def build(self, x): parameter_consumer_limit_level (int): limit consumer to comsume sharded parameter with Broadcast, level 2 is hard limit, level 1 is soft limit, level 0 is no limit. Note that this paremeter is at pre-alpha stage and is not stable. """ if not mode: - self.proto.set_optimizer_placement_optimization_mode("none") + self.proto.optimizer_placement_optimization_mode = "none" return assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." assert ( @@ -118,11 +118,9 @@ def build(self, x): ), "ZeRO min size of a sharded optimizer state must > 0." assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." if stage >= 1: - self.proto.set_optimizer_placement_optimization_mode("distributed_split") - self.proto.set_optimizer_placement_optimization_threshold(min_shard_size) - self.proto.set_optimizer_placement_optimization_comsumer_limit_level( - parameter_consumer_limit_level - ) + self.proto.optimizer_placement_optimization_mode = "distributed_split" + self.proto.optimizer_placement_optimization_threshold = min_shard_size + self.proto.optimizer_placement_optimization_comsumer_limit_level = parameter_consumer_limit_level if stage >= 2: oneflow.boxing.nccl.enable_use_compute_stream(True) if stage >= 3: From cce8efd876d79fdb4c58ffc8a5ff02c45d5ab66a Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 27 May 2022 17:13:38 +0800 Subject: [PATCH 24/46] add test --- .../kernels/nccl_logical_send_recv_kernel.cpp | 28 ++--- .../test/graph/test_nccl_logical_send_recv.py | 103 ++++++++++++++++++ 2 files changed, 117 insertions(+), 14 deletions(-) create mode 100644 python/oneflow/test/graph/test_nccl_logical_send_recv.py diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index a663da3a51b..eff2a77e545 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -56,7 +56,7 @@ class NcclLogicalSendRecvState final : public user_op::OpKernelState { bool has_independent_stream_; std::string stream_name_; - ParallelDesc parallel_desc_; + std::unique_ptr parallel_desc_; mutable std::unique_ptr comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; @@ -65,30 +65,30 @@ class NcclLogicalSendRecvState final : public user_op::OpKernelState { std::vector recv_elem_cnts_; }; -NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* ctx) : -parallel_desc_(ctx->parallel_desc()){ +NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* ctx) { has_independent_stream_ = ctx->op_conf().has_stream_name_hint(); if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); } + const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + parallel_desc_ = std::make_unique(ctx->parallel_desc()); NdSbp src_nd_sbp; - NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "src_nd_sbp", &src_nd_sbp)); + NdSbp dst_nd_sbp; CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, "dst_nd_sbp", &dst_nd_sbp)); + const auto& parallel_hierarchy = parallel_desc_->hierarchy(); src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); - const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); - const auto& parallel_hierarchy = parallel_desc_.hierarchy(); CHECK_EQ(src_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes()); CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes()); const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); - + const DataType data_type = in_logical_desc->data_type(); const Shape& logical_shape = Shape(in_logical_desc->shape()); + const DeviceType device_type = parallel_desc_->device_type(); + const int64_t parallel_num = parallel_desc_->parallel_num(); + std::vector src_send_intersections; std::vector dst_recv_intersections; - GetSendRecvIntersection(parallel_id, parallel_desc_.hierarchy(), src_nd_sbp, dst_nd_sbp, + GetSendRecvIntersection(parallel_id, parallel_desc_->hierarchy(), src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); - const DataType data_type = in_logical_desc->data_type(); - const DeviceType device_type = parallel_desc_.device_type(); - const int64_t parallel_num = parallel_desc_.parallel_num(); CHECK_EQ(src_send_intersections.size(), parallel_num); send_elem_cnts_.resize(parallel_num); in_tensor_slice_copier_vec_.resize(parallel_num); @@ -120,9 +120,9 @@ parallel_desc_(ctx->parallel_desc()){ void NcclLogicalSendRecvState::InitComm() const { std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) { - int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); + for (int64_t parallel_id = 0; parallel_id < parallel_desc_->parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py new file mode 100644 index 00000000000..db0f8b2cbe1 --- /dev/null +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -0,0 +1,103 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import oneflow +import numpy as np +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +from oneflow.test_utils.automated_test_util import * +import time +import os + +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "1" +os.environ["LOGICAL_SR"] = "1" + + +def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): + # can not process p in dst + if flow.sbp.partial_sum() in dst_nd_sbp: + return + # skip src == dst + if src_nd_sbp == dst_nd_sbp: + return + # in this case, use intra group boxing + if src_nd_sbp[0] == dst_nd_sbp[0]: + return + # in this case, use inter group boxing + if ( + src_nd_sbp[1] == dst_nd_sbp[1] + and src_nd_sbp[0] != src_nd_sbp[1] + and src_nd_sbp[0] != src_nd_sbp[1] + ): + return + # in this case, use 1d boxing + if src_nd_sbp[0] == src_nd_sbp[1] and dst_nd_sbp[0] == dst_nd_sbp[1]: + return + placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + + flow.boxing.nccl.enable_use_compute_stream(True) + class TestNcclLogicalSendRecvGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=placement) + return y + + x = flow.tensor( + np.arange(12 * 16 * 16).reshape(12, 16, 16), + sbp=src_nd_sbp, + placement=placement, + ) + graph = TestNcclLogicalSendRecvGraph() + y = graph(x) + graph.debug(3) + print("graph repr:\n", graph) + test_case.assertTrue(np.array_equal(y.numpy(), x.numpy())) + + +def gen_nd_sbp(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + flow.sbp.split(2), + ] + nd_sbp_list = [] + for sbp0 in sbp_list: + for sbp1 in sbp_list: + nd_sbp_list.append([sbp0, sbp1]) + return nd_sbp_list + + +@flow.unittest.skip_unless_1n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestNcclLogicalSendRecv(flow.unittest.TestCase): + def test_nccl_logical_send_recv(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_nd_sbp() + arg_dict["dst_nd_sbp"] = gen_nd_sbp() + for arg in GenArgList(arg_dict): + _test_nccl_logical_send_recv(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() From a30b0c0228d12076e13597fc4b50ec0677948a9e Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 27 May 2022 22:41:45 +0800 Subject: [PATCH 25/46] debug bad case --- .../insert_nccl_logical_op_pass.cpp | 2 +- .../test/graph/test_nccl_logical_send_recv.py | 48 +++++++++++++++++-- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index d4033adeba1..a789368f8b9 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -489,7 +489,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( } if (Global::Get()->enable_debug_mode()) { - VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name + LOG(ERROR) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name << ", order=" << src_order << ", sbp=" << NdSbpToString(src_node->NdSbp4Lbi(lbi)) << "] to [" << dst_op_name << ", order=" << node2subgraph_order.at(dst_node) diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index db0f8b2cbe1..2ae7cb41203 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -34,12 +34,15 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): # can not process p in dst if flow.sbp.partial_sum() in dst_nd_sbp: return + # skip src == dst if src_nd_sbp == dst_nd_sbp: return + # in this case, use intra group boxing if src_nd_sbp[0] == dst_nd_sbp[0]: return + # in this case, use inter group boxing if ( src_nd_sbp[1] == dst_nd_sbp[1] @@ -47,9 +50,38 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): and src_nd_sbp[0] != src_nd_sbp[1] ): return + # in this case, use 1d boxing if src_nd_sbp[0] == src_nd_sbp[1] and dst_nd_sbp[0] == dst_nd_sbp[1]: return + + # bad case: S with P + if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(0): + return + if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(1): + return + if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(2): + return + if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.partial_sum: + return + if src_nd_sbp[0] == flow.sbp.split(1) and src_nd_sbp[1] == flow.sbp.partial_sum: + return + if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.partial_sum: + return + # bad case: diff S + if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.split(1): + return + if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.split(2): + return + if src_nd_sbp[0] == flow.sbp.split(1) and src_nd_sbp[1] == flow.sbp.split(2): + return + if src_nd_sbp[0] == flow.sbp.split(1) and src_nd_sbp[1] == flow.sbp.split(0): + return + if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.split(0): + return + if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.split(1): + return + placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) flow.boxing.nccl.enable_use_compute_stream(True) @@ -58,19 +90,27 @@ def __init__(self): super().__init__() def build(self, x): + # from src nd sbp to dst nd sbp y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y + in_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) x = flow.tensor( - np.arange(12 * 16 * 16).reshape(12, 16, 16), + in_np, sbp=src_nd_sbp, placement=placement, ) + graph = TestNcclLogicalSendRecvGraph() y = graph(x) - graph.debug(3) - print("graph repr:\n", graph) - test_case.assertTrue(np.array_equal(y.numpy(), x.numpy())) + out_np = y.numpy() + equal = np.array_equal(out_np, in_np) + if not equal: + print("graph repr:\n", graph) + print("in np:\n", in_np) + print("diff np:\n", out_np - in_np) + test_case.assertTrue(equal) + def gen_nd_sbp(): From c73013f4bb0dfb05dda2f8690eb87199d8dd850d Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 27 May 2022 23:19:22 +0800 Subject: [PATCH 26/46] refine test for eager and graph boxing --- .../test/graph/test_nccl_logical_send_recv.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index 2ae7cb41203..359c12739ae 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -54,8 +54,22 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): # in this case, use 1d boxing if src_nd_sbp[0] == src_nd_sbp[1] and dst_nd_sbp[0] == dst_nd_sbp[1]: return + + # input + placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + in_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) + x = flow.tensor( + in_np, + sbp=src_nd_sbp, + placement=placement, + ) + + # check eager boxing + test_case.assertTrue(np.array_equal(x.numpy(), in_np)) + eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) + assert np.array_equal(eager_out.numpy(), in_np) - # bad case: S with P + # bad case of graph: S with P if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(0): return if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(1): @@ -68,7 +82,7 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): return if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.partial_sum: return - # bad case: diff S + # bad case of graph: diff S if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.split(1): return if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.split(2): @@ -82,8 +96,7 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.split(1): return - placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) - + # check graph boxing flow.boxing.nccl.enable_use_compute_stream(True) class TestNcclLogicalSendRecvGraph(flow.nn.Graph): def __init__(self): @@ -93,14 +106,6 @@ def build(self, x): # from src nd sbp to dst nd sbp y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y - - in_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) - x = flow.tensor( - in_np, - sbp=src_nd_sbp, - placement=placement, - ) - graph = TestNcclLogicalSendRecvGraph() y = graph(x) out_np = y.numpy() From 08b1f692a3315052b3c8c8eeb0d4995351f590e2 Mon Sep 17 00:00:00 2001 From: strint Date: Sat, 28 May 2022 00:52:20 +0800 Subject: [PATCH 27/46] test case ready --- .../test/graph/test_nccl_logical_send_recv.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index 359c12739ae..1b9abd80bbc 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -57,17 +57,13 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): # input placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) - in_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) - x = flow.tensor( - in_np, - sbp=src_nd_sbp, - placement=placement, - ) + local_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) + # NOTE(strint): flow.tensor(numpy, sbp) is not valid when sbp contains partial_sum + x = flow.tensor(local_np).to_global(sbp=src_nd_sbp, placement=placement) # check eager boxing - test_case.assertTrue(np.array_equal(x.numpy(), in_np)) eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) - assert np.array_equal(eager_out.numpy(), in_np) + test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) # bad case of graph: S with P if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(0): @@ -103,17 +99,27 @@ def __init__(self): super().__init__() def build(self, x): - # from src nd sbp to dst nd sbp y = x.to_global(sbp=dst_nd_sbp, placement=placement) return y + graph = TestNcclLogicalSendRecvGraph() y = graph(x) out_np = y.numpy() + in_np = x.numpy() equal = np.array_equal(out_np, in_np) + # Debug log + if flow.env.get_rank() == 1: + if equal: + print("test boxing passed form ", src_nd_sbp, " to ", dst_nd_sbp) + else: + print("graph repr:\n", graph) + print("local in data:\n", x.to_local().numpy()) + print("local out data:\n", y.to_local().numpy()) + print("global in np:\n", in_np) + print("global out np:\n", out_np) + print("global diff np:\n", out_np - in_np) if not equal: - print("graph repr:\n", graph) - print("in np:\n", in_np) - print("diff np:\n", out_np - in_np) + print("error rank: ", flow.env.get_rank()) test_case.assertTrue(equal) From 821a8f4e434d66fb4248ebbd20738ca148d41df5 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 30 May 2022 10:36:37 +0800 Subject: [PATCH 28/46] simplify --- oneflow/core/job/nd_sbp_util.cpp | 65 ++++++++----------- oneflow/core/job/nd_sbp_util.h | 7 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 4 +- 3 files changed, 34 insertions(+), 42 deletions(-) diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index 3a2224061bf..a4aa997aca3 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -138,11 +138,9 @@ void DfsTraverseRanks4NdSbp( int32_t depth, std::vector& in_parallel_ids, const std::vector& out_parallel_ids, const Shape& parallel_hierarchy, const NdIndexOffsetHelper& hierarchy_index_helper, - const NdSbp& in_nd_sbp, const std::function& visit) { + const NdSbp& in_nd_sbp, const std::function& visit) { if (depth >= parallel_hierarchy.NumAxes()) { - visit(hierarchy_index_helper.NdIndexToOffset(out_parallel_ids.data(), - parallel_hierarchy.NumAxes()), - hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), + visit(hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), parallel_hierarchy.NumAxes())); return; } @@ -162,64 +160,57 @@ void DfsTraverseRanks4NdSbp( } } -void DfsTraverse4NdSbp(int64_t out_id, const std::shared_ptr parallel_hierarchy, - const NdSbp& in_nd_sbp, const std::function& visit) { +void DfsTraverse4NdSbp(int64_t recv_id, const std::shared_ptr& parallel_hierarchy, + const NdSbp& in_nd_sbp, const std::function& visit) { int32_t hierarchy_dimension = parallel_hierarchy->NumAxes(); const NdIndexOffsetHelper hierarchy_index_helper( parallel_hierarchy->dim_vec().data(), hierarchy_dimension); std::vector in_parallel_ids(hierarchy_dimension); std::vector out_parallel_ids(hierarchy_dimension); - hierarchy_index_helper.OffsetToNdIndex(out_id, out_parallel_ids.data(), hierarchy_dimension); + hierarchy_index_helper.OffsetToNdIndex(recv_id, out_parallel_ids.data(), hierarchy_dimension); DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *parallel_hierarchy, hierarchy_index_helper, in_nd_sbp, visit); } -bool NdSbpNoPartialParallel(const NdSbp& nd_sbp) { - CHECK_GT(nd_sbp.sbp_parallel_size(), 0); - FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { - if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; } - } - return true; -} - } // namespace -void GetSendRecvIntersection(int64_t parallel_id, const std::shared_ptr parallel_hierarchy, +void GetRankSendRecvIntersection(int64_t parallel_id, const std::shared_ptr& parallel_hierarchy, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const Shape& logical_shape, - std::vector* src_send_intersections, - std::vector* dst_recv_intersections) { + std::vector* send_intersections, + std::vector* recv_intersections) { CHECK(parallel_hierarchy); const int64_t parallel_num = parallel_hierarchy->elem_cnt(); CHECK_LT(parallel_id, parallel_num); - const std::vector& out_slices = - GetTensorSliceView(*parallel_hierarchy, dst_nd_sbp, logical_shape); const std::vector& in_slices = GetTensorSliceView(*parallel_hierarchy, src_nd_sbp, logical_shape); + const std::vector& out_slices = + GetTensorSliceView(*parallel_hierarchy, dst_nd_sbp, logical_shape); - // cur_out_slice recv from - dst_recv_intersections->resize(parallel_num); + // cur rank recv from + recv_intersections->resize(parallel_num); const TensorSliceView& cur_rank_out_slice = out_slices.at(parallel_id); - const auto& add_to_dst_recv_intersections = [&](int32_t out_id, int32_t in_id) { - CHECK_EQ(out_id, parallel_id); - const TensorSliceView& in_slice = in_slices.at(in_id); + const auto& add_to_recv_intersections = [&](int32_t send_id) { + const TensorSliceView& in_slice = in_slices.at(send_id); const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); - dst_recv_intersections->at(in_id) = intersection; + if (intersection.IsEmpty()) { return; } + recv_intersections->at(send_id) = intersection; }; - DfsTraverse4NdSbp(parallel_id, parallel_hierarchy, src_nd_sbp, add_to_dst_recv_intersections); + DfsTraverse4NdSbp(parallel_id, parallel_hierarchy, src_nd_sbp, add_to_recv_intersections); - // cur_in_slice send to - src_send_intersections->resize(parallel_num); + // cur rank send to + send_intersections->resize(parallel_num); const TensorSliceView& cur_rank_in_slice = in_slices.at(parallel_id); - const auto& add_to_src_send_intersections = [&](int32_t out_id, int32_t in_id) { - if (in_id != parallel_id) { return; } - const TensorSliceView& out_slice = out_slices.at(out_id); - const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); - src_send_intersections->at(out_id) = intersection; - }; - for (int64_t i = 0; i < parallel_num; ++i) { - DfsTraverse4NdSbp(i, parallel_hierarchy, src_nd_sbp, add_to_src_send_intersections); + for (int64_t recv_i = 0; recv_i < parallel_num; ++recv_i) { + const auto& add_to_send_intersections = [&](int32_t send_id) { + if (send_id != parallel_id) { return; } + const TensorSliceView& out_slice = out_slices.at(recv_i); + const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); + if (intersection.IsEmpty()) { return; } + send_intersections->at(recv_i) = intersection; + }; + DfsTraverse4NdSbp(recv_i, parallel_hierarchy, src_nd_sbp, add_to_send_intersections); } } diff --git a/oneflow/core/job/nd_sbp_util.h b/oneflow/core/job/nd_sbp_util.h index abeddf09066..731eddd1aca 100644 --- a/oneflow/core/job/nd_sbp_util.h +++ b/oneflow/core/job/nd_sbp_util.h @@ -35,11 +35,12 @@ TensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc); bool NdSbpHasPartialParallel(const NdSbp& nd_sbp); -void GetSendRecvIntersection(int64_t parallel_id, const std::shared_ptr parallel_hierarchy, + +void GetRankSendRecvIntersection(int64_t parallel_id, const std::shared_ptr& parallel_hierarchy, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, const Shape& logical_shape, - std::vector* src_send_intersections, - std::vector* dst_recv_intersections); + std::vector* send_intersections, + std::vector* recv_intersections); } // namespace oneflow diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index eff2a77e545..331e2974357 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -86,7 +86,7 @@ NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* c std::vector src_send_intersections; std::vector dst_recv_intersections; - GetSendRecvIntersection(parallel_id, parallel_desc_->hierarchy(), src_nd_sbp, dst_nd_sbp, + GetRankSendRecvIntersection(parallel_id, parallel_desc_->hierarchy(), src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); CHECK_EQ(src_send_intersections.size(), parallel_num); @@ -262,7 +262,7 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { std::vector src_send_intersections; std::vector dst_recv_intersections; - GetSendRecvIntersection(parallel_id, ctx->parallel_desc().hierarchy(), src_nd_sbp, dst_nd_sbp, + GetRankSendRecvIntersection(parallel_id, ctx->parallel_desc().hierarchy(), src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); int64_t buf_count = 0; CHECK_EQ(src_send_intersections.size(), parallel_num); From 29079a003b5252a1562ad21921021aa080c93731 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 30 May 2022 11:00:34 +0800 Subject: [PATCH 29/46] refine test --- python/oneflow/test/graph/test_nccl_logical_send_recv.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index 1b9abd80bbc..f8d875fd2f7 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -58,8 +58,7 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): # input placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) local_np = np.arange(4 * 4 * 4).reshape(4, 4, 4) - # NOTE(strint): flow.tensor(numpy, sbp) is not valid when sbp contains partial_sum - x = flow.tensor(local_np).to_global(sbp=src_nd_sbp, placement=placement) + x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) # check eager boxing eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) @@ -67,7 +66,7 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): # bad case of graph: S with P if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(0): - return + pass if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(1): return if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(2): From e49d38073e2f78203b5261e97f0ee6ec93e90bc3 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 30 May 2022 13:16:49 +0800 Subject: [PATCH 30/46] fix buff size --- .../insert_nccl_logical_op_pass.cpp | 2 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 5 ++- .../test/graph/test_nccl_logical_send_recv.py | 43 +------------------ 3 files changed, 6 insertions(+), 44 deletions(-) diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index a789368f8b9..d4033adeba1 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -489,7 +489,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( } if (Global::Get()->enable_debug_mode()) { - LOG(ERROR) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name + VLOG(3) << " insert nccl op: " << nccl_op.name() << " from [" << src_op_name << ", order=" << src_order << ", sbp=" << NdSbpToString(src_node->NdSbp4Lbi(lbi)) << "] to [" << dst_op_name << ", order=" << node2subgraph_order.at(dst_node) diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 331e2974357..8bfc9a69dbd 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/parallel_desc.h" @@ -254,6 +256,7 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { const Shape* out_shape = ctx->OutputShape("out", 0); const user_op::TensorDesc* logical_in_tensor = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); const Shape& logical_shape = logical_in_tensor->shape(); + const DataType data_type = logical_in_tensor->data_type(); const NdSbp& src_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); const NdSbp& dst_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); @@ -278,7 +281,7 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. buf_count += out_shape->elem_cnt(); } - return buf_count; + return buf_count * GetSizeOfDataType(data_type); } REGISTER_USER_KERNEL("_nccl_logical_send_recv") diff --git a/python/oneflow/test/graph/test_nccl_logical_send_recv.py b/python/oneflow/test/graph/test_nccl_logical_send_recv.py index f8d875fd2f7..288bd28266e 100644 --- a/python/oneflow/test/graph/test_nccl_logical_send_recv.py +++ b/python/oneflow/test/graph/test_nccl_logical_send_recv.py @@ -64,33 +64,6 @@ def _test_nccl_logical_send_recv(test_case, src_nd_sbp, dst_nd_sbp): eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) - # bad case of graph: S with P - if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(0): - pass - if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(1): - return - if src_nd_sbp[0] == flow.sbp.partial_sum and src_nd_sbp[1] == flow.sbp.split(2): - return - if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.partial_sum: - return - if src_nd_sbp[0] == flow.sbp.split(1) and src_nd_sbp[1] == flow.sbp.partial_sum: - return - if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.partial_sum: - return - # bad case of graph: diff S - if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.split(1): - return - if src_nd_sbp[0] == flow.sbp.split(0) and src_nd_sbp[1] == flow.sbp.split(2): - return - if src_nd_sbp[0] == flow.sbp.split(1) and src_nd_sbp[1] == flow.sbp.split(2): - return - if src_nd_sbp[0] == flow.sbp.split(1) and src_nd_sbp[1] == flow.sbp.split(0): - return - if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.split(0): - return - if src_nd_sbp[0] == flow.sbp.split(2) and src_nd_sbp[1] == flow.sbp.split(1): - return - # check graph boxing flow.boxing.nccl.enable_use_compute_stream(True) class TestNcclLogicalSendRecvGraph(flow.nn.Graph): @@ -105,21 +78,7 @@ def build(self, x): y = graph(x) out_np = y.numpy() in_np = x.numpy() - equal = np.array_equal(out_np, in_np) - # Debug log - if flow.env.get_rank() == 1: - if equal: - print("test boxing passed form ", src_nd_sbp, " to ", dst_nd_sbp) - else: - print("graph repr:\n", graph) - print("local in data:\n", x.to_local().numpy()) - print("local out data:\n", y.to_local().numpy()) - print("global in np:\n", in_np) - print("global out np:\n", out_np) - print("global diff np:\n", out_np - in_np) - if not equal: - print("error rank: ", flow.env.get_rank()) - test_case.assertTrue(equal) + test_case.assertTrue(np.array_equal(out_np, in_np)) From 3fc1821c1780f2f451d26381a1a3d009edb7faba Mon Sep 17 00:00:00 2001 From: strint Date: Wed, 1 Jun 2022 16:56:44 +0800 Subject: [PATCH 31/46] fix conflict --- .../job_rewriter/insert_nccl_logical_op_pass.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 2e3a3c9e29e..b918f8ea6b8 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -434,10 +434,22 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const src_reduced_hierarchy, lbn, scope_symbol_id, logical_blob_desc); } + } + if (!got_nccl) { + got_nccl = TryBuildNcclBy2DHierarchyOthers(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp, + src_reduced_hierarchy, lbn, scope_symbol_id, + logical_blob_desc); + } + VLOG_IF(3, !got_nccl) << "Cannot get nccl logical op for 2D sbp, src nd sbp " + << NdSbpToString(*src_reduced_nd_sbp) << ", dst nd sbp " + << NdSbpToString(*dst_reduced_nd_sbp) << "."; + return got_nccl; + } return false; } bool ReverseOrderInsertNcclLogicalOps() { + return Global::Get()->resource().disable_group_boxing_by_dst_parallel(); } void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( From 79e1290fa4836cb60dc67b44a5e470b72efee7ba Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 2 Jun 2022 02:06:41 +0800 Subject: [PATCH 32/46] refine zero nd --- oneflow/core/job/job_conf.proto | 2 +- .../optimizer_placement_optimization_pass.cpp | 126 +++++++++++++----- python/oneflow/nn/graph/graph_config.py | 14 +- python/oneflow/test/graph/test_graph_zero.py | 9 +- 4 files changed, 103 insertions(+), 48 deletions(-) diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 3be7d5447da..03638feec30 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -211,7 +211,7 @@ message JobConfigProto { optional bool enable_gradients_stats_aggregation = 106 [default = true]; optional string optimizer_placement_optimization_mode = 107; optional int64 optimizer_placement_optimization_threshold = 108 [default = 1024]; - optional int64 optimizer_placement_optimization_comsumer_limit_level = 110 [default = 2]; + optional int64 optimizer_placement_optimization_shard_restore_level = 110 [default = 2]; optional QatConfig qat_config = 109; diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 0cd7f990438..9b582ded571 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -13,8 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/user_op_conf.h" +#include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" @@ -57,6 +59,14 @@ class DataParallelNodeSequence final { int64_t len() const { return len_; } + bool resize(const int64_t size) { + if (size > len_) { return false; } + if (size <= 1) { return false; } + nodes_.resize(size); + len_ = nodes().size(); + return true; + } + private: std::vector nodes_; int64_t order_; @@ -64,7 +74,7 @@ class DataParallelNodeSequence final { int64_t len_; }; -using SequencePtr = std::shared_ptr; +using SequencePtr = std::shared_ptr; ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd, const int64_t parallel_id) { @@ -92,8 +102,6 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->in_edges().size() > 1) { break; } if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); - VLOG(3) << cur_node->op().op_name() - << " has sbp: " << cur_node->NdSbp4BnInOp(sole_ibn).DebugString(); const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn); if (ibn_nd_sbp.sbp_parallel_size() == 0) { break; } bool has_broadcast = false; @@ -102,11 +110,9 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( } if (!has_broadcast) { break; } } - if (!IsAllowed(cur_node)) { break; } + //if (!IsAllowed(cur_node)) { break; } if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); - VLOG(3) << cur_node->op().op_name() - << " has sbp: " << cur_node->NdSbp4BnInOp(sole_obn).DebugString(); const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); bool has_broadcast = false; FOR_RANGE(int, i, 0, obn_nd_sbp.sbp_parallel_size()) { @@ -156,15 +162,15 @@ void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::stri void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { const OpNode* node = sequence->GetLastNode(); const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); - const int64_t limit_consumer_mode = - builder->job().job_conf().optimizer_placement_optimization_comsumer_limit_level(); - // If limit_consumer_mode == 0, no limit on consumer - if (limit_consumer_mode == 1) { - // input lbn for parallel cast op + const int64_t shard_restore_level = + builder->job().job_conf().optimizer_placement_optimization_shard_restore_level(); + // If shard_restore_level == 0, no limit on consumer + if (shard_restore_level == 1) { + // Input lbn for parallel cast op std::string parallel_cast_input_lbn = GenLogicalBlobName(lbi); // Add indentity to enable mem reuse of boxing op when there is no op between var op and boxing. if (sequence->len() == 1) { - LOG(ERROR) << "ZeRO find a data-parallel sequence only has one variable " << sequence->GetVariableNode()->op().op_name(); + VLOG(3) << "ZeRO find a data-parallel sequence only has one variable " << sequence->GetVariableNode()->op().op_name(); const auto var_identity_op = user_op::UserOpConfWrapperBuilder("System-ZeRO-Identity-" + node->op().op_name() + "-" + NewUniqueId()) .Op("identity") @@ -204,9 +210,8 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const } } }); - } else if (limit_consumer_mode == 2) { + } else if (shard_restore_level == 2) { // Hard limt consumer to consume weight as Broadcast. - // Default is 2. node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { for (const std::string& ibn : out_node->op().input_bns()) { if (out_node->op().BnInOp2Lbi(ibn) == lbi) { @@ -246,7 +251,7 @@ void ForEachDataParallelNodeSequence(const OpGraph& op_graph, CHECK_JUST(GetDataParallelVariableAndNaiveSuccNode(node, IsAllowed, &nodes)); if (nodes.empty()) { return; } const int64_t order = GetMinConsumerOrder(op_graph, nodes.back(), OpNode2Order); - Handler(std::make_shared(std::move(nodes), order)); + Handler(std::make_shared(std::move(nodes), order)); }); } @@ -282,6 +287,24 @@ bool IsS0Parallel(const SbpSignature& signature, const std::string& bn) { return IsS0Parallel(signature.bn_in_op2sbp_parallel().at(bn)); } +bool IsNdSbpMatch(const NdSbpSignature& signature, const std::string& bn, const NdSbp& nd_sbp) { + return signature.bn_in_op2nd_sbp().at(bn) == nd_sbp; +} + +bool IsNdSbpSupported4Op(const OpNode* node, const NdSbp& nd_sbp) { + if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; } + std::vector list; + auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { + return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); + }; + CHECK_JUST(node->op().GetNdSbpSignatureList(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); + const auto IsInAndOutMatch= [&](const NdSbpSignature& signature) { + return IsNdSbpMatch(signature, node->op().SoleIbn(), nd_sbp) + && IsNdSbpMatch(signature, node->op().SoleObn(), nd_sbp); + }; + return std::any_of(list.cbegin(), list.cend(), IsInAndOutMatch); +} + bool IsS0SignatureSupported(const OpNode* node) { if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; } SbpSignatureList list; @@ -318,16 +341,9 @@ void ForEachModelSizeBalancedPartition( Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); - const auto IsAllowed = [threshold](const OpNode* n) -> bool { - if (n->op().op_conf().has_variable_conf()) { - const Shape shape(n->op().op_conf().variable_conf().shape()); - const int64_t parallel_num = n->parallel_desc().parallel_num(); - // TODO(strint): zero with nd check size - // Parameter needs to be able to evenly splited and one slice size >= threshold - return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; - } else { - return IsS0SignatureSupported(n); - } + const auto IsAllowed = [](const OpNode* n) -> bool { + // No need to limit here. + return true; }; const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, std::vector&& sorted_sequences) { @@ -335,6 +351,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder // and add ctrl edge to control the exectuion order between variable ops. // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass // consume the fp16 variable and the optimizer consume the fp32 variable. + std::string prev_allowed_op_name = ""; for (int64_t i = 0; i < sorted_sequences.size(); ++i) { const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); OperatorConf new_var_op_conf = var_node->op().op_conf(); @@ -342,11 +359,14 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); std::string new_split_signature = ""; int64_t split_dim = 0; - if (new_var_op_conf.variable_conf().nd_sbp_size() == 1 - && new_var_op_conf.variable_conf().nd_sbp(0) == "B") { + if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { + // split last dim + split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; + // All B, B -> S0 new_split_signature = "S(0)"; - split_dim = 0; } else { + // ND sbp, (*, B, S, *) -> (*, S, S, *) + // ND sbp, (*, S, B, *) -> (*, S, S, *) FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { std::vector adjacent_dim{j - 1, j + 1}; @@ -369,24 +389,60 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder } if (new_split_signature != "") { *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature; - VLOG(3) << var_node->op().op_name() << " succeed to change form B to " - << new_split_signature << " on ranks dim " << split_dim << " with op conf " - << new_var_op_conf.variable_conf().DebugString(); } else { + continue; + } + + bool split_is_allowed = true; + if (split_is_allowed) { + NdSbp new_nd_sbp; + std::vector nd_sbp_str_vec; + for (const auto& sbp_str : new_var_op_conf.variable_conf().nd_sbp()) { + nd_sbp_str_vec.push_back(sbp_str); + } + ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); + // check allowed by min shard size and evenly split + const auto slices = GetTensorSliceView(*pd.hierarchy(), new_nd_sbp, Shape(new_var_op_conf.variable_conf().shape())); + if (slices.size() < 2) { split_is_allowed = false; } + if (split_is_allowed && slices.at(0).shape().elem_cnt() < threshold) { split_is_allowed = false; } + if (split_is_allowed) { + FOR_RANGE(int64_t, slice_idx, 1, slices.size()) { + if (slices.at(slice_idx).shape() != slices.at(0).shape()) { split_is_allowed = false; break;} + } + } + if (split_is_allowed) { + // resize sequence by new nd sbp limit + auto& cur_seq = sorted_sequences.at(i); + int64_t max_len = 1; + if (cur_seq->len() > 1) { + FOR_RANGE(int64_t, node_idx, 1, cur_seq->len()) { + if (IsNdSbpSupported4Op(cur_seq->nodes().at(node_idx), new_nd_sbp)) { + ++max_len; + } else { + break; + } + } + } + if (max_len < cur_seq->len()) { cur_seq->resize(max_len); } + } + } + if (!split_is_allowed) { VLOG(3) << var_node->op().op_name() << " failed to change form B to S " << " with op conf " << new_var_op_conf.variable_conf().DebugString(); + continue; } if (i != 0) { - const std::string& prev_op_name = - sorted_sequences.at(i - 1)->GetVariableNode()->op().op_name(); - new_var_op_conf.add_ctrl_in_op_name(prev_op_name); + new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); } - // TODO(strint): rewrite with MutOpTransactioin builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. if (new_split_signature != "") { SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } + prev_allowed_op_name = var_node->op().op_name(); + VLOG(3) << var_node->op().op_name() << " succeed to change form B to " + << new_split_signature << " on ranks dim " << split_dim << " with op conf " + << new_var_op_conf.variable_conf().DebugString(); } }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 72ff402b4fd..fdbac8eba19 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -77,8 +77,8 @@ def enable_zero( mode: bool = True, *, stage: int = 2, - min_shard_size: int = 1024, - parameter_consumer_limit_level: int = 2, + shard_min_size: int = 1024, + shard_restore_level: int = 1, ): r"""Enable ZeRO redundancy optimizer. @@ -106,21 +106,21 @@ def build(self, x): Args: mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. stage (int): optimization stage, range from 1 to 3. - min_shard_size (int): min size of a shard of an optimizer state. - parameter_consumer_limit_level (int): limit consumer to comsume sharded parameter with Broadcast, level 2 is hard limit, level 1 is soft limit, level 0 is no limit. Note that this paremeter is at pre-alpha stage and is not stable. + shard_min_size (int): min size of a shard of an optimizer state. + shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this paremeter is at pre-alpha stage. """ if not mode: self.proto.optimizer_placement_optimization_mode = "none" return assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." assert ( - min_shard_size > 0 + shard_min_size > 0 ), "ZeRO min size of a sharded optimizer state must > 0." assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." if stage >= 1: self.proto.optimizer_placement_optimization_mode = "distributed_split" - self.proto.optimizer_placement_optimization_threshold = min_shard_size - self.proto.optimizer_placement_optimization_comsumer_limit_level = parameter_consumer_limit_level + self.proto.optimizer_placement_optimization_threshold = shard_min_size + self.proto.optimizer_placement_optimization_shard_restore_level = shard_restore_level if stage >= 2: oneflow.boxing.nccl.enable_use_compute_stream(True) if stage >= 3: diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 50502e21677..f3deb4f73a2 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -56,8 +56,8 @@ def __init__(self): self.config.enable_zero( True, stage=zero_stage, - min_shard_size=1, - parameter_consumer_limit_level=0, + shard_min_size=1, + shard_restore_level=0, ) self.debug(2) @@ -158,7 +158,6 @@ def forward(self, x): x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) - #flow.boxing.nccl.enable_use_compute_stream(True) class LinearTrainGraph2DWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() @@ -173,8 +172,8 @@ def __init__(self): self.config.enable_zero( True, stage=zero_stage, - min_shard_size=1, - parameter_consumer_limit_level=1, + shard_min_size=1, + shard_restore_level=1, ) self.debug(1) From 322504538811e25a0c867f819aeb8aed419fffbe Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 2 Jun 2022 02:24:06 +0800 Subject: [PATCH 33/46] refine --- oneflow/core/job/job_build_and_infer_ctx.cpp | 2 -- .../optimizer_placement_optimization_pass.cpp | 8 +++----- python/oneflow/nn/graph/graph_config.py | 6 +++--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 4e98e017384..839d9f411ca 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -13,8 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "glog/logging.h" -#include "oneflow/api/python/env/env.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/config_def.h" diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 9b582ded571..3a506cc2df8 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/nd_sbp_util.h" @@ -59,12 +58,11 @@ class DataParallelNodeSequence final { int64_t len() const { return len_; } - bool resize(const int64_t size) { - if (size > len_) { return false; } - if (size <= 1) { return false; } + void resize(const int64_t size) { + CHECK(size <= len_); + CHECK(size > 1); nodes_.resize(size); len_ = nodes().size(); - return true; } private: diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index fdbac8eba19..1c5775c97b5 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -17,7 +17,7 @@ from collections import OrderedDict -import oneflow +import oneflow.boxing.nccl as nccl_config from oneflow.nn.graph.optimizer import OptDict import oneflow.core.job.job_conf_pb2 as job_conf_pb @@ -122,9 +122,9 @@ def build(self, x): self.proto.optimizer_placement_optimization_threshold = shard_min_size self.proto.optimizer_placement_optimization_shard_restore_level = shard_restore_level if stage >= 2: - oneflow.boxing.nccl.enable_use_compute_stream(True) + nccl_config.enable_use_compute_stream(True) if stage >= 3: - oneflow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) + nccl_config.disable_group_boxing_by_dst_parallel(True) def allow_fuse_model_update_ops(self, mode: bool = True): r"""If set to true, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance. From c75143553ec5fbce26599e2b8c8739c388ef98d7 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 2 Jun 2022 22:03:40 +0800 Subject: [PATCH 34/46] add full test --- oneflow/core/job/plan_util.cpp | 57 ++++++++++++++++++++ python/oneflow/test/graph/test_graph_zero.py | 32 +++++------ 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index a2d3fc2f8da..50e5aed140b 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -879,12 +879,69 @@ void PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) { } } +<<<<<<< Updated upstream for (auto pair : rank_device2size) { int64_t rank_id = pair.first.first; int64_t device_id = pair.first.second; double mem_size = pair.second * 1.0 / 1000000.0; LOG(INFO) << "Graph name " << plan_name << " needs to allocate [ " << mem_size << " MiB ] device memory in Rank: " << rank_id << " , Device: " << device_id << "."; +======= + for (const auto* task : ordered_tasks) { + for (const auto& pair : task->produced_regst_desc()) { + const auto& regst = pair.second; + if (regst.regst_desc_type().has_data_regst_desc() + && mem_block_id2info.find(regst.mem_block_id()) != mem_block_id2info.end()) { + const auto data_regst = regst.regst_desc_type().data_regst_desc(); + std::string op_name = data_regst.lbi2blob_desc(0).lbi().op_name(); + mem_block_id2info.at(regst.mem_block_id()).ordered_op_names.push_back(op_name); + } + } + } + + auto CompMemBlock = [&](int64_t a, int64_t b) { + return mem_block_id2info[a].mem_block_mem_size > mem_block_id2info[b].mem_block_mem_size; + }; + + auto B2MiB = [](int64_t val) { return val * 1.0 / 1000000.0; }; + + for (auto& rank_memory_info : rank_device_memory_infos) { + std::sort(rank_memory_info.chunk_info.mem_block_ids.begin(), + rank_memory_info.chunk_info.mem_block_ids.end(), CompMemBlock); + LOG(ERROR) << " Graph name " << plan_name << " in Rank: " << rank_memory_info.rank_id + << ", Device: " << rank_memory_info.device_id << " needs to allocate [ " + << B2MiB(rank_memory_info.total_mem_size) + << " MiB ] device memory. \n In general, Chunk id: " + << rank_memory_info.chunk_info.chunk_id << " memory is [ " + << B2MiB(rank_memory_info.chunk_info.chunk_mem_size) + << " MiB ]; \n Memory out of Chunk is [ " + << B2MiB(rank_memory_info.not_reused_mem_size) + << " MiB ]; and in particular: Eager Variable Tensor total memory is [ " + << B2MiB(rank_memory_info.eager_variable_total_mem_size) << " MiB ]."; + } + + if (IsInDebugMode()) { + for (const auto& rank_memory_info : rank_device_memory_infos) { + int64_t chunk_id = rank_memory_info.chunk_info.chunk_id; + VLOG(2) << " For detail: Chunk id: " << chunk_id << " has " + << rank_memory_info.chunk_info.mem_block_ids.size() << " MemBlocks."; + for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) { + CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end()); + const auto& mem_block_info = mem_block_id2info.at(mem_block_id); + VLOG(2) << " In Chunk id: " << chunk_id << " MemBlock id: " << mem_block_id + << " has num = " << mem_block_info.ordered_op_names.size() + << " ops with mem size = " << B2MiB(mem_block_info.mem_block_mem_size); + } + for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) { + CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end()); + const auto& mem_block_info = mem_block_id2info.at(mem_block_id); + for (int64_t i = 0; i < mem_block_info.ordered_op_names.size(); ++i) { + VLOG(3) << " In MemBlock id: " << mem_block_id << " order: " << i + << " op_name: " << mem_block_info.ordered_op_names.at(i); + } + } + } +>>>>>>> Stashed changes } } diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index f3deb4f73a2..d3e39be629c 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -122,11 +122,11 @@ def train_with_graph(iter_num=1): def get_mixed_linear(): linear_dp_mp = flow.nn.Linear(800, 400, bias=False) linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) - flow.nn.init.constant_(linear_dp_mp.weight, 2.068758) + flow.nn.init.constant_(linear_dp_mp.weight, 1.068758) linear_mp_dp = flow.nn.Linear(800, 400, bias=False) linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) - flow.nn.init.constant_(linear_mp_dp.weight, 2.068758) + flow.nn.init.constant_(linear_mp_dp.weight, 1.068758) class MixedLinear(flow.nn.Module): def __init__(self): @@ -156,7 +156,7 @@ def forward(self, x): ) grad_scaler = flow.amp.StaticGradScaler(200) - x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) + x = flow.rand((2, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) class LinearTrainGraph2DWithZeRO(flow.nn.Graph): def __init__(self): @@ -175,14 +175,13 @@ def __init__(self): shard_min_size=1, shard_restore_level=1, ) - self.debug(1) def build(self, x): out = self.mixed_linear0(x) out = self.mixed_linear1(out) - loss = out.sum() + loss = out.mean() loss.backward() - return out + return loss class LinearEvalGraph2DWithZeRO(flow.nn.Graph): def __init__(self): @@ -198,14 +197,12 @@ def build(self, x): return out linear_t_g = LinearTrainGraph2DWithZeRO() - linear_t_g.debug(1) linear_e_g = LinearEvalGraph2DWithZeRO() - linear_e_g.debug(1) def one_train_iter(): out = linear_t_g(x) - if flow.env.get_rank() == 0: - print(linear_t_g) + #if flow.env.get_rank() == 0: + # print(linear_t_g) def one_eval_iter(): out = linear_e_g(x) @@ -213,10 +210,8 @@ def one_eval_iter(): for i in range(iter_num): one_train_iter() - # After pass rewrite in training graph, parameters' sbp has been - # changed from flow.sbp.broadcast to flow.sbp.split(0) - # test_case.assertEqual(linear_dp.weight.sbp[0], S0) - # test_case.assertEqual(linear_mp.weight.sbp[0], S0) + for state in linear_t_g._state(): + test_case.assertEqual(state.origin.sbp, (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=0))) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. @@ -242,9 +237,16 @@ def _test_linear_train_graph_with_zero_3(test_case): @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n4d() class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): - def test_linear_train_graph_2d_with_zero_1(test_case): + def test_linear_train_graph_2d_with_zero_3(test_case): + _test_linear_train_graph_2d_with_zero(test_case, 3) + + def test_linear_train_graph_2d_with_zero_2(test_case): _test_linear_train_graph_2d_with_zero(test_case, 2) + def test_linear_train_graph_2d_with_zero_1(test_case): + _test_linear_train_graph_2d_with_zero(test_case, 1) + + if __name__ == "__main__": unittest.main() From 5c78921bdf840cacf41db3a5cde945d4f4593ad0 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 2 Jun 2022 22:23:52 +0800 Subject: [PATCH 35/46] revert change --- oneflow/core/job/plan_util.cpp | 57 ---------------------------------- 1 file changed, 57 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 50e5aed140b..a2d3fc2f8da 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -879,69 +879,12 @@ void PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) { } } -<<<<<<< Updated upstream for (auto pair : rank_device2size) { int64_t rank_id = pair.first.first; int64_t device_id = pair.first.second; double mem_size = pair.second * 1.0 / 1000000.0; LOG(INFO) << "Graph name " << plan_name << " needs to allocate [ " << mem_size << " MiB ] device memory in Rank: " << rank_id << " , Device: " << device_id << "."; -======= - for (const auto* task : ordered_tasks) { - for (const auto& pair : task->produced_regst_desc()) { - const auto& regst = pair.second; - if (regst.regst_desc_type().has_data_regst_desc() - && mem_block_id2info.find(regst.mem_block_id()) != mem_block_id2info.end()) { - const auto data_regst = regst.regst_desc_type().data_regst_desc(); - std::string op_name = data_regst.lbi2blob_desc(0).lbi().op_name(); - mem_block_id2info.at(regst.mem_block_id()).ordered_op_names.push_back(op_name); - } - } - } - - auto CompMemBlock = [&](int64_t a, int64_t b) { - return mem_block_id2info[a].mem_block_mem_size > mem_block_id2info[b].mem_block_mem_size; - }; - - auto B2MiB = [](int64_t val) { return val * 1.0 / 1000000.0; }; - - for (auto& rank_memory_info : rank_device_memory_infos) { - std::sort(rank_memory_info.chunk_info.mem_block_ids.begin(), - rank_memory_info.chunk_info.mem_block_ids.end(), CompMemBlock); - LOG(ERROR) << " Graph name " << plan_name << " in Rank: " << rank_memory_info.rank_id - << ", Device: " << rank_memory_info.device_id << " needs to allocate [ " - << B2MiB(rank_memory_info.total_mem_size) - << " MiB ] device memory. \n In general, Chunk id: " - << rank_memory_info.chunk_info.chunk_id << " memory is [ " - << B2MiB(rank_memory_info.chunk_info.chunk_mem_size) - << " MiB ]; \n Memory out of Chunk is [ " - << B2MiB(rank_memory_info.not_reused_mem_size) - << " MiB ]; and in particular: Eager Variable Tensor total memory is [ " - << B2MiB(rank_memory_info.eager_variable_total_mem_size) << " MiB ]."; - } - - if (IsInDebugMode()) { - for (const auto& rank_memory_info : rank_device_memory_infos) { - int64_t chunk_id = rank_memory_info.chunk_info.chunk_id; - VLOG(2) << " For detail: Chunk id: " << chunk_id << " has " - << rank_memory_info.chunk_info.mem_block_ids.size() << " MemBlocks."; - for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) { - CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end()); - const auto& mem_block_info = mem_block_id2info.at(mem_block_id); - VLOG(2) << " In Chunk id: " << chunk_id << " MemBlock id: " << mem_block_id - << " has num = " << mem_block_info.ordered_op_names.size() - << " ops with mem size = " << B2MiB(mem_block_info.mem_block_mem_size); - } - for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) { - CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end()); - const auto& mem_block_info = mem_block_id2info.at(mem_block_id); - for (int64_t i = 0; i < mem_block_info.ordered_op_names.size(); ++i) { - VLOG(3) << " In MemBlock id: " << mem_block_id << " order: " << i - << " op_name: " << mem_block_info.ordered_op_names.at(i); - } - } - } ->>>>>>> Stashed changes } } From bfa726c55c63ff61fb7bdd94d92fb6becb6fe5b7 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 3 Jun 2022 01:12:49 +0800 Subject: [PATCH 36/46] refine split check --- .../optimizer_placement_optimization_pass.cpp | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 3a506cc2df8..028c1f08401 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -13,7 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "oneflow/core/common/util.h" +#include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/sbp_parallel.h" @@ -337,6 +339,26 @@ void ForEachModelSizeBalancedPartition( } } +bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy, int64_t min_size) { + if (shape.NumAxes() < 1 || shape.elem_cnt() < 1) { return false; } + CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes()); + Shape cur_shape = shape; + if (cur_shape.elem_cnt() < min_size) { return false; } + FOR_RANGE(int64_t, i, hierachy.NumAxes(), ++i) { + const auto& sbp = nd_sbp.sbp_parallel(i); + if (sbp.has_split_parallel()) { + const int64_t dim = sbp.split_parallel().axis(); + if (dim >= cur_shape.NumAxes()) { return false; } + // Evenly split. + if (cur_shape.At(dim) % hierachy.At(i) != 0) { return false; } + cur_shape.Set(dim, cur_shape.At(dim) / hierachy.At(i)); + // Larger then min size. + if (cur_shape.elem_cnt() < min_size) { return false; } + } + } + return true; +} + Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); const auto IsAllowed = [](const OpNode* n) -> bool { @@ -355,6 +377,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder OperatorConf new_var_op_conf = var_node->op().op_conf(); const std::string& sole_obn = var_node->op().SoleObn(); const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); + const Shape& logical_shape = Shape(new_var_op_conf.variable_conf().shape()); std::string new_split_signature = ""; int64_t split_dim = 0; if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { @@ -400,13 +423,9 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder } ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); // check allowed by min shard size and evenly split - const auto slices = GetTensorSliceView(*pd.hierarchy(), new_nd_sbp, Shape(new_var_op_conf.variable_conf().shape())); - if (slices.size() < 2) { split_is_allowed = false; } - if (split_is_allowed && slices.at(0).shape().elem_cnt() < threshold) { split_is_allowed = false; } + LOG(ERROR) << "op " << var_node->op().op_name() << " shape " << new_var_op_conf.variable_conf().shape().DebugString() << " sbp " << NdSbpToString(new_nd_sbp); if (split_is_allowed) { - FOR_RANGE(int64_t, slice_idx, 1, slices.size()) { - if (slices.at(slice_idx).shape() != slices.at(0).shape()) { split_is_allowed = false; break;} - } + split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold); } if (split_is_allowed) { // resize sequence by new nd sbp limit From 0bcbf30147dc6f1f35aa35a8663ba5ebe1b3e9c0 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 6 Jun 2022 12:16:19 +0800 Subject: [PATCH 37/46] fix typo --- .../core/job_rewriter/optimizer_placement_optimization_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 028c1f08401..7acb5a20f1e 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -344,7 +344,7 @@ bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes()); Shape cur_shape = shape; if (cur_shape.elem_cnt() < min_size) { return false; } - FOR_RANGE(int64_t, i, hierachy.NumAxes(), ++i) { + FOR_RANGE(int64_t, i, 0, hierachy.NumAxes()) { const auto& sbp = nd_sbp.sbp_parallel(i); if (sbp.has_split_parallel()) { const int64_t dim = sbp.split_parallel().axis(); From 14c8520ea1d5ebf75366b57ba3cdc1cba91e27e4 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 6 Jun 2022 18:20:33 +0800 Subject: [PATCH 38/46] rm log --- .../core/job_rewriter/optimizer_placement_optimization_pass.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 7acb5a20f1e..e57a54266cb 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -110,7 +110,6 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( } if (!has_broadcast) { break; } } - //if (!IsAllowed(cur_node)) { break; } if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); @@ -423,7 +422,6 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder } ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); // check allowed by min shard size and evenly split - LOG(ERROR) << "op " << var_node->op().op_name() << " shape " << new_var_op_conf.variable_conf().shape().DebugString() << " sbp " << NdSbpToString(new_nd_sbp); if (split_is_allowed) { split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold); } From 56754bce241764fe2519e3efbce1832cb72ca8ad Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 6 Jun 2022 20:05:46 +0800 Subject: [PATCH 39/46] spit long func --- .../optimizer_placement_optimization_pass.cpp | 190 ++++++++++-------- oneflow/user/ops/stack_op.cpp | 2 - 2 files changed, 101 insertions(+), 91 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index e57a54266cb..5facfda11dd 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/user_op_conf.h" @@ -23,6 +24,7 @@ limitations under the License. #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" +#include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { @@ -338,6 +340,7 @@ void ForEachModelSizeBalancedPartition( } } +namespace { bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy, int64_t min_size) { if (shape.NumAxes() < 1 || shape.elem_cnt() < 1) { return false; } CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes()); @@ -358,107 +361,116 @@ bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy return true; } -Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { - const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); - const auto IsAllowed = [](const OpNode* n) -> bool { - // No need to limit here. - return true; - }; - const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, - std::vector&& sorted_sequences) { - // For all sorted sequnence, set the variable op in the sequence to S - // and add ctrl edge to control the exectuion order between variable ops. - // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass - // consume the fp16 variable and the optimizer consume the fp32 variable. - std::string prev_allowed_op_name = ""; - for (int64_t i = 0; i < sorted_sequences.size(); ++i) { - const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); - OperatorConf new_var_op_conf = var_node->op().op_conf(); - const std::string& sole_obn = var_node->op().SoleObn(); - const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); - const Shape& logical_shape = Shape(new_var_op_conf.variable_conf().shape()); - std::string new_split_signature = ""; - int64_t split_dim = 0; - if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { - // split last dim - split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; - // All B, B -> S0 - new_split_signature = "S(0)"; - } else { - // ND sbp, (*, B, S, *) -> (*, S, S, *) - // ND sbp, (*, S, B, *) -> (*, S, S, *) - FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { - if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { - std::vector adjacent_dim{j - 1, j + 1}; - for (auto const& dim_to_try : adjacent_dim) { - if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { - SbpParallel sbp; - if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), - &sbp) - && sbp.has_split_parallel()) { - new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); - split_dim = j; - } +void GenerateSplitSignature(const NdSbp& var_nd_sbp, const OperatorConf& new_var_op_conf, std::string& new_split_signature, int64_t& split_dim) { + if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { + // split last dim + split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; + // All B, B -> S0 + new_split_signature = "S(0)"; + } else { + // ND sbp, (*, B, S, *) -> (*, S, S, *) + // ND sbp, (*, S, B, *) -> (*, S, S, *) + FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { + if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { + std::vector adjacent_dim{j - 1, j + 1}; + for (auto const& dim_to_try : adjacent_dim) { + if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { + SbpParallel sbp; + if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), + &sbp) + && sbp.has_split_parallel()) { + new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); + split_dim = j; } - if (new_split_signature != "") break; } + if (new_split_signature != "") break; } - // Only split one more dim. - if (new_split_signature != "") break; } + // Only split one more dim. + if (new_split_signature != "") break; } - if (new_split_signature != "") { - *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature; - } else { - continue; - } + } +} +void ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelDesc& pd, std::vector&& sorted_sequences) { + // For all sorted sequnence, set the variable op in the sequence to S + // and add ctrl edge to control the exectuion order between variable ops. + // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass + // consume the fp16 variable and the optimizer consume the fp32 variable. + std::string prev_allowed_op_name = ""; + for (int64_t i = 0; i < sorted_sequences.size(); ++i) { + const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); + OperatorConf new_var_op_conf = var_node->op().op_conf(); + const std::string& sole_obn = var_node->op().SoleObn(); + const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); + const Shape& logical_shape = Shape(new_var_op_conf.variable_conf().shape()); + + std::string new_split_signature = ""; + int64_t split_dim = 0; + GenerateSplitSignature(var_nd_sbp, new_var_op_conf, new_split_signature, split_dim); + if (new_split_signature != "") { + *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature; + } else { + continue; + } - bool split_is_allowed = true; + bool split_is_allowed = true; + if (split_is_allowed) { + NdSbp new_nd_sbp; + std::vector nd_sbp_str_vec; + for (const auto& sbp_str : new_var_op_conf.variable_conf().nd_sbp()) { + nd_sbp_str_vec.push_back(sbp_str); + } + ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); + // check allowed by min shard size and evenly split if (split_is_allowed) { - NdSbp new_nd_sbp; - std::vector nd_sbp_str_vec; - for (const auto& sbp_str : new_var_op_conf.variable_conf().nd_sbp()) { - nd_sbp_str_vec.push_back(sbp_str); - } - ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); - // check allowed by min shard size and evenly split - if (split_is_allowed) { - split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold); - } - if (split_is_allowed) { - // resize sequence by new nd sbp limit - auto& cur_seq = sorted_sequences.at(i); - int64_t max_len = 1; - if (cur_seq->len() > 1) { - FOR_RANGE(int64_t, node_idx, 1, cur_seq->len()) { - if (IsNdSbpSupported4Op(cur_seq->nodes().at(node_idx), new_nd_sbp)) { - ++max_len; - } else { - break; - } + split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold); + } + if (split_is_allowed) { + // resize sequence by new nd sbp limit + auto& cur_seq = sorted_sequences.at(i); + int64_t max_len = 1; + if (cur_seq->len() > 1) { + FOR_RANGE(int64_t, node_idx, 1, cur_seq->len()) { + if (IsNdSbpSupported4Op(cur_seq->nodes().at(node_idx), new_nd_sbp)) { + ++max_len; + } else { + break; } } - if (max_len < cur_seq->len()) { cur_seq->resize(max_len); } } + if (max_len < cur_seq->len()) { cur_seq->resize(max_len); } } - if (!split_is_allowed) { - VLOG(3) << var_node->op().op_name() << " failed to change form B to S " - << " with op conf " << new_var_op_conf.variable_conf().DebugString(); - continue; - } - if (i != 0) { - new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); - } - builder->MutOpsOnlyOnce({new_var_op_conf}); - // Set consumers to consum this variable op's cast op's output as Broadcast. - if (new_split_signature != "") { - SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); - } - prev_allowed_op_name = var_node->op().op_name(); - VLOG(3) << var_node->op().op_name() << " succeed to change form B to " - << new_split_signature << " on ranks dim " << split_dim << " with op conf " - << new_var_op_conf.variable_conf().DebugString(); } + if (!split_is_allowed) { + VLOG(3) << var_node->op().op_name() << " failed to change form B to S " + << " with op conf " << new_var_op_conf.variable_conf().DebugString(); + continue; + } + if (i != 0) { + new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); + } + builder->MutOpsOnlyOnce({new_var_op_conf}); + // Set consumers to consum this variable op's cast op's output as Broadcast. + if (new_split_signature != "") { + SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); + } + prev_allowed_op_name = var_node->op().op_name(); + VLOG(3) << var_node->op().op_name() << " succeed to change form B to " + << new_split_signature << " on ranks dim " << split_dim << " with op conf " + << new_var_op_conf.variable_conf().DebugString(); + } +} +} // namespace + +Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { + const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); + const auto IsAllowed = [](const OpNode* n) -> bool { + // No need to limit here. + return true; + }; + const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, + std::vector&& sorted_sequences) { + ShardSequence(builder, threshold, pd, std::forward>(sorted_sequences)); }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, PlacementSequencesAsSplitParallel); diff --git a/oneflow/user/ops/stack_op.cpp b/oneflow/user/ops/stack_op.cpp index 254cbcd1743..1dd129081bd 100644 --- a/oneflow/user/ops/stack_op.cpp +++ b/oneflow/user/ops/stack_op.cpp @@ -144,8 +144,6 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& /*static*/ Maybe StackGradOp::GetSbp(user_op::SbpContext* ctx) { const auto axis = ctx->Attr("axis"); - const int64_t in_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); const int64_t like_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, like_num_axes) { From 567af3301d422aec4d62c7a4081509fb67a2f4b0 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 7 Jun 2022 10:02:02 +0800 Subject: [PATCH 40/46] restore test --- python/oneflow/test/graph/test_graph_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index d3e39be629c..a1213f1b07e 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -227,10 +227,10 @@ class TestLinearTrainGraphWithZeRO(oneflow.unittest.TestCase): def test_linear_train_graph_with_zero_1(test_case): _test_linear_train_graph_with_zero(test_case, 1) - def _test_linear_train_graph_with_zero_2(test_case): + def test_linear_train_graph_with_zero_2(test_case): _test_linear_train_graph_with_zero(test_case, 2) - def _test_linear_train_graph_with_zero_3(test_case): + def test_linear_train_graph_with_zero_3(test_case): _test_linear_train_graph_with_zero(test_case, 3) From 84ca7786656450c98f180449c1f0d64b9d589692 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Thu, 9 Jun 2022 09:57:44 +0800 Subject: [PATCH 41/46] Update optimizer_placement_optimization_pass.cpp --- .../core/job_rewriter/optimizer_placement_optimization_pass.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 5facfda11dd..4cd37b4a820 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -105,7 +105,6 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn); - if (ibn_nd_sbp.sbp_parallel_size() == 0) { break; } bool has_broadcast = false; FOR_RANGE(int, i, 0, ibn_nd_sbp.sbp_parallel_size()) { if (ibn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; From 7095ec3f274276a11f08d540a0924033107a623c Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 9 Jun 2022 08:32:36 +0000 Subject: [PATCH 42/46] auto format by CI --- .../optimizer_placement_optimization_pass.cpp | 88 ++++++++++--------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 4cd37b4a820..2e6a2ca1911 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -170,14 +170,16 @@ void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const std::string parallel_cast_input_lbn = GenLogicalBlobName(lbi); // Add indentity to enable mem reuse of boxing op when there is no op between var op and boxing. if (sequence->len() == 1) { - VLOG(3) << "ZeRO find a data-parallel sequence only has one variable " << sequence->GetVariableNode()->op().op_name(); - const auto var_identity_op = user_op::UserOpConfWrapperBuilder("System-ZeRO-Identity-" + node->op().op_name() + "-" - + NewUniqueId()) - .Op("identity") - .Input("in", GenLogicalBlobName(lbi)) - .Output("out") - .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) - .Build(); + VLOG(3) << "ZeRO find a data-parallel sequence only has one variable " + << sequence->GetVariableNode()->op().op_name(); + const auto var_identity_op = + user_op::UserOpConfWrapperBuilder("System-ZeRO-Identity-" + node->op().op_name() + "-" + + NewUniqueId()) + .Op("identity") + .Input("in", GenLogicalBlobName(lbi)) + .Output("out") + .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) + .Build(); builder->AddOps(node->parallel_desc().parallel_conf(), {var_identity_op.op_conf()}); parallel_cast_input_lbn = var_identity_op.output("out", 0); } @@ -298,7 +300,7 @@ bool IsNdSbpSupported4Op(const OpNode* node, const NdSbp& nd_sbp) { return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); }; CHECK_JUST(node->op().GetNdSbpSignatureList(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); - const auto IsInAndOutMatch= [&](const NdSbpSignature& signature) { + const auto IsInAndOutMatch = [&](const NdSbpSignature& signature) { return IsNdSbpMatch(signature, node->op().SoleIbn(), nd_sbp) && IsNdSbpMatch(signature, node->op().SoleObn(), nd_sbp); }; @@ -340,7 +342,8 @@ void ForEachModelSizeBalancedPartition( } namespace { -bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy, int64_t min_size) { +bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy, + int64_t min_size) { if (shape.NumAxes() < 1 || shape.elem_cnt() < 1) { return false; } CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes()); Shape cur_shape = shape; @@ -348,7 +351,7 @@ bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy FOR_RANGE(int64_t, i, 0, hierachy.NumAxes()) { const auto& sbp = nd_sbp.sbp_parallel(i); if (sbp.has_split_parallel()) { - const int64_t dim = sbp.split_parallel().axis(); + const int64_t dim = sbp.split_parallel().axis(); if (dim >= cur_shape.NumAxes()) { return false; } // Evenly split. if (cur_shape.At(dim) % hierachy.At(i) != 0) { return false; } @@ -360,37 +363,38 @@ bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy return true; } -void GenerateSplitSignature(const NdSbp& var_nd_sbp, const OperatorConf& new_var_op_conf, std::string& new_split_signature, int64_t& split_dim) { - if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { - // split last dim - split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; - // All B, B -> S0 - new_split_signature = "S(0)"; - } else { - // ND sbp, (*, B, S, *) -> (*, S, S, *) - // ND sbp, (*, S, B, *) -> (*, S, S, *) - FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { - if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { - std::vector adjacent_dim{j - 1, j + 1}; - for (auto const& dim_to_try : adjacent_dim) { - if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { - SbpParallel sbp; - if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), - &sbp) - && sbp.has_split_parallel()) { - new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); - split_dim = j; - } +void GenerateSplitSignature(const NdSbp& var_nd_sbp, const OperatorConf& new_var_op_conf, + std::string& new_split_signature, int64_t& split_dim) { + if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { + // split last dim + split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; + // All B, B -> S0 + new_split_signature = "S(0)"; + } else { + // ND sbp, (*, B, S, *) -> (*, S, S, *) + // ND sbp, (*, S, B, *) -> (*, S, S, *) + FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { + if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { + std::vector adjacent_dim{j - 1, j + 1}; + for (auto const& dim_to_try : adjacent_dim) { + if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { + SbpParallel sbp; + if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), &sbp) + && sbp.has_split_parallel()) { + new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); + split_dim = j; } - if (new_split_signature != "") break; } + if (new_split_signature != "") break; } - // Only split one more dim. - if (new_split_signature != "") break; } + // Only split one more dim. + if (new_split_signature != "") break; } + } } -void ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelDesc& pd, std::vector&& sorted_sequences) { +void ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelDesc& pd, + std::vector&& sorted_sequences) { // For all sorted sequnence, set the variable op in the sequence to S // and add ctrl edge to control the exectuion order between variable ops. // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass @@ -445,21 +449,19 @@ void ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelD << " with op conf " << new_var_op_conf.variable_conf().DebugString(); continue; } - if (i != 0) { - new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); - } + if (i != 0) { new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); } builder->MutOpsOnlyOnce({new_var_op_conf}); // Set consumers to consum this variable op's cast op's output as Broadcast. if (new_split_signature != "") { SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); } prev_allowed_op_name = var_node->op().op_name(); - VLOG(3) << var_node->op().op_name() << " succeed to change form B to " - << new_split_signature << " on ranks dim " << split_dim << " with op conf " + VLOG(3) << var_node->op().op_name() << " succeed to change form B to " << new_split_signature + << " on ranks dim " << split_dim << " with op conf " << new_var_op_conf.variable_conf().DebugString(); } } -} // namespace +} // namespace Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); @@ -468,7 +470,7 @@ Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder return true; }; const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, - std::vector&& sorted_sequences) { + std::vector&& sorted_sequences) { ShardSequence(builder, threshold, pd, std::forward>(sorted_sequences)); }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, From b401e669f21e380bcf18d9406939e8fd4c230839 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 9 Jun 2022 08:54:15 +0000 Subject: [PATCH 43/46] auto format by CI --- python/oneflow/nn/graph/graph_config.py | 4 +++- python/oneflow/test/graph/test_graph_zero.py | 25 ++++++++------------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 1c5775c97b5..ca03078f91b 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -120,7 +120,9 @@ def build(self, x): if stage >= 1: self.proto.optimizer_placement_optimization_mode = "distributed_split" self.proto.optimizer_placement_optimization_threshold = shard_min_size - self.proto.optimizer_placement_optimization_shard_restore_level = shard_restore_level + self.proto.optimizer_placement_optimization_shard_restore_level = ( + shard_restore_level + ) if stage >= 2: nccl_config.enable_use_compute_stream(True) if stage >= 3: diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index a1213f1b07e..51fa38a8657 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -54,10 +54,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) self.config.enable_zero( - True, - stage=zero_stage, - shard_min_size=1, - shard_restore_level=0, + True, stage=zero_stage, shard_min_size=1, shard_restore_level=0, ) self.debug(2) @@ -133,14 +130,14 @@ def __init__(self): super().__init__() self.dp_mp = linear_dp_mp self.mp_dp = linear_mp_dp - + def forward(self, x): x = self.dp_mp(x) x = flow.relu(x) x = self.mp_dp(x) x = flow.relu(x) return x - + return MixedLinear() mixed_linear0 = get_mixed_linear() @@ -170,10 +167,7 @@ def __init__(self): self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) self.config.enable_zero( - True, - stage=zero_stage, - shard_min_size=1, - shard_restore_level=1, + True, stage=zero_stage, shard_min_size=1, shard_restore_level=1, ) def build(self, x): @@ -186,8 +180,8 @@ def build(self, x): class LinearEvalGraph2DWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.mixed_linear0 = mixed_linear0 - self.mixed_linear1 = mixed_linear1 + self.mixed_linear0 = mixed_linear0 + self.mixed_linear1 = mixed_linear1 self.config.enable_amp(True) @@ -201,7 +195,7 @@ def build(self, x): def one_train_iter(): out = linear_t_g(x) - #if flow.env.get_rank() == 0: + # if flow.env.get_rank() == 0: # print(linear_t_g) def one_eval_iter(): @@ -211,7 +205,9 @@ def one_eval_iter(): one_train_iter() for state in linear_t_g._state(): - test_case.assertEqual(state.origin.sbp, (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=0))) + test_case.assertEqual( + state.origin.sbp, (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=0)) + ) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. @@ -247,6 +243,5 @@ def test_linear_train_graph_2d_with_zero_1(test_case): _test_linear_train_graph_2d_with_zero(test_case, 1) - if __name__ == "__main__": unittest.main() From 2b0324e397aa35bd9ee22478d740132db5be3fef Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 9 Jun 2022 22:35:57 +0800 Subject: [PATCH 44/46] fix static check --- .../job_rewriter/optimizer_placement_optimization_pass.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 2e6a2ca1911..7aaf2e75426 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -40,11 +40,10 @@ int64_t GetSoleOutBlobSize(const OpNode* node) { class DataParallelNodeSequence final { public: DataParallelNodeSequence(std::vector nodes, int64_t order) - : nodes_(std::move(nodes)), order_(order) { + : nodes_(std::move(nodes)), order_(order), len_(nodes_.size()) { const OpNode* var_node = nodes_.front(); CHECK(var_node->op().op_conf().has_variable_conf()); model_size_ = GetSoleOutBlobSize(var_node); - len_ = nodes_.size(); } ~DataParallelNodeSequence() = default; From 7d611c4fa1e785c698b2f8bc2b1ac42fe0f339be Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 10 Jun 2022 11:07:05 +0800 Subject: [PATCH 45/46] add tips for zero api change --- python/oneflow/nn/graph/graph_config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index ca03078f91b..33810922a1c 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -71,6 +71,11 @@ def build(self, x): """ assert type(mode) is bool self.proto.enable_auto_mixed_precision = mode + + def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"): + raise RuntimeError( + "`set_zero_redundancy_optimizer_mode` has been changed to `enable_zero`, please use `enable_zero(True)` to activate ZeRO optimization." + ) def enable_zero( self, From 640487b19fc09acfa3e0044ca38e0e0207058a96 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 10 Jun 2022 03:09:47 +0000 Subject: [PATCH 46/46] auto format by CI --- python/oneflow/nn/graph/graph_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 33810922a1c..ea48ad8d957 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -71,7 +71,7 @@ def build(self, x): """ assert type(mode) is bool self.proto.enable_auto_mixed_precision = mode - + def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"): raise RuntimeError( "`set_zero_redundancy_optimizer_mode` has been changed to `enable_zero`, please use `enable_zero(True)` to activate ZeRO optimization."