Skip to content

Commit

Permalink
[ElementwiseOpToLLVM] Replace bitwise fp32 to bf16 conversion with In…
Browse files Browse the repository at this point in the history
…tel SPIR-V Extension (triton-lang#1074)

Related to issue triton-lang#1001.

This pass is already lowering `arith::TruncFOp` and `arith::ExtFOp`, so
there was the original suggesting of lowering to arith operators didn't
make sense, but I have replace most of the bit operations with calls to
an Intel SPIR-V extension that translates to a MOV instruction in vISA.
I couldn't remove the round to zero mode of `convertFp32ToBf16`, since
the extension only supports round to closest even. The code that calls
`convertFp32ToBf16` uses round to closest even by default, so that's
fine.
  • Loading branch information
FMarno authored May 13, 2024
1 parent 3eb4027 commit 69e0020
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 38 deletions.
51 changes: 51 additions & 0 deletions test/Conversion/arith_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s

// CHECK-DAG: llvm.func spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(f32) -> i16
// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>

// CHECK-LABEL: llvm.func spir_kernelcc @float_to_bfloat_conversion(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32, f32, f32, f32)>) -> !llvm.struct<(i16, i16, i16, i16)>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @float_to_bfloat_conversion(%arg0 : tensor<512xf32, #blocked>) -> tensor<512xbf16, #blocked>{
// CHECK: builtin.unrealized_conversion_cast %[[VAL_0]] : !llvm.struct<(f32, f32, f32, f32)> to tensor<512xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_6:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_2]]) : (f32) -> i16
// CHECK: %[[VAL_7:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_3]]) : (f32) -> i16
// CHECK: %[[VAL_8:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_4]]) : (f32) -> i16
// CHECK: %[[VAL_9:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_5]]) : (f32) -> i16
// CHECK: %[[VAL_10:.*]] = llvm.mlir.undef : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_10]][0] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_11]][1] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_12]][2] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_14:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_13]][3] : !llvm.struct<(i16, i16, i16, i16)>
%1 = arith.truncf %arg0 : tensor<512xf32, #blocked> to tensor<512xbf16, #blocked>
// CHECK: llvm.return %[[VAL_14]] : !llvm.struct<(i16, i16, i16, i16)>
tt.return %1: tensor<512xbf16, #blocked>
}

// CHECK-LABEL: llvm.func spir_kernelcc @bfloat_to_float_conversion(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i16, i16, i16, i16)>) -> !llvm.struct<(f32, f32, f32, f32)>
tt.func @bfloat_to_float_conversion(%arg0 : tensor<512xbf16, #blocked>) -> tensor<512xf32, #blocked>{
// CHECK: builtin.unrealized_conversion_cast %[[VAL_0]] : !llvm.struct<(i16, i16, i16, i16)> to tensor<512xbf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_6:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_2]]) : (i16) -> f32
// CHECK: %[[VAL_7:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_3]]) : (i16) -> f32
// CHECK: %[[VAL_8:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_4]]) : (i16) -> f32
// CHECK: %[[VAL_9:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_5]]) : (i16) -> f32
// CHECK: %[[VAL_10:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_10]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_11]][1] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_12]][2] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_14:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_13]][3] : !llvm.struct<(f32, f32, f32, f32)>
%1 = arith.extf %arg0 : tensor<512xbf16, #blocked> to tensor<512xf32, #blocked>
// CHECK: llvm.return %[[VAL_14]] : !llvm.struct<(f32, f32, f32, f32)>
tt.return %1: tensor<512xf32, #blocked>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"

namespace mlir {
class ConversionPatternRewriter;
}

namespace mlir::triton::gpu::intel {

// data type for D_C_A_B.
Expand Down Expand Up @@ -53,6 +57,14 @@ LogicalResult getConvertBackwardSlice(
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);

LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
ArrayRef<Type> paramTypes,
Type resultType);

LLVM::CallOp createSPIRVBuiltinCall(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFuncOp func, ValueRange args);

} // namespace mlir::triton::gpu::intel

#endif // TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H
37 changes: 11 additions & 26 deletions third_party/intel/lib/GPUToTritonGEN/OpToFuncCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"

