@@ -161,7 +161,7 @@ std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
161161 // map between from idx in shape to new_shape
162162 std::vector<int64_t > idx_map (shape.size (), -1 );
163163 for (int i = 0 , n = static_cast <int >(shape.size ()); i < n; ++i) {
164- if (shape[id ] != 1 ) {
164+ if (shape[i ] != 1 ) {
165165 idx_map[i] = static_cast <int64_t >(new_shape.size ());
166166 new_shape.emplace_back (shape[i]);
167167 }
@@ -272,6 +272,139 @@ std::vector<std::shared_ptr<DimTrans>> GetDimTrans(
272272 return ret_dim_trans;
273273}
274274
275+ std::vector<std::shared_ptr<DimTrans>> GetDimTransCoShard (
276+ const std::shared_ptr<DimTrans> dim_trans,
277+ const std::vector<int64_t >& input_shape,
278+ const std::vector<int64_t >& mesh_shape,
279+ const std::vector<std::vector<int64_t >>& input_dims_mapping,
280+ const std::set<int64_t >& sharded_input_dims,
281+ std::vector<std::vector<bool >>* shardable,
282+ std::set<int64_t >* seen_dims) {
283+ DimTrans::Type type = dim_trans->type ();
284+ std::vector<std::shared_ptr<DimTrans>> ret_dim_trans;
285+
286+ if (type == DimTrans::Type::INPUTDIM) {
287+ std::shared_ptr<InputDim> inputdim =
288+ std::dynamic_pointer_cast<InputDim>(dim_trans);
289+ int64_t dim = inputdim->input_dim ();
290+ seen_dims->insert (dim);
291+
292+ if (sharded_input_dims.count (dim) > 0 ) {
293+ ret_dim_trans.push_back (dim_trans);
294+ }
295+ } else if (type == DimTrans::Type::FLATTEN) {
296+ std::shared_ptr<Flatten> flatten =
297+ std::dynamic_pointer_cast<Flatten>(dim_trans);
298+ const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs ();
299+
300+ int64_t nmesh = (*shardable)[0 ].size (); // NOLINT
301+ int64_t mesh_shape_prod = 1 ;
302+
303+ int last_shard_idx = -1 ;
304+ int first_shard_idx = -1 ;
305+ int64_t first_sharded_shape = -1 ;
306+ for (int i = 0 , n = static_cast <int >(inputs.size ()); i < n; ++i) {
307+ std::shared_ptr<DimTrans> input = inputs[i];
308+ if (input->type () == DimTrans::Type::INPUTDIM) {
309+ std::shared_ptr<InputDim> inputdim =
310+ std::dynamic_pointer_cast<InputDim>(input);
311+ if (sharded_input_dims.count (inputdim->input_dim ()) > 0 ) {
312+ if (first_shard_idx == -1 ) {
313+ first_shard_idx = i;
314+ first_sharded_shape = input_shape[inputdim->input_dim ()];
315+ }
316+ for (const auto & dim : input_dims_mapping[inputdim->input_dim ()]) {
317+ mesh_shape_prod *= mesh_shape[dim];
318+ }
319+ if (first_sharded_shape % mesh_shape_prod == 0 ) {
320+ ret_dim_trans.push_back (inputdim);
321+ } else {
322+ break ;
323+ }
324+ } else {
325+ break ;
326+ }
327+ last_shard_idx = i;
328+ } else {
329+ break ;
330+ }
331+ }
332+
333+ for (int i = last_shard_idx + 1 , n = static_cast <int >(inputs.size ()); i < n;
334+ i++) {
335+ std::shared_ptr<DimTrans> input = inputs[i];
336+ if (input->type () == DimTrans::Type::INPUTDIM) {
337+ std::shared_ptr<InputDim> inputdim =
338+ std::dynamic_pointer_cast<InputDim>(input);
339+ (*shardable)[inputdim->input_dim ()].assign (nmesh, false );
340+ }
341+
342+ GetDimTransCoShard (input,
343+ input_shape,
344+ mesh_shape,
345+ input_dims_mapping,
346+ sharded_input_dims,
347+ shardable,
348+ seen_dims);
349+ }
350+ } else if (type == DimTrans::Type::SPLIT) {
351+ std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
352+ std::vector<std::shared_ptr<DimTrans>> dims =
353+ GetDimTransCoShard (split->input (),
354+ input_shape,
355+ mesh_shape,
356+ input_dims_mapping,
357+ sharded_input_dims,
358+ shardable,
359+ seen_dims);
360+ int64_t ret_size = split->local_split_shape_value ();
361+
362+ if (split->split_id () == 0 ) {
363+ int64_t mesh_shape_prod = 1 ;
364+ int64_t first_shard_idx = -1 ;
365+ int64_t first_sharded_shape = -1 ;
366+ for (const auto & dim : dims) {
367+ PADDLE_ENFORCE_EQ (dim->type (),
368+ DimTrans::Type::INPUTDIM,
369+ common::errors::InvalidArgument (
370+ " The returned dim_trans must be INPUTDIM." ));
371+ std::shared_ptr<InputDim> inputdim =
372+ std::dynamic_pointer_cast<InputDim>(dim);
373+ int64_t nmesh = static_cast <int64_t >(mesh_shape.size ());
374+ int64_t input_axis = inputdim->input_dim ();
375+
376+ // Check whether the sharded dim can be sharded on
377+ // each mesh dimension. The dimension should be
378+ // divisible by the mesh size that it is sharded on
379+ for (int64_t imesh = 0 ; imesh < nmesh; imesh++) {
380+ (*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0 );
381+ }
382+
383+ if (first_shard_idx == -1 ) {
384+ first_shard_idx = input_axis;
385+ first_sharded_shape = input_shape[input_axis];
386+ }
387+
388+ if (sharded_input_dims.count (input_axis) > 0 ) {
389+ for (const auto & dim : input_dims_mapping[input_axis]) {
390+ mesh_shape_prod *= mesh_shape[dim];
391+ }
392+ if ((ret_size % mesh_shape_prod == 0 ) &&
393+ (first_sharded_shape % mesh_shape_prod == 0 )) {
394+ ret_dim_trans.push_back (dim);
395+ } else {
396+ break ;
397+ }
398+ } else {
399+ break ;
400+ }
401+ }
402+ }
403+ } else if (type == DimTrans::Type::SINGLETON) {
404+ }
405+ return ret_dim_trans;
406+ }
407+
275408void GetUsedInputDim (const std::shared_ptr<DimTrans> dim_trans,
276409 std::set<int64_t >* seen_dims) {
277410 if (dim_trans->type () == DimTrans::Type::INPUTDIM) {
@@ -311,6 +444,27 @@ InferFromDimTrans(const DistMetaTensor& input_spec,
311444 return InferFromDimTrans (input_spec, input_shape, dim_trans);
312445}
313446
447+ std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
448+ InferFromDimTransCoShard (
449+ const DistMetaTensor& input_spec,
450+ const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
451+ auto input_shape = phi::vectorize (input_spec.dims ());
452+ // deal with reshape xshape in dynamic
453+ if (input_shape[0 ] == 0 &&
454+ input_shape.size () !=
455+ input_spec.dist_attr ().multi_dims_mapping ().size ()) {
456+ input_shape.erase (input_shape.begin ());
457+ }
458+ PADDLE_ENFORCE_EQ (input_shape.size (),
459+ input_spec.dist_attr ().multi_dims_mapping ().size (),
460+ common::errors::InvalidArgument (
461+ " The Tensor X's rank [%d] and X's "
462+ " dims_mapping size [%d] are not matched." ,
463+ input_shape.size (),
464+ input_spec.dist_attr ().multi_dims_mapping ().size ()));
465+ return InferFromDimTransCoShard (input_spec, input_shape, dim_trans);
466+ }
467+
314468std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
315469InferFromDimTrans (const DistMetaTensor& input,
316470 const std::vector<int64_t >& input_shape,
@@ -400,4 +554,105 @@ InferFromDimTrans(const DistMetaTensor& input,
400554 return {new_input_dims_mapping, out_dims_mapping};
401555}
402556
557+ std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
558+ InferFromDimTransCoShard (
559+ const DistMetaTensor& input,
560+ const std::vector<int64_t >& input_shape,
561+ const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
562+ const std::vector<std::vector<int64_t >>& input_dims_mapping =
563+ input.dist_attr ().multi_dims_mapping ();
564+ const ProcessMesh& mesh = input.dist_attr ().process_mesh ();
565+ const std::vector<int64_t >& mesh_shape = mesh.shape ();
566+
567+ std::set<int64_t > sharded_input_dims;
568+ for (int64_t i = 0 , n = static_cast <int64_t >(input_dims_mapping.size ());
569+ i < n;
570+ ++i) {
571+ if (std::any_of (input_dims_mapping[i].begin (),
572+ input_dims_mapping[i].end (),
573+ [](int64_t dim) { return dim > -1 ; })) {
574+ sharded_input_dims.insert (i);
575+ }
576+ }
577+ int64_t ndim = static_cast <int64_t >(input_shape.size ());
578+ int64_t nmesh = static_cast <int64_t >(mesh_shape.size ());
579+ std::vector<std::vector<bool >> shardable (ndim,
580+ std::vector<bool >(nmesh, true ));
581+
582+ std::set<int64_t > seen_input_dims;
583+ for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
584+ GetUsedInputDim (trans, &seen_input_dims);
585+ }
586+
587+ for (int64_t idim = 0 ; idim < ndim; idim++) {
588+ bool seen = seen_input_dims.count (idim);
589+ if (!seen) {
590+ shardable[idim].assign (nmesh, seen);
591+ }
592+ }
593+
594+ // get the map from sharded input dimensions to output dimensions.
595+ // key is src dim, value is dst dim.
596+ std::vector<int64_t > dim_map_src2tgt (ndim, -1 );
597+ std::unordered_map<int , std::vector<int >> dim_map_dst2src;
598+ for (int64_t i = 0 , n = static_cast <int64_t >(dim_trans.size ()); i < n; i++) {
599+ std::vector<std::shared_ptr<DimTrans>> dims =
600+ GetDimTransCoShard (dim_trans[i],
601+ input_shape,
602+ mesh_shape,
603+ input_dims_mapping,
604+ sharded_input_dims,
605+ &shardable,
606+ &seen_input_dims);
607+ for (auto dim : dims) {
608+ if (dim->type () == DimTrans::Type::INPUTDIM) {
609+ std::shared_ptr<InputDim> inputdim =
610+ std::dynamic_pointer_cast<InputDim>(dim);
611+ dim_map_src2tgt[inputdim->input_dim ()] = i;
612+ dim_map_dst2src[i].push_back (inputdim->input_dim ());
613+ }
614+ }
615+ }
616+
617+ std::vector<std::vector<int64_t >> out_dims_mapping (dim_trans.size ());
618+ std::vector<std::vector<int64_t >> new_input_dims_mapping (
619+ input_dims_mapping.size ());
620+ for (size_t i = 0 ; i < input_dims_mapping.size (); i++) {
621+ if (std::any_of (input_dims_mapping[i].begin (),
622+ input_dims_mapping[i].end (),
623+ [](int64_t dim) { return dim >= 0 ; })) {
624+ new_input_dims_mapping[i] = input_dims_mapping[i];
625+ }
626+ }
627+
628+ // set output dims mapping with corresponding input dimensions.
629+ // if one input dimension is sharded on a unshardable mesh after
630+ // splitting, we need to make it replicated.
631+ for (int64_t i = 0 ; i < ndim; i++) {
632+ std::vector<int64_t > mesh_dims = input_dims_mapping[i];
633+ for (int64_t mesh_dim : mesh_dims) {
634+ if (mesh_dim > -1 && shardable[i][mesh_dim] && dim_map_src2tgt[i] > -1 ) {
635+ int dst_dim = dim_map_src2tgt[i];
636+ out_dims_mapping[dst_dim].push_back (mesh_dim);
637+
638+ auto src_dim = dim_map_dst2src[dst_dim];
639+ auto min_dim = std::min_element (src_dim.begin (), src_dim.end ());
640+ if (*min_dim == i) {
641+ new_input_dims_mapping[*min_dim].push_back (mesh_dim);
642+ continue ;
643+ }
644+ new_input_dims_mapping[*min_dim].insert (
645+ new_input_dims_mapping[*min_dim].end (),
646+ new_input_dims_mapping[i].begin (),
647+ new_input_dims_mapping[i].end ());
648+ new_input_dims_mapping[i] = {};
649+ } else {
650+ new_input_dims_mapping[i] = {};
651+ }
652+ }
653+ }
654+
655+ return {new_input_dims_mapping, out_dims_mapping};
656+ }
657+
403658} // namespace phi::distributed
0 commit comments