Skip to content

Commit

Permalink
[XLA:GPU] Support F16 and BF16 math calls in triton fusion.
Browse files Browse the repository at this point in the history
Like in the XLA fusion emitter, we can upcast to F32 and then use the libdevice
math call for F32, and finally downcast back to the original type.

PiperOrigin-RevId: 691295714
  • Loading branch information
akuegel authored and tensorflower-gardener committed Oct 30, 2024
1 parent 225fb6d commit af53eb2
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 43 deletions.
32 changes: 24 additions & 8 deletions third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,24 +326,40 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
const se::DeviceDescription& device_info,
const HloInstruction& hlo,
ValueRange inputs) {
if (mlir::getElementTypeOrSelf(inputs[0]).isF32() ||
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
Type input_type = mlir::getElementTypeOrSelf(inputs[0]);
if (input_type.isBF16() || input_type.isF16() || input_type.isF32() ||
input_type.isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::RocmComputeCapability>(
device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
PrimitiveType output_type = hlo.shape().element_type();
llvm::SmallVector<Value, 2> casted_inputs;
if (input_type.isBF16() || input_type.isF16()) {
// Upcast the inputs to F32.
for (int64_t i = 0; i < inputs.size(); ++i) {
casted_inputs.push_back(Cast(b, inputs[i], b.getF32Type()));
}
output_type = F32;
} else {
casted_inputs.assign(inputs.begin(), inputs.end());
}
Value res = b.create<mt::ExternElementwiseOp>(
casted_inputs[0].getType(), casted_inputs, "libdevice",
libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(), output_type, triple),
/*pure=*/true);
if (input_type.isBF16() || input_type.isF16()) {
// Downcast back to the original input type.
res = Cast(b, res, input_type);
}
return res;
}
}
const bool is_integer =
mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(inputs[0]));
const bool is_integer = mlir::isa<mlir::IntegerType>(input_type);

switch (hlo.opcode()) {
case HloOpcode::kCopy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,24 +398,41 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
const se::DeviceDescription& device_info,
const HloInstruction& hlo,
ValueRange inputs) {
if (mlir::getElementTypeOrSelf(inputs[0]).isF32() ||
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
Type input_type = mlir::getElementTypeOrSelf(inputs[0]);
if (input_type.isBF16() || input_type.isF16() || input_type.isF32() ||
input_type.isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::RocmComputeCapability>(
device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
PrimitiveType output_type = hlo.shape().element_type();
llvm::SmallVector<Value, 2> casted_inputs;
if (input_type.isBF16() || input_type.isF16()) {
// Upcast the inputs to F32.
for (int64_t i = 0; i < inputs.size(); ++i) {
casted_inputs.push_back(Cast(b, inputs[i], b.getF32Type()));
}
output_type = F32;
} else {
casted_inputs.assign(inputs.begin(), inputs.end());
}
Value res = b.create<mt::ExternElementwiseOp>(
casted_inputs[0].getType(), casted_inputs, "libdevice",
libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(), output_type, triple),
/*pure=*/true);
if (input_type.isBF16() || input_type.isF16()) {
// Downcast back to the original input type.
res = Cast(b, res, input_type);
}
return res;
}
}
const bool is_integer =
mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(inputs[0]));
mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(input_type));

switch (hlo.opcode()) {
case HloOpcode::kCopy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,50 @@ ENTRY e {
/*run_hlo_passes=*/false));
}

TEST_P(UnaryElementwiseTest, ElementwiseUnaryOpExecutesCorrectly) {
PrimitiveType data_type;
HloOpcode opcode;
float tolerance;
std::tie(data_type, opcode, tolerance) = GetParam();

const std::string kHloTestTemplate = R"(
triton_computation {
parameter_0 = $0[33,68]{1,0} parameter(0)
f1.1 = $0[33,68]{1,0} $1(parameter_0)
ROOT convert = f32[33,68]{1,0} convert(f1.1)
}
ENTRY e {
p0 = $0[33,68]{1,0} parameter(0)
ROOT triton_fusion = f32[33,68]{1,0} fusion(p0), kind=kCustom,
calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton",
"block_level_fusion_config":{"output_tile_sizes":["1", "1"],"num_warps":"1"}}}
})";
const std::string hlo_test = absl::Substitute(
kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

const std::string kHloRefTemplate = R"(
fused_computation {
param_0.1 = $0[33,68]{1,0} parameter(0)
f.1 = $0[33,68]{1,0} $1(param_0.1)
ROOT convert = f32[33,68]{1,0} convert(f.1)
}
ENTRY e {
p0 = $0[33,68]{1,0} parameter(0)
ROOT fusion = f32[33,68]{1,0} fusion(p0), kind=kLoop, calls=fused_computation
})";
const std::string hlo_ref = absl::Substitute(
kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

EXPECT_TRUE(RunAndCompareTwoModules(
hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance},
/*run_hlo_passes=*/false));
}

