@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
419419 SmallVector<Value> dynDims, dynDevice;
420420 for (auto dim : adaptor.getDimsDynamic ()) {
421421 // type conversion should be 1:1 for ints
422- assert (dim.size () == 1 );
423- dynDims.emplace_back (dim[0 ]);
422+ dynDims.emplace_back (llvm::getSingleElement (dim));
424423 }
425424 // same for device
426425 for (auto device : adaptor.getDeviceDynamic ()) {
427- assert (device.size () == 1 );
428- dynDevice.emplace_back (device[0 ]);
426+ dynDevice.emplace_back (llvm::getSingleElement (device));
429427 }
430428
431429 // To keep the code simple, convert dims/device to values when they are
@@ -771,18 +769,17 @@ struct ConvertMeshToMPIPass
771769 typeConverter.addConversion ([](Type type) { return type; });
772770
773771 // convert mesh::ShardingType to a tuple of RankedTensorTypes
774- typeConverter.addConversion (
775- [](ShardingType type,
776- SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
777- auto i16 = IntegerType::get (type.getContext (), 16 );
778- auto i64 = IntegerType::get (type.getContext (), 64 );
779- std::array<int64_t , 2 > shp = {ShapedType::kDynamic ,
780- ShapedType::kDynamic };
781- results.emplace_back (RankedTensorType::get (shp, i16 ));
782- results.emplace_back (RankedTensorType::get (shp, i64 )); // actually ?x2
783- results.emplace_back (RankedTensorType::get (shp, i64 ));
784- return success ();
785- });
772+ typeConverter.addConversion ([](ShardingType type,
773+ SmallVectorImpl<Type> &results)
774+ -> std::optional<LogicalResult> {
775+ auto i16 = IntegerType::get (type.getContext (), 16 );
776+ auto i64 = IntegerType::get (type.getContext (), 64 );
777+ std::array<int64_t , 2 > shp = {ShapedType::kDynamic , ShapedType::kDynamic };
778+ results.emplace_back (RankedTensorType::get (shp, i16 ));
779+ results.emplace_back (RankedTensorType::get (shp, i64 )); // actually ?x2
780+ results.emplace_back (RankedTensorType::get (shp, i64 ));
781+ return success ();
782+ });
786783
787784 // To 'extract' components, a UnrealizedConversionCastOp is expected
788785 // to define the input
0 commit comments