From 16b2257885a3408f2a8eec5d20678feda1e63e46 Mon Sep 17 00:00:00 2001 From: Roc Date: Thu, 8 Dec 2022 05:40:03 +0000 Subject: [PATCH 1/2] support sharding in fp16 on xpu, change reduce_max to reduce_sum for found nan or inf --- .../meta_parallel/sharding/group_sharded_storage.py | 10 +++++++++- .../meta_parallel/sharding/group_sharded_utils.py | 6 ++++-- .../fleet/test_imperative_auto_mixed_precision.py | 2 +- .../test_imperative_auto_mixed_precision_for_eager.py | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index abab68a1912170..63076daac1a214 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -162,7 +162,15 @@ def _add_param_as_view(self, param, align, convert_gpu=True): origin_state = param.stop_gradient param.stop_gradient = True - param.flatten_() + + if paddle.is_compiled_with_xpu(): + # now xpu doesn't support flatten kenel for fp16, but supports fp32 + # TODO(liupeng51): support fp16 + tmp = paddle.flatten(param.astype("float32")).astype(param.dtype) + paddle.assign(tmp, output=param) + else: + param.flatten_() + param.stop_gradient = origin_state # Copy the current param value diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index c12381c894e794..d845f3b78c6345 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -220,7 +220,8 @@ def unscale_method(self, optimizer): temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) - device = "cpu" if optimizer.offload else "gpu" + device = paddle.get_device().split(":")[0] + device = "cpu" if optimizer.offload else device dev_id = ( 0 if device == "cpu" else int(paddle.get_device().split(":")[1]) ) @@ -245,8 +246,9 @@ def unscale_method(self, optimizer): is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") paddle.distributed.all_reduce( - is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None + is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None ) + self._found_inf = is_found_inf.numpy()[0] scaler._unscale = MethodType(unscale_method, scaler) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py index 3141ed81a458b2..54f81cfd5de684 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py @@ -344,7 +344,7 @@ def nan_inf(self): scaled_loss = scaler.scale(loss) scaled_loss.backward() optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) - self.assertEqual(scaler._found_inf.numpy() == 1, True) + self.assertEqual(scaler._found_inf.numpy() >= 1, True) for param in model.parameters(): # param not update when tensor contains nan or inf diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index dbcdf3ac46c9e8..8800515e690e35 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -343,7 +343,7 @@ def nan_inf(self): scaled_loss = scaler.scale(loss) scaled_loss.backward() optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) - self.assertEqual(scaler._found_inf.numpy() == 1, True) + self.assertEqual(scaler._found_inf.numpy() >= 1, True) for param in model.parameters(): # param not update when tensor contains nan or inf From 0737c559e5f8e713026bf22bb5b2f8a15e10a4d6 Mon Sep 17 00:00:00 2001 From: Roc Date: Fri, 9 Dec 2022 06:32:23 +0000 Subject: [PATCH 2/2] update --- .../meta_parallel/sharding/group_sharded_storage.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 63076daac1a214..abab68a1912170 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -162,15 +162,7 @@ def _add_param_as_view(self, param, align, convert_gpu=True): origin_state = param.stop_gradient param.stop_gradient = True - - if paddle.is_compiled_with_xpu(): - # now xpu doesn't support flatten kenel for fp16, but supports fp32 - # TODO(liupeng51): support fp16 - tmp = paddle.flatten(param.astype("float32")).astype(param.dtype) - paddle.assign(tmp, output=param) - else: - param.flatten_() - + param.flatten_() param.stop_gradient = origin_state # Copy the current param value