Skip to content

Commit d18fd40

Browse files
authored
feat: more JLL changes for #788 (#790)
1 parent 230d77e commit d18fd40

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
#include "shardy/integrations/c/attributes.h"
7979
#include "xla/pjrt/mlir_to_hlo.h"
8080
#include "xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h"
81+
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"
8182

8283
// IFRT
8384
#include "xla/python/ifrt/array.h"
@@ -1956,4 +1957,31 @@ hloShardingFromTensorShardingAttr(mlir::sdy::TensorShardingAttr attr,
19561957
xla::sdy::convertToHloSharding(attr, get_mesh_attr, manual_axes));
19571958
}
19581959

1960+
extern "C" mlir::sdy::TensorShardingAttr
1961+
hloShardingToTensorShardingAttr(const xla::HloSharding *hloSharding,
1962+
mlir::sdy::MeshAttr meshAttr, int64_t rank,
1963+
const bool *isClosed, const int64_t *priority) {
1964+
const SmallDenseMap<int64_t, StringRef> deviceIdToMaximalMeshName =
1965+
SmallDenseMap<int64_t, StringRef>();
1966+
mlir::sdy::TensorShardingAttr tensorShardingAttr =
1967+
xla::sdy::convertToSdySharding(*hloSharding, meshAttr,
1968+
deviceIdToMaximalMeshName, rank,
1969+
/*openDims=*/true);
1970+
1971+
for (int64_t i = 0; i < rank; i++) {
1972+
auto oldDimSharding = tensorShardingAttr.getDimSharding(i);
1973+
1974+
std::optional<int64_t> dimPriority;
1975+
if (priority[i] > 0)
1976+
dimPriority = priority[i];
1977+
1978+
tensorShardingAttr = tensorShardingAttr.replaceDimSharding(
1979+
i, mlir::sdy::DimensionShardingAttr::get(oldDimSharding.getContext(),
1980+
oldDimSharding.getAxes(),
1981+
isClosed[i], dimPriority));
1982+
}
1983+
1984+
return tensorShardingAttr;
1985+
}
1986+
19591987
#pragma endregion

deps/ReactantExtra/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ cc_library(
492492
"-Wl,-exported_symbol,_ClientGetAddressableDevices",
493493
"-Wl,-exported_symbol,_hloShardingFromTensorShardingAttr",
494494
"-Wl,-exported_symbol,_op_sharding_*",
495+
"-Wl,-exported_symbol,_hloShardingToTensorShardingAttr",
495496
]}),
496497
deps = [
497498
"@enzyme//:EnzymeMLIR",
@@ -541,6 +542,7 @@ cc_library(
541542
"@xla//xla/pjrt/distributed:client",
542543
"@xla//xla/pjrt/distributed:service",
543544
"@xla//xla/service/spmd/shardy/stablehlo_round_trip:export_shardings",
545+
"@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import",
544546

545547
"@xla//xla:xla_proto_cc",
546548
"@xla//xla:xla_proto_cc_impl",

0 commit comments

Comments
 (0)