INSTANTIATE_TEST_SUITE_P(
ElementwiseTestSuitePRED, UnaryElementwiseTest,
::testing::Combine(
Expand Down Expand Up @@ -380,6 +424,54 @@ ENTRY e {
/*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6));
}

TEST_P(BinaryElementwiseTest, ElementwiseBinaryOpExecutesCorrectly) {
PrimitiveType data_type;
HloOpcode opcode;
float tolerance;
std::tie(data_type, opcode, tolerance) = GetParam();

const std::string kHloTestTemplate = R"(
triton_computation {
parameter_0 = $0[11,63]{1,0} parameter(0)
parameter_1 = $0[11,63]{1,0} parameter(1)
f1.1 = $0[11,63]{1,0} $1(parameter_0, parameter_1)
ROOT c.1 = f32[11,63]{1,0} convert(f1.1)
}
ENTRY e {
p0 = $0[11,63]{1,0} parameter(0)
p1 = $0[11,63]{1,0} parameter(1)
ROOT triton_fusion = f32[11,63]{1,0} fusion(p0, p1), kind=kCustom,
calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton",
"block_level_fusion_config":{"output_tile_sizes":["1", "1"],"num_warps":"1"}}}
})";
const std::string hlo_test = absl::Substitute(
kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

const std::string kHloRefTemplate = R"(
fused_computation {
p0 = $0[11,63]{1,0} parameter(0)
p1 = $0[11,63]{1,0} parameter(1)
f.1 = $0[11,63]{1,0} $1(p0, p1)
ROOT convert.1 = f32[11,63]{1,0} convert(f.1)
}
ENTRY e {
p1 = $0[11,63]{1,0} parameter(1)
p0 = $0[11,63]{1,0} parameter(0)
ROOT fusion = f32[11,63]{1,0} fusion(p0, p1), kind=kLoop, calls=fused_computation
})";
const std::string hlo_ref = absl::Substitute(
kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

EXPECT_TRUE(RunAndCompareTwoModules(
hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance},
/*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6));
}

std::vector<HloOpcode> TestedBinaryElementwise(PrimitiveType element_type) {
std::vector<HloOpcode> ret =
legacy_triton::TritonSupportedBinaryElementwiseUpToFloatNormalization(
Expand Down Expand Up @@ -1063,6 +1155,7 @@ TEST_P(TritonSoftmaxTest,
if (data_type == F16) {
GTEST_SKIP() << "Exponential op does not support F16.";
}

const std::string hlo_text_template = R"(
HloModule softmax
max_computation {
Expand Down Expand Up @@ -1743,10 +1836,6 @@ ENTRY main {
TEST_P(TritonSoftmaxTest, CanFuseAndEmitRMSNormDiamond) {
PrimitiveType data_type = GetParam();

if (data_type == F16) {
GTEST_SKIP() << "rsqrt op does not support F16.";
}

const std::string hlo_text_template = R"(
HloModule rms_norm
add_computation {
Expand Down Expand Up @@ -1794,7 +1883,7 @@ ENTRY main.30 {
tolerance = 1e-6;
break;
case F16:
tolerance = 2e-4;
tolerance = 5e-4;
break;
case BF16:
tolerance = 4e-2;
Expand Down
17 changes: 9 additions & 8 deletions third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
ret.insert(HloOpcode::kNot);
}

if (element_type == PrimitiveType::F32 ||
if (element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16 ||
element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::F64) {
absl::flat_hash_set<HloOpcode> additional_opcodes{
HloOpcode::kCos, HloOpcode::kExp, HloOpcode::kExpm1,
Expand All @@ -100,13 +102,6 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
ret.insert(additional_opcodes.begin(), additional_opcodes.end());
}

if (element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16) {
absl::flat_hash_set<HloOpcode> additional_opcodes{HloOpcode::kFloor,
HloOpcode::kCeil};
ret.insert(additional_opcodes.begin(), additional_opcodes.end());
}

if (primitive_util::IsFloatingPointType(element_type)) {
ret.insert(HloOpcode::kReducePrecision);
}
Expand Down Expand Up @@ -185,6 +180,12 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
ret.insert(HloOpcode::kRemainder);
ret.insert(HloOpcode::kPower);
}
if (element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16) {
ret.insert(HloOpcode::kAtan2);
ret.insert(HloOpcode::kPower);
ret.insert(HloOpcode::kRemainder);
}

return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ std::vector<HloOpcode> TritonSupportedUnaryElementwiseUpToFloatNormalization(
ret.push_back(HloOpcode::kNegate);
if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16 ||
element_type == PrimitiveType::F64) {
absl::c_copy(std::vector<HloOpcode>{HloOpcode::kCos, HloOpcode::kExp,
HloOpcode::kExpm1, HloOpcode::kFloor,
Expand All @@ -168,10 +169,13 @@ std::vector<HloOpcode> TritonSupportedBinaryElementwiseUpToFloatNormalization(
HloOpcode::kMultiply, HloOpcode::kSubtract};
if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16 ||
element_type == PrimitiveType::F64) {
ret.push_back(HloOpcode::kAtan2);
ret.push_back(HloOpcode::kDivide);
ret.push_back(HloOpcode::kPower);
if (element_type != PrimitiveType::F16) {
ret.push_back(HloOpcode::kDivide);
}
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,13 @@ ENTRY main {
reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0}
subtract = f16[127,125]{1,0} subtract(param_0, broadcast)
// Replace Softmax exponential with abs, because Triton doesn't support
// non-f32 exponentials.
abs = f16[127,125]{1,0} abs(subtract)
exp = f16[127,125]{1,0} exponential(subtract)
constant_zero = f16[] constant(0)
second_reduce = f16[127]{0} reduce(abs, constant_zero), dimensions={1}, to_apply=add_computation
second_reduce = f16[127]{0} reduce(exp, constant_zero), dimensions={1}, to_apply=add_computation
second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0}
// Replace divide with multiply, because Triton doesn't support f16
// divisions.
ROOT multiply = f16[127,125]{1,0} multiply(abs, second_broadcast)
ROOT multiply = f16[127,125]{1,0} multiply(exp, second_broadcast)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
Expand Down Expand Up @@ -209,20 +207,20 @@ ENTRY main {
reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
subtract = bf16[127,125]{1,0} subtract(param_0, broadcast)
ROOT exponential = bf16[127,125]{1,0} exponential(subtract)
ROOT round = bf16[127,125]{1,0} round-nearest-even(subtract)
})";

auto module = ParseAndReturnVerifiedModule(hlo_string).value();
const HloInstruction* bf16_exponential =
const HloInstruction* bf16_round_nearest_even =
hlo_query::GetFirstInstructionWithOpcode(*module->entry_computation(),
HloOpcode::kExp);
HloOpcode::kRoundNearestEven);
EXPECT_FALSE(IsTritonSupportedInstruction(
*bf16_exponential, device_info_.gpu_compute_capability()));
*bf16_round_nearest_even, device_info_.gpu_compute_capability()));
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
EXPECT_TRUE(verifier().Run(module.get()).status().ok());
EXPECT_THAT(
module->entry_computation()->root_instruction(),
GmockMatch(m::Exp(
GmockMatch(m::RoundNearestEven(
m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
}

Expand Down Expand Up @@ -733,8 +731,8 @@ max_computation {
ENTRY main {
param_0 = f16[127,125]{1,0} parameter(0)
exponential = f16[127,125] exponential(param_0)
convert = f32[127,125] convert(exponential)
round-nearest-even = f16[127,125] round-nearest-even(param_0)
convert = f32[127,125] convert(round-nearest-even)
constant_neg_inf = f32[] constant(-inf)
reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation
broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
Expand All @@ -745,7 +743,7 @@ ENTRY main {
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
EXPECT_TRUE(verifier().Run(module.get()).status().ok());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Fusion(m::Exp(m::Parameter()))
GmockMatch(m::Fusion(m::RoundNearestEven(m::Parameter()))
.WithPredicate(HasBlockLevelFusionConfig)));
}

Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/pattern_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -2678,7 +2678,6 @@ XLA_NULLOP_PATTERN(ReplicaId)
.WithOperand(0, std::forward<Arg>(arg)); \
}
XLA_UNOP_PATTERN(Abs)
XLA_UNOP_PATTERN(RoundNearestAfz)
XLA_UNOP_PATTERN(Bitcast)
XLA_UNOP_PATTERN(BitcastConvert)
XLA_UNOP_PATTERN(Broadcast)
Expand Down Expand Up @@ -2717,6 +2716,8 @@ XLA_UNOP_PATTERN(RecvDone)
XLA_UNOP_PATTERN(ReducePrecision)
XLA_UNOP_PATTERN(Reshape)
XLA_UNOP_PATTERN(Reverse)
XLA_UNOP_PATTERN(RoundNearestAfz)
XLA_UNOP_PATTERN(RoundNearestEven)
XLA_UNOP_PATTERN(Rsqrt)
XLA_UNOP_PATTERN(SendDone)
XLA_UNOP_PATTERN(Sign)
Expand Down

0 comments on commit af53eb2

Please sign in to comment.