|
78 | 78 | #include "shardy/integrations/c/attributes.h" |
79 | 79 | #include "xla/pjrt/mlir_to_hlo.h" |
80 | 80 | #include "xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h" |
| 81 | +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" |
81 | 82 |
|
82 | 83 | // IFRT |
83 | 84 | #include "xla/python/ifrt/array.h" |
@@ -1956,4 +1957,31 @@ hloShardingFromTensorShardingAttr(mlir::sdy::TensorShardingAttr attr, |
1956 | 1957 | xla::sdy::convertToHloSharding(attr, get_mesh_attr, manual_axes)); |
1957 | 1958 | } |
1958 | 1959 |
|
| 1960 | +extern "C" mlir::sdy::TensorShardingAttr |
| 1961 | +hloShardingToTensorShardingAttr(const xla::HloSharding *hloSharding, |
| 1962 | + mlir::sdy::MeshAttr meshAttr, int64_t rank, |
| 1963 | + const bool *isClosed, const int64_t *priority) { |
| 1964 | + const SmallDenseMap<int64_t, StringRef> deviceIdToMaximalMeshName = |
| 1965 | + SmallDenseMap<int64_t, StringRef>(); |
| 1966 | + mlir::sdy::TensorShardingAttr tensorShardingAttr = |
| 1967 | + xla::sdy::convertToSdySharding(*hloSharding, meshAttr, |
| 1968 | + deviceIdToMaximalMeshName, rank, |
| 1969 | + /*openDims=*/true); |
| 1970 | + |
| 1971 | + for (int64_t i = 0; i < rank; i++) { |
| 1972 | + auto oldDimSharding = tensorShardingAttr.getDimSharding(i); |
| 1973 | + |
| 1974 | + std::optional<int64_t> dimPriority; |
| 1975 | + if (priority[i] > 0) |
| 1976 | + dimPriority = priority[i]; |
| 1977 | + |
| 1978 | + tensorShardingAttr = tensorShardingAttr.replaceDimSharding( |
| 1979 | + i, mlir::sdy::DimensionShardingAttr::get(oldDimSharding.getContext(), |
| 1980 | + oldDimSharding.getAxes(), |
| 1981 | + isClosed[i], dimPriority)); |
| 1982 | + } |
| 1983 | + |
| 1984 | + return tensorShardingAttr; |
| 1985 | +} |
| 1986 | + |
1959 | 1987 | #pragma endregion |
0 commit comments