Skip to content

Commit

Permalink
add split backward rule
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Sep 10, 2023
1 parent 3c19d80 commit 15cd1ac
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 8 deletions.
110 changes: 104 additions & 6 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ SplitSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// step2.2: infer output dimsmapping from merged input dimsmapping
// step2.2: infer output dims mapping from merged input dims mapping
std::vector<int64_t> output_dims_mapping =
GetDimsMappingForAxes(output_axes, axis_to_dim_map);

Expand All @@ -94,7 +94,7 @@ SplitSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
new_input_dims_mapping[axis] = -1;
new_input_dist_attrs[0].set_dims_mapping(new_input_dims_mapping);

// Step2.4 handle input tensor partial (TODO)
// Step3 Handle input tensor partial (TODO)
VLOG(4) << "SplitSPMDRule InferForward: ";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
Expand All @@ -113,12 +113,110 @@ SplitSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SplitSPMDRule::InferBackward(const std::vector<DistTensorSpec>& output_specs,
SplitSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of SplitPMDRule is NOT implemented yet."));
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in split must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(output_specs, "split");

// check whether the size of output_specs equals
// to the specified split num in op attributes
int64_t specified_split_num = -1;
// split api uses num or sections as attribute
if (attrs.find("num") != attrs.end()) {
specified_split_num = ExtractAttr<int64_t>("num", attrs);
} else if (attrs.find("sections") != attrs.end()) {
std::vector<int64_t> sections =
ExtractAttr<std::vector<int64_t>>("sections", attrs);
specified_split_num = sections.size();
}
PADDLE_ENFORCE_EQ(
noutputs,
specified_split_num,
phi::errors::InvalidArgument("The size of OutputSpec [%d] is not equal "
"to the specified split number [%d]",
noutputs,
specified_split_num));

// check whether all dims mapping in output_specs are the same
const std::vector<int64_t>& dims_mapping0 = output_specs[0].dims_mapping();
for (int64_t i = 1; i < noutputs; i++) {
const std::vector<int64_t>& dims_mapping = output_specs[i].dims_mapping();
if (!std::equal(
dims_mapping0.begin(), dims_mapping0.end(), dims_mapping.begin())) {
PADDLE_THROW(
phi::errors::InvalidArgument("Not all dims_mappings in "
"output_specs are the same."));
}
}

// step1: Build Einsum Notation
int64_t ndim = input_specs[0].shape().size();
int64_t axis = ExtractAttr<int>("axis", attrs);
if (axis < 0) {
axis += ndim;
}
std::string alphabet = "abcdefghijlmnopqrstuvwxyz";

// get einsum notation for input, use a special
// notation 'k' to mark the splitted axis in input
std::string input_axes = alphabet.substr(0, ndim);
input_axes[axis] = 'k';

// get einsum notation for output
std::string output_axes(input_axes);
// the splitted axis cannot be sharded, set its notation
// with the special '1' to set its dim mapping to -1.
output_axes[axis] = '1';

// step2: Sharding Propogation
// step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info = {{output_axes, dims_mapping0}};
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// step2.2: infer input dims mapping from output dims mapping
// the split axis in input is set to -1.
std::vector<int64_t> input_dims_mapping =
GetDimsMappingForAxes(input_axes, axis_to_dim_map, true);

std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < noutputs; i++) {
output_dist_attrs.emplace_back(output_specs[i].dist_attr());
}

// step2.3 get new dist attribute for input. the splitted
// cannot be sharded, if it is sharded, set it to replicated.
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(input_dims_mapping);

// step3 Handle input tensor partial (TODO)

VLOG(4) << "SplitSPMDRule InferBackward: ";
for (int64_t i = 0; i < noutputs; i++) {
VLOG(4) << "Output" << std::to_string(i) << " shape: ["
<< str_join(output_specs[i].shape()) << "] "
<< "einsum_notation: " << output_axes << " dims_mapping: ["
<< str_join(output_specs[i].dims_mapping()) << "]";
}
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "einsum_notation: " << input_axes << " dims_mapping: ["
<< str_join(input_dims_mapping) << "]";
}
VLOG(4) << std::endl;

return {};
return {{input_dist_attr}, output_dist_attrs};
}

} // namespace auto_parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class SplitSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
Expand Down
242 changes: 241 additions & 1 deletion test/auto_parallel/spmd_rules/test_split_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_single_mesh_dim(self):
self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, -1, -1])

# num_or_sections = [15, 16, 17], axis = 2
# [-1, -1, 0] --> [-1, -1, -1], [-1, -1, -1], [-1 -1, -1], [-1, -1, -1]
# [-1, -1, 0] --> [-1, -1, -1], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1]
self.attrs = {}
self.attrs['sections'] = [15, 16, 17]
self.attrs['axis'] = 2
Expand Down Expand Up @@ -200,6 +200,246 @@ def test_multi_mesh_dim(self):
infered_output_dist_attrs[2].dims_mapping, [-1, 1, -1, -1]
)

def test_backward_single_mesh_dim(self):
x_shape = self.x_dist_tensor_spec.shape

