Skip to content

Commit

Permalink
add softmax backward rule
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Aug 18, 2023
1 parent 6e45db7 commit 05ef040
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ namespace auto_parallel {

using phi::distributed::auto_parallel::str_join;

// step0: verify input args based on softmax logic
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on softmax logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
Expand Down Expand Up @@ -94,7 +94,7 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

VLOG(4) << "EmbeddingSPMDRule InferForward: "
VLOG(4) << "SoftmaxSPMDRule InferForward: "
<< "Einsum notation: [" << x_axes << " --> " << out_axes << "]. "
<< std::endl
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
Expand All @@ -106,11 +106,72 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SoftmaxSPMDRule::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const std::vector<DistTensorSpec>& input_specs,
SoftmaxSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of SoftmaxSPMDRule is NOT implemented yet."));
// step0: verify input args based on softmax logic
int64_t input_specs_size = input_specs.size();
int64_t output_specs_size = output_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
1,
phi::errors::InvalidArgument(
"The size of InputSpec of softmax should be 1, but got [%d].",
input_specs_size));
PADDLE_ENFORCE_EQ(
output_specs_size,
1,
phi::errors::InvalidArgument(
"The size of InputSpec of softmax should be 1, but got [%d].",
output_specs_size));
VerifySpecs(output_specs, "softmax_backward");

// step1: build Einsum Notation
std::vector<int64_t> x_shape = input_specs[0].shape();
int64_t x_ndim = input_specs[0].shape().size();

std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
std::string out_axes = x_axes;

int axis = ExtractAttr<int>("axis", attrs);
// normalize axis
if (axis < 0) {
axis = x_ndim + axis;
}

// sharding on softmax_axis is not supported now, so set
// the notation on softmax_axis to '1' so that we can set
// its dim mapping to -1
x_axes[axis] = '1';

// step2: Sharding Propogation
std::vector<int64_t> out_dims_mapping = output_specs[0].dims_mapping();
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{out_axes, out_dims_mapping}});

// infer input's dims mapping.
std::vector<int64_t> x_dims_mapping =
GetDimsMappingForAxes(x_axes, axis_to_dim_map);
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(x_dims_mapping);

// update output's dims mapping.
out_dims_mapping[axis] = -1;
TensorDistAttr output_dist_attr(output_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(out_dims_mapping);

VLOG(4) << "SoftmaxSPMDRule InferBackward: "
<< "softmax_axis: " << axis << std::endl
<< "Einsum notation: [" << x_axes << " --> " << out_axes << "]. "
<< std::endl
<< "Output shape: [" << str_join(output_specs[0].shape())
<< "], src_dims_mapping: ["
<< str_join(output_specs[0].dims_mapping())
<< "], dst_dims_mapping: [" << str_join(out_dims_mapping)
<< "]; Input dims_mapping: [" << str_join(x_dims_mapping) << "]";

return {{input_dist_attr}, {output_dist_attr}};
}

} // namespace auto_parallel
Expand Down
64 changes: 64 additions & 0 deletions test/auto_parallel/spmd_rules/test_softmax_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def setUp(self):
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)

self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec)

self.attrs = {
'axis': -1,
}
Expand Down Expand Up @@ -99,6 +101,68 @@ def test_softmax_infer_forward(self):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])

def test_softmax_infer_backward(self):
# sharding on batch axis I
self.out_dist_tensor_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
self.assertEqual(len(result_dist_attrs), 2)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])

# sharding on batch axis II
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1])

# sharding on softmax_axis
self.out_dist_tensor_spec.set_dims_mapping([1, -1, 0])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])

# sharding on softmax_axis + axis = 1
self.attrs = {
'axis': 1,
}
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])

# sharding on softmax_axis + axis = -2
self.attrs = {
'axis': -2,
}
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])


if __name__ == "__main__":
unittest.main()

0 comments on commit 05ef040

Please sign in to comment.