From 05ef04046333d1f37116ae175da8024fc0c8b3e6 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Fri, 18 Aug 2023 21:06:17 +0800 Subject: [PATCH] add softmax backward rule --- .../spmd_rules/softmax_spmd_rule.cc | 73 +++++++++++++++++-- .../spmd_rules/test_softmax_rule.py | 64 ++++++++++++++++ 2 files changed, 131 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc index 9bafda84e29fe..4d14955d32679 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc @@ -20,10 +20,10 @@ namespace auto_parallel { using phi::distributed::auto_parallel::str_join; -// step0: verify input args based on softmax logic std::pair, std::vector> SoftmaxSPMDRule::InferForward(const std::vector& input_specs, const paddle::framework::AttributeMap& attrs) { + // step0: verify input args based on softmax logic auto input_specs_size = input_specs.size(); PADDLE_ENFORCE_EQ( input_specs_size, @@ -94,7 +94,7 @@ SoftmaxSPMDRule::InferForward(const std::vector& input_specs, TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); x_dist_attr_dst.set_dims_mapping(x_dims_mapping); - VLOG(4) << "EmbeddingSPMDRule InferForward: " + VLOG(4) << "SoftmaxSPMDRule InferForward: " << "Einsum notation: [" << x_axes << " --> " << out_axes << "]. " << std::endl << "X shape: [" << str_join(x_shape) << "], src_dims_mapping: [" @@ -106,11 +106,72 @@ SoftmaxSPMDRule::InferForward(const std::vector& input_specs, } std::pair, std::vector> -SoftmaxSPMDRule::InferBackward(const std::vector& output_specs, - const std::vector& input_specs, +SoftmaxSPMDRule::InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of SoftmaxSPMDRule is NOT implemented yet.")); + // step0: verify input args based on softmax logic + int64_t input_specs_size = input_specs.size(); + int64_t output_specs_size = output_specs.size(); + PADDLE_ENFORCE_EQ( + input_specs_size, + 1, + phi::errors::InvalidArgument( + "The size of InputSpec of softmax should be 1, but got [%d].", + input_specs_size)); + PADDLE_ENFORCE_EQ( + output_specs_size, + 1, + phi::errors::InvalidArgument( + "The size of InputSpec of softmax should be 1, but got [%d].", + output_specs_size)); + VerifySpecs(output_specs, "softmax_backward"); + + // step1: build Einsum Notation + std::vector x_shape = input_specs[0].shape(); + int64_t x_ndim = input_specs[0].shape().size(); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + std::string out_axes = x_axes; + + int axis = ExtractAttr("axis", attrs); + // normalize axis + if (axis < 0) { + axis = x_ndim + axis; + } + + // sharding on softmax_axis is not supported now, so set + // the notation on softmax_axis to '1' so that we can set + // its dim mapping to -1 + x_axes[axis] = '1'; + + // step2: Sharding Propogation + std::vector out_dims_mapping = output_specs[0].dims_mapping(); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{out_axes, out_dims_mapping}}); + + // infer input's dims mapping. + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map); + TensorDistAttr input_dist_attr(input_specs[0].dist_attr()); + input_dist_attr.set_dims_mapping(x_dims_mapping); + + // update output's dims mapping. + out_dims_mapping[axis] = -1; + TensorDistAttr output_dist_attr(output_specs[0].dist_attr()); + output_dist_attr.set_dims_mapping(out_dims_mapping); + + VLOG(4) << "SoftmaxSPMDRule InferBackward: " + << "softmax_axis: " << axis << std::endl + << "Einsum notation: [" << x_axes << " --> " << out_axes << "]. " + << std::endl + << "Output shape: [" << str_join(output_specs[0].shape()) + << "], src_dims_mapping: [" + << str_join(output_specs[0].dims_mapping()) + << "], dst_dims_mapping: [" << str_join(out_dims_mapping) + << "]; Input dims_mapping: [" << str_join(x_dims_mapping) << "]"; + + return {{input_dist_attr}, {output_dist_attr}}; } } // namespace auto_parallel diff --git a/test/auto_parallel/spmd_rules/test_softmax_rule.py b/test/auto_parallel/spmd_rules/test_softmax_rule.py index 57e4532ba3cfb..3b33b99cfe74b 100644 --- a/test/auto_parallel/spmd_rules/test_softmax_rule.py +++ b/test/auto_parallel/spmd_rules/test_softmax_rule.py @@ -33,6 +33,8 @@ def setUp(self): x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec) + self.attrs = { 'axis': -1, } @@ -99,6 +101,68 @@ def test_softmax_infer_forward(self): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + def test_softmax_infer_backward(self): + # sharding on batch axis I + self.out_dist_tensor_spec.set_dims_mapping([1, -1, -1]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + self.assertEqual(len(result_dist_attrs), 2) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) + + # sharding on batch axis II + self.out_dist_tensor_spec.set_dims_mapping([-1, 1, -1]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1]) + + # sharding on softmax_axis + self.out_dist_tensor_spec.set_dims_mapping([1, -1, 0]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) + + # sharding on softmax_axis + axis = 1 + self.attrs = { + 'axis': 1, + } + self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + + # sharding on softmax_axis + axis = -2 + self.attrs = { + 'axis': -2, + } + self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + if __name__ == "__main__": unittest.main()