@@ -734,10 +734,57 @@ void OpShardingToJLOpSharding(const xla::OpSharding &op_sharding,
734734 jl_op_sharding->shard_group_type = op_sharding.shard_group_type ();
735735}
736736
737- // xla::OpSharding JLOpShardingToOpSharding(const JLOpSharding &jl_op_sharding)
738- // {
739- // xla::OpSharding op_sharding;
740- // }
737+ xla::OpSharding JLOpShardingToOpSharding (const JLOpSharding &jl_op_sharding) {
738+ xla::OpSharding op_sharding;
739+
740+ op_sharding.set_type (static_cast <xla::OpSharding_Type>(jl_op_sharding.type ));
741+ op_sharding.set_replicate_on_last_tile_dim (
742+ jl_op_sharding.replicate_on_last_tile_dim );
743+
744+ xla::ShapeProto *mutable_shape_proto = op_sharding.mutable_tile_shape ();
745+
746+ for (int i = 0 ; i < jl_op_sharding.n_tile_dimensions ; i++) {
747+ mutable_shape_proto->add_dimensions (jl_op_sharding.tile_dimensions [i]);
748+ }
749+
750+ if (jl_op_sharding.n_layout_minor_to_major > 0 ) {
751+ auto *mutable_layout = mutable_shape_proto->mutable_layout ();
752+ for (int i = 0 ; i < jl_op_sharding.n_layout_minor_to_major ; i++) {
753+ mutable_layout->add_minor_to_major (
754+ jl_op_sharding.layout_minor_to_major [i]);
755+ }
756+ }
757+
758+ for (int i = 0 ; i < jl_op_sharding.n_tile_dimensions ; i++) {
759+ op_sharding.add_last_tile_dims (
760+ static_cast <xla::OpSharding_Type>(jl_op_sharding.last_tile_dims [i]));
761+ }
762+
763+ for (int i = 0 ; i < jl_op_sharding.n_tile_assignment_dimensions ; i++) {
764+ op_sharding.add_tile_assignment_dimensions (
765+ jl_op_sharding.tile_assignment_dimensions [i]);
766+ }
767+
768+ for (int i = 0 ; i < jl_op_sharding.n_tile_assignment_devices ; i++) {
769+ op_sharding.add_tile_assignment_devices (
770+ jl_op_sharding.tile_assignment_devices [i]);
771+ }
772+
773+ for (int i = 0 ; i < jl_op_sharding.n_iota_reshape_dims ; i++) {
774+ op_sharding.add_iota_reshape_dims (jl_op_sharding.iota_reshape_dims [i]);
775+ }
776+
777+ for (int i = 0 ; i < jl_op_sharding.n_iota_transpose_perm ; i++) {
778+ op_sharding.add_iota_transpose_perm (jl_op_sharding.iota_transpose_perm [i]);
779+ }
780+
781+ op_sharding.set_is_shard_group (jl_op_sharding.is_shard_group );
782+ op_sharding.set_shard_group_id (jl_op_sharding.shard_group_id );
783+ op_sharding.set_shard_group_type (static_cast <xla::OpSharding_ShardGroupType>(
784+ jl_op_sharding.shard_group_type ));
785+
786+ return op_sharding;
787+ }
741788
742789typedef PjRtFuture<> FutureType;
743790extern " C" void FreeFuture (FutureType *Future) { delete Future; }
0 commit comments