# num_or_sections = 2, axis = 1
# [0, -1, -1], [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1]
# (outputs --> input, outputs)
self.rule = get_spmd_rule("split_with_num")
self.attrs = {}
self.attrs['num'] = 2
self.attrs['axis'] = 1
self.out_spec_list = []
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list[0].shape = [x_shape[0], x_shape[1] // 2, x_shape[2]]
self.out_spec_list[1].shape = [x_shape[0], x_shape[1] // 2, x_shape[2]]
self.out_spec_list[0].set_dims_mapping([0, -1, -1])
self.out_spec_list[1].set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 2)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1])

# num_or_sections = [15, 16, 17], axis = 2
# [0, -1, -1], [0, -1, -1], [0, -1, -1] -->
# [0, -1, -1], [0, -1, -1], [0, -1, -1], [0, -1, -1]
# (outputs --> input, outputs)
self.rule = get_spmd_rule("split")
self.attrs = {}
self.attrs['sections'] = [15, 16, 17]
self.attrs['axis'] = 2
self.out_spec_list = []
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list[0].shape = [x_shape[0], x_shape[1], 15]
self.out_spec_list[1].shape = [x_shape[0], x_shape[1], 16]
self.out_spec_list[2].shape = [x_shape[0], x_shape[1], 17]
self.out_spec_list[0].set_dims_mapping([0, -1, -1])
self.out_spec_list[1].set_dims_mapping([0, -1, -1])
self.out_spec_list[2].set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, -1, -1])

# num_or_sections = [15, 16, 17], axis = 2
# [-1, -1, -1], [-1, -1, -1], [-1, -1, -1] -->
# [-1, -1, -1], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1]
# (outputs --> input, outputs)
self.attrs = {}
self.attrs['sections'] = [15, 16, 17]
self.attrs['axis'] = 2
self.out_spec_list = []
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list[0].shape = [x_shape[0], x_shape[1], 15]
self.out_spec_list[1].shape = [x_shape[0], x_shape[1], 16]
self.out_spec_list[2].shape = [x_shape[0], x_shape[1], 17]
self.out_spec_list[0].set_dims_mapping([-1, -1, -1])
self.out_spec_list[1].set_dims_mapping([-1, -1, -1])
self.out_spec_list[2].set_dims_mapping([-1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[2].dims_mapping, [-1, -1, -1]
)

# num_or_sections = 2, axis = -2
# [0, -1, -1], [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1]
# (outputs --> input, outputs)
self.rule = get_spmd_rule("split_with_num")
self.attrs = {}
self.attrs['num'] = 2
self.attrs['axis'] = -2
self.out_spec_list = []
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list[0].shape = [x_shape[0], x_shape[1] // 2, x_shape[2]]
self.out_spec_list[1].shape = [x_shape[0], x_shape[1] // 2, x_shape[2]]
self.out_spec_list[0].set_dims_mapping([0, -1, -1])
self.out_spec_list[1].set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 2)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1])

def test_backward_multi_mesh_dim(self):
x_shape = [96, 32, 48, 24]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = x_shape

# num_or_sections = 3, axis = -1
# [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1] -->
# [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1]
# (outputs --> input, outputs)
self.rule = get_spmd_rule("split_with_num")
self.attrs = {}
self.attrs['num'] = 3
self.attrs['axis'] = -1
self.out_spec_list = []
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list[0].shape = [
x_shape[0],
x_shape[1],
x_shape[2],
x_shape[3] // 3,
]
self.out_spec_list[1].shape = [
x_shape[0],
x_shape[1],
x_shape[2],
x_shape[3] // 3,
]
self.out_spec_list[2].shape = [
x_shape[0],
x_shape[1],
x_shape[2],
x_shape[3] // 3,
]
self.out_spec_list[0].set_dims_mapping([0, 1, -1, -1])
self.out_spec_list[1].set_dims_mapping([0, 1, -1, -1])
self.out_spec_list[2].set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[2].dims_mapping, [0, 1, -1, -1]
)

# num_or_sections = [32, 32, 32], axis = 0
# [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1] -->
# [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1]
# (outputs --> input, outputs)
self.rule = get_spmd_rule("split")
self.attrs = {}
self.attrs['sections'] = [32, 32, 32]
self.attrs['axis'] = 0
self.out_spec_list = []
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list.append(DistTensorSpec(self.x_dist_tensor_spec))
self.out_spec_list[0].shape = [32, x_shape[1], x_shape[2], x_shape[3]]
self.out_spec_list[1].shape = [32, x_shape[1], x_shape[2], x_shape[3]]
self.out_spec_list[2].shape = [32, x_shape[1], x_shape[2], x_shape[3]]
self.out_spec_list[0].set_dims_mapping([-1, 1, -1, -1])
self.out_spec_list[1].set_dims_mapping([-1, 1, -1, -1])
self.out_spec_list[2].set_dims_mapping([-1, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[2].dims_mapping, [-1, 1, -1, -1]
)

# [-1, 1, -1, -1], [0, 1, -1, -1], [-1, 1, -1, -1] --> error
self.out_spec_list[1].set_dims_mapping([0, 1, -1, -1])
with self.assertRaises(BaseException):
self.rule.infer_backward(
[self.x_dist_tensor_spec], self.out_spec_list, self.attrs
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 15cd1ac

Please sign in to comment.