Skip to content

Commit dd404b0

Browse files
committed
feat: convert JLOpSharding to OpSharding
1 parent 33acc4b commit dd404b0

File tree

1 file changed

+51
-4
lines changed

1 file changed

+51
-4
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,10 +734,57 @@ void OpShardingToJLOpSharding(const xla::OpSharding &op_sharding,
734734
jl_op_sharding->shard_group_type = op_sharding.shard_group_type();
735735
}
736736

737-
// xla::OpSharding JLOpShardingToOpSharding(const JLOpSharding &jl_op_sharding)
738-
// {
739-
// xla::OpSharding op_sharding;
740-
// }
737+
xla::OpSharding JLOpShardingToOpSharding(const JLOpSharding &jl_op_sharding) {
738+
xla::OpSharding op_sharding;
739+
740+
op_sharding.set_type(static_cast<xla::OpSharding_Type>(jl_op_sharding.type));
741+
op_sharding.set_replicate_on_last_tile_dim(
742+
jl_op_sharding.replicate_on_last_tile_dim);
743+
744+
xla::ShapeProto *mutable_shape_proto = op_sharding.mutable_tile_shape();
745+
746+
for (int i = 0; i < jl_op_sharding.n_tile_dimensions; i++) {
747+
mutable_shape_proto->add_dimensions(jl_op_sharding.tile_dimensions[i]);
748+
}
749+
750+
if (jl_op_sharding.n_layout_minor_to_major > 0) {
751+
auto *mutable_layout = mutable_shape_proto->mutable_layout();
752+
for (int i = 0; i < jl_op_sharding.n_layout_minor_to_major; i++) {
753+
mutable_layout->add_minor_to_major(
754+
jl_op_sharding.layout_minor_to_major[i]);
755+
}
756+
}
757+
758+
for (int i = 0; i < jl_op_sharding.n_tile_dimensions; i++) {
759+
op_sharding.add_last_tile_dims(
760+
static_cast<xla::OpSharding_Type>(jl_op_sharding.last_tile_dims[i]));
761+
}
762+
763+
for (int i = 0; i < jl_op_sharding.n_tile_assignment_dimensions; i++) {
764+
op_sharding.add_tile_assignment_dimensions(
765+
jl_op_sharding.tile_assignment_dimensions[i]);
766+
}
767+
768+
for (int i = 0; i < jl_op_sharding.n_tile_assignment_devices; i++) {
769+
op_sharding.add_tile_assignment_devices(
770+
jl_op_sharding.tile_assignment_devices[i]);
771+
}
772+
773+
for (int i = 0; i < jl_op_sharding.n_iota_reshape_dims; i++) {
774+
op_sharding.add_iota_reshape_dims(jl_op_sharding.iota_reshape_dims[i]);
775+
}
776+
777+
for (int i = 0; i < jl_op_sharding.n_iota_transpose_perm; i++) {
778+
op_sharding.add_iota_transpose_perm(jl_op_sharding.iota_transpose_perm[i]);
779+
}
780+
781+
op_sharding.set_is_shard_group(jl_op_sharding.is_shard_group);
782+
op_sharding.set_shard_group_id(jl_op_sharding.shard_group_id);
783+
op_sharding.set_shard_group_type(static_cast<xla::OpSharding_ShardGroupType>(
784+
jl_op_sharding.shard_group_type));
785+
786+
return op_sharding;
787+
}
741788

742789
typedef PjRtFuture<> FutureType;
743790
extern "C" void FreeFuture(FutureType *Future) { delete Future; }

0 commit comments

Comments
 (0)