-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][NVVM] Add support for Convert Ops with rs-rounding mode #165736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add support for Convert Ops with rs-rounding mode #165736
Conversation
…s in NVVM Dialect
|
@llvm/pr-subscribers-mlir-llvm Author: Stefan Mada (smada3) ChangesAdded support for all stochastic rounding conversion intrinsics in NVVM Dialect. Below table shows the mapping. Note that the left 2 columns already existed - this PR only adds support for it from the NVVM dialect:
Patch is 32.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165736.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4f483859ac18d..ce041e9ad854c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1604,10 +1604,11 @@ def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
+def FPRoundingModeRS : I32EnumAttrCase<"RS", 6, "rs">;
def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
[FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
- FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
+ FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -1921,6 +1922,95 @@ def NVVM_ConvertF6x2ToF16x2Op :
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+//===----------------------------------------------------------------------===//
+// NVVM Stochastic Rounding Conversion Ops
+//===----------------------------------------------------------------------===//
+
+// Base class for conversions from F32x2 to FPx2 formats
+// (F16x2, BF16x2)
+// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
+// as currently only support .rs rounding mode
+class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type resultType> :
+ NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ Results<(outs resultType:$dst)>,
+ Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu)> {
+ let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let description = [{
+ Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
+ rounding (.rs) mode with randomness provided by the `rbits` parameter. The
+ `relu` attribute clamps negative results to 0. The `sat` attribute determines
+ saturation behavior.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ llvm::Intrinsic::ID getIntrinsicID();
+ }];
+
+ string llvmBuilder = [{
+ auto intId = op.getIntrinsicID();
+ $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+ }];
+ }
+
+// F32x2 -> F16x2 with stochastic rounding
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+
+// F32x2 -> BF16x2 with stochastic rounding
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+
+// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
+// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
+class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type resultType> :
+ NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ Results<(outs resultType:$dst)>,
+ Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy)> {
+ let summary = "Convert vector<4xf32> to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let description = [{
+ Converts a vector<4xf32> to packed }] # dstFormat # [{ format using
+ stochastic rounding (.rs) mode with randomness provided by the `rbits`
+ parameter. The `dstTy` attribute specifies the target format. The `relu`
+ attribute clamps negative results to 0. The `sat` attribute determines
+ saturation behavior.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let assemblyFormat = "$src `,` $rbits attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ llvm::Intrinsic::ID getIntrinsicID();
+ }];
+
+ string llvmBuilder = [{
+ auto intId = op.getIntrinsicID();
+ $dst = createIntrinsicCall(builder, intId, {$src, $rbits});
+ }];
+}
+
+// F32x4 -> F8x4 with stochastic rounding (supports E4M3FN, E5M2)
+def NVVM_ConvertF32x4ToF8x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f8x4", "convert.f32x4.to.f8x4", VectorOfLengthAndType<[4], [I8]>>;
+
+// F32x4 -> F6x4 with stochastic rounding (supports E2M3FN, E3M2FN)
+def NVVM_ConvertF32x4ToF6x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f6x4", "convert.f32x4.to.f6x4", VectorOfLengthAndType<[4], [I8]>>;
+
+// F32x4 -> F4x4 with stochastic rounding (supports E2M1FN)
+def NVVM_ConvertF32x4ToF4x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f4x4", "convert.f32x4.to.f4x4", I16>;
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4dbcc1d4b..1baca404da21a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -365,6 +365,83 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Stochastic Rounding Conversion Ops - getIntrinsicID implementations
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x2 to f16x2.");
+ return success();
+}
+
+LogicalResult ConvertF32x2ToBF16x2Op::verify() {
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x2 to bf16x2.");
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF8x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x4 to f8x4.");
+
+ if (getSat() == SaturationMode::NONE)
+ return emitOpError("Only SATFINITE saturation mode is supported for "
+ "conversions from f32x4 to f8x4.");
+
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f32x4 to f8x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF6x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x4 to f6x4.");
+
+ if (getSat() == SaturationMode::NONE)
+ return emitOpError("Only SATFINITE saturation mode is supported for "
+ "conversions from f32x4 to f6x4.");
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x4 to f6x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF4x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x4 to f4x4.");
+
+ if (getSat() == SaturationMode::NONE)
+ return emitOpError("Only SATFINITE saturation mode is supported for "
+ "conversions from f32x4 to f4x4.");
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from "
+ "f32x4 to f4x4.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2412,6 +2489,84 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
+// Helper macros for F32x2 to FPx2 stochastic rounding intrinsics
+#define GET_F32x2_TO_F16x2_RS_ID(has_relu, has_satf) \
+ (has_relu ? (has_satf ? llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite \
+ : llvm::Intrinsic::nvvm_ff2f16x2_rs_relu) \
+ : (has_satf ? llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite \
+ : llvm::Intrinsic::nvvm_ff2f16x2_rs))
+
+#define GET_F32x2_TO_BF16x2_RS_ID(has_relu, has_satf) \
+ (has_relu ? (has_satf ? llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite \
+ : llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu) \
+ : (has_satf ? llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite \
+ : llvm::Intrinsic::nvvm_ff2bf16x2_rs))
+
+llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
+ bool hasRelu = getRelu();
+ bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+ return GET_F32x2_TO_F16x2_RS_ID(hasRelu, hasSatFinite);
+}
+
+llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
+ bool hasRelu = getRelu();
+ bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+ return GET_F32x2_TO_BF16x2_RS_ID(hasRelu, hasSatFinite);
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
new file mode 100644
index 0000000000000..7ffafdeadb62e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Test invalid target architecture (sm_100 instead of sm_100a)
+gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
+ func.func @convert_rs() {
+ %f1 = llvm.mlir.constant(1.0 : f32) : f32
+ %f2 = llvm.mlir.constant(2.0 : f32) : f32
+ %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
+ // expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
+ %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ return
+ }
+}
+
+// -----
+
+// Test that FP8/FP6/FP4 conversions require satfinite mode
+llvm.func @invalid_sat_mode_f8x4_e4m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f8x4_e5m2(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f8E5M2)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f6x4_e2m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f6x4_e3m2(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f4x4_e2m1(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> i16 (f4E2M1FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test that operations require stochastic rounding mode
+llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ llvm.return %res : vector<2xf16>
+}
+
+// -----
+
+llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// -----
+
+llvm.func @invalid_rnd_mode_f8x4_e4m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_rnd_mode_f4x4_e2m1(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> i16 (f4E2M1FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
+llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E3M4)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_dst_type_f8x4_e8m0(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E8M0FNU)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test invalid destination types for f6x4 (should only accept f6E2M3FN, f6E3M2FN)
+llvm.func @invalid_dst_type_f6x4_f8(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test invalid destination type for f4x4 (should only accept f4E2M1FN)
+llvm.func @invalid_dst_type_f4x4_f6(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only 'f4E2M1FN' type is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f6E2M3FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test invalid rounding modes for non-stochastic ops
+llvm.func @convert_float_to_tf32_rs_not_supported(%src : f32) -> i32 {
+ // expected-error @below {{Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.}}
+ %res = nvvm.convert.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rs>}
+ llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
new file mode 100644
index 0000000000000..227726ac300b5
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -0,0 +1,252 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// Test valid architectures work
+
+// Valid case on sm_100a
+gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
+ func.func @convert_rs() {
+ ...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Stefan Mada (smada3) ChangesAdded support for all stochastic rounding conversion intrinsics in NVVM Dialect. Below table shows the mapping. Note that the left 2 columns already existed - this PR only adds support for it from the NVVM dialect:
Patch is 32.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165736.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4f483859ac18d..ce041e9ad854c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1604,10 +1604,11 @@ def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
+def FPRoundingModeRS : I32EnumAttrCase<"RS", 6, "rs">;
def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
[FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
- FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
+ FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -1921,6 +1922,95 @@ def NVVM_ConvertF6x2ToF16x2Op :
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+//===----------------------------------------------------------------------===//
+// NVVM Stochastic Rounding Conversion Ops
+//===----------------------------------------------------------------------===//
+
+// Base class for conversions from F32x2 to FPx2 formats
+// (F16x2, BF16x2)
+// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
+// as currently only support .rs rounding mode
+class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type resultType> :
+ NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ Results<(outs resultType:$dst)>,
+ Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu)> {
+ let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let description = [{
+ Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
+ rounding (.rs) mode with randomness provided by the `rbits` parameter. The
+ `relu` attribute clamps negative results to 0. The `sat` attribute determines
+ saturation behavior.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ llvm::Intrinsic::ID getIntrinsicID();
+ }];
+
+ string llvmBuilder = [{
+ auto intId = op.getIntrinsicID();
+ $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+ }];
+ }
+
+// F32x2 -> F16x2 with stochastic rounding
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+
+// F32x2 -> BF16x2 with stochastic rounding
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+
+// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
+// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
+class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type resultType> :
+ NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ Results<(outs resultType:$dst)>,
+ Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy)> {
+ let summary = "Convert vector<4xf32> to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let description = [{
+ Converts a vector<4xf32> to packed }] # dstFormat # [{ format using
+ stochastic rounding (.rs) mode with randomness provided by the `rbits`
+ parameter. The `dstTy` attribute specifies the target format. The `relu`
+ attribute clamps negative results to 0. The `sat` attribute determines
+ saturation behavior.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let assemblyFormat = "$src `,` $rbits attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ llvm::Intrinsic::ID getIntrinsicID();
+ }];
+
+ string llvmBuilder = [{
+ auto intId = op.getIntrinsicID();
+ $dst = createIntrinsicCall(builder, intId, {$src, $rbits});
+ }];
+}
+
+// F32x4 -> F8x4 with stochastic rounding (supports E4M3FN, E5M2)
+def NVVM_ConvertF32x4ToF8x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f8x4", "convert.f32x4.to.f8x4", VectorOfLengthAndType<[4], [I8]>>;
+
+// F32x4 -> F6x4 with stochastic rounding (supports E2M3FN, E3M2FN)
+def NVVM_ConvertF32x4ToF6x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f6x4", "convert.f32x4.to.f6x4", VectorOfLengthAndType<[4], [I8]>>;
+
+// F32x4 -> F4x4 with stochastic rounding (supports E2M1FN)
+def NVVM_ConvertF32x4ToF4x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f4x4", "convert.f32x4.to.f4x4", I16>;
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4dbcc1d4b..1baca404da21a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -365,6 +365,83 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Stochastic Rounding Conversion Ops - getIntrinsicID implementations
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x2 to f16x2.");
+ return success();
+}
+
+LogicalResult ConvertF32x2ToBF16x2Op::verify() {
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x2 to bf16x2.");
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF8x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x4 to f8x4.");
+
+ if (getSat() == SaturationMode::NONE)
+ return emitOpError("Only SATFINITE saturation mode is supported for "
+ "conversions from f32x4 to f8x4.");
+
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f32x4 to f8x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF6x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x4 to f6x4.");
+
+ if (getSat() == SaturationMode::NONE)
+ return emitOpError("Only SATFINITE saturation mode is supported for "
+ "conversions from f32x4 to f6x4.");
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x4 to f6x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF4x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x4 to f4x4.");
+
+ if (getSat() == SaturationMode::NONE)
+ return emitOpError("Only SATFINITE saturation mode is supported for "
+ "conversions from f32x4 to f4x4.");
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from "
+ "f32x4 to f4x4.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2412,6 +2489,84 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
+// Helper macros for F32x2 to FPx2 stochastic rounding intrinsics
+#define GET_F32x2_TO_F16x2_RS_ID(has_relu, has_satf) \
+ (has_relu ? (has_satf ? llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite \
+ : llvm::Intrinsic::nvvm_ff2f16x2_rs_relu) \
+ : (has_satf ? llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite \
+ : llvm::Intrinsic::nvvm_ff2f16x2_rs))
+
+#define GET_F32x2_TO_BF16x2_RS_ID(has_relu, has_satf) \
+ (has_relu ? (has_satf ? llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite \
+ : llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu) \
+ : (has_satf ? llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite \
+ : llvm::Intrinsic::nvvm_ff2bf16x2_rs))
+
+llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
+ bool hasRelu = getRelu();
+ bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+ return GET_F32x2_TO_F16x2_RS_ID(hasRelu, hasSatFinite);
+}
+
+llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
+ bool hasRelu = getRelu();
+ bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+ return GET_F32x2_TO_BF16x2_RS_ID(hasRelu, hasSatFinite);
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
new file mode 100644
index 0000000000000..7ffafdeadb62e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Test invalid target architecture (sm_100 instead of sm_100a)
+gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
+ func.func @convert_rs() {
+ %f1 = llvm.mlir.constant(1.0 : f32) : f32
+ %f2 = llvm.mlir.constant(2.0 : f32) : f32
+ %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
+ // expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
+ %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ return
+ }
+}
+
+// -----
+
+// Test that FP8/FP6/FP4 conversions require satfinite mode
+llvm.func @invalid_sat_mode_f8x4_e4m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f8x4_e5m2(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f8E5M2)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f6x4_e2m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f6x4_e3m2(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_sat_mode_f4x4_e2m1(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> i16 (f4E2M1FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test that operations require stochastic rounding mode
+llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ llvm.return %res : vector<2xf16>
+}
+
+// -----
+
+llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// -----
+
+llvm.func @invalid_rnd_mode_f8x4_e4m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_rnd_mode_f4x4_e2m1(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> i16 (f4E2M1FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
+llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E3M4)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_dst_type_f8x4_e8m0(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E8M0FNU)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test invalid destination types for f6x4 (should only accept f6E2M3FN, f6E3M2FN)
+llvm.func @invalid_dst_type_f6x4_f8(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test invalid destination type for f4x4 (should only accept f4E2M1FN)
+llvm.func @invalid_dst_type_f4x4_f6(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only 'f4E2M1FN' type is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f6E2M3FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test invalid rounding modes for non-stochastic ops
+llvm.func @convert_float_to_tf32_rs_not_supported(%src : f32) -> i32 {
+ // expected-error @below {{Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.}}
+ %res = nvvm.convert.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rs>}
+ llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
new file mode 100644
index 0000000000000..227726ac300b5
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -0,0 +1,252 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// Test valid architectures work
+
+// Valid case on sm_100a
+gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
+ func.func @convert_rs() {
+ ...
[truncated]
|
| provided by the `rbits` parameter. The `dstTy` attribute specifies the | ||
| target format. The `relu` attribute clamps negative results to 0. | ||
|
|
||
| Note: These operations always use RS rounding mode and SATFINITE saturation mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, Thanks for this note!
durga4github
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a few minor nitpicks.
|
@smada3 , I think we do not need the entire table in the commit-message and could simplify it. |
Wolfram70
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as well, thanks!
Added NVVM dialect operations for stochastic rounding (.rs) conversions
from F32 to various packed floating-point formats. These operations map
to existing PTX instructions and LLVM intrinsics.
Supported conversions:
All operations support stochastic rounding with randomness provided via
an rbits parameter, and optional relu and saturation modifiers.