#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"

namespace mlir {

Expand Down Expand Up @@ -52,17 +55,17 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
for (Value operand : adaptor.getOperands())
castedOperands.push_back(maybeCast(operand, rewriter));

Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName =
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
SmallVector<Type> parameters(ValueRange(castedOperands).getTypes());
Type resultType = parameters.front();
StringRef funcName = getFunctionName(resultType);
if (funcName.empty())
return failure();

LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp =
rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
callOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
auto funcOp = triton::gpu::intel::lookupOrCreateSPIRVFn(
moduleOp, funcName, parameters, resultType);
auto callOp = triton::gpu::intel::createSPIRVBuiltinCall(
op->getLoc(), rewriter, funcOp, castedOperands);

if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult()});
Expand All @@ -86,11 +89,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
}

Type getFunctionType(Type resultType, ValueRange operands) const {
SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}

StringRef getFunctionName(Type type) const {
if (isa<Float32Type>(type))
return f32Func;
Expand All @@ -99,19 +97,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return "";
}

LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
Operation *op) const {
using LLVM::LLVMFuncOp;

auto funcAttr = StringAttr::get(op->getContext(), funcName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);

mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
}

const std::string f32Func;
const std::string f64Func;
};
Expand Down
33 changes: 21 additions & 12 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"

namespace {
Expand Down Expand Up @@ -1284,36 +1285,44 @@ struct FpToFpOpConversion
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
auto as_int16 = bitcast(v, i16_ty);
auto as_int32 = zext(i32_ty, as_int16);
auto shifted = shl(i32_ty, as_int32, i32_val(16));
return (bitcast(shifted, f32_ty));
auto moduleOp =
v.getDefiningOp()->getParentWithTrait<OpTrait::SymbolTable>();
constexpr StringLiteral name = "_Z31intel_convert_as_bfloat16_floatt";
auto ext_func = triton::gpu::intel::lookupOrCreateSPIRVFn(moduleOp, name,
i16_ty, f32_ty);
auto call =
triton::gpu::intel::createSPIRVBuiltinCall(loc, rewriter, ext_func, v);
return call.getResult();
}

static Value convertFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
auto ctx = rewriter.getContext();
return rewriter.create<LLVM::FPExtOp>(loc, f32_ty, v);
}

static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v, const RoundingMode rounding) {
if (rounding == RoundingMode::RTNE) {
auto moduleOp =
v.getDefiningOp()->getParentWithTrait<OpTrait::SymbolTable>();
// Intel SPIR-V extension only supports round-to-nearest-even
constexpr StringLiteral name = "_Z32intel_convert_bfloat16_as_ushortf";
auto trunc_func = triton::gpu::intel::lookupOrCreateSPIRVFn(
moduleOp, name, f32_ty, i16_ty);
auto call = triton::gpu::intel::createSPIRVBuiltinCall(loc, rewriter,
trunc_func, v);
return call.getResult();
}

auto as_uint32 = bitcast(v, i32_ty);
auto check_exponent =
and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)),
i32_val(0x7f800000));
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
Value rounded = as_uint32;
if (rounding == RoundingMode::RTNE) {
rounded =
add(i32_ty, i32_val(0x7fff),
and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)));
rounded = add(i32_ty, rounded, as_uint32);
rounded = select(exponent_not_all1s, rounded, as_uint32);
}

auto preserve_nan =
and_(i1_ty, exponent_all1s,
Expand Down
25 changes: 25 additions & 0 deletions third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

#include "triton/Analysis/Utility.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"

#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

#include <optional>

using namespace mlir;
Expand Down Expand Up @@ -235,4 +237,27 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
return success();
}

LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
ArrayRef<Type> paramTypes,
Type resultType) {
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTable, name));
if (!func) {
OpBuilder b(symbolTable->getRegion(0));
func = b.create<LLVM::LLVMFuncOp>(
symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
}
return func;
}

LLVM::CallOp createSPIRVBuiltinCall(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFuncOp func, ValueRange args) {
auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
call.setCConv(func.getCConv());
return call;
}

} // namespace mlir::triton::gpu::intel

0 comments on commit 69e0020

Please sign in to comment.