Skip to content

Commit

Permalink
Reverts af53eb2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691339219
  • Loading branch information
akuegel authored and tensorflower-gardener committed Oct 30, 2024
1 parent 2d93f75 commit 9d5bb83
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 169 deletions.
32 changes: 8 additions & 24 deletions third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,40 +326,24 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
const se::DeviceDescription& device_info,
const HloInstruction& hlo,
ValueRange inputs) {
Type input_type = mlir::getElementTypeOrSelf(inputs[0]);
if (input_type.isBF16() || input_type.isF16() || input_type.isF32() ||
input_type.isF64()) {
if (mlir::getElementTypeOrSelf(inputs[0]).isF32() ||
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::RocmComputeCapability>(
device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
PrimitiveType output_type = hlo.shape().element_type();
llvm::SmallVector<Value, 2> casted_inputs;
if (input_type.isBF16() || input_type.isF16()) {
// Upcast the inputs to F32.
for (int64_t i = 0; i < inputs.size(); ++i) {
casted_inputs.push_back(Cast(b, inputs[i], b.getF32Type()));
}
output_type = F32;
} else {
casted_inputs.assign(inputs.begin(), inputs.end());
}
Value res = b.create<mt::ExternElementwiseOp>(
casted_inputs[0].getType(), casted_inputs, "libdevice",
libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(), output_type, triple),
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
/*pure=*/true);
if (input_type.isBF16() || input_type.isF16()) {
// Downcast back to the original input type.
res = Cast(b, res, input_type);
}
return res;
}
}
const bool is_integer = mlir::isa<mlir::IntegerType>(input_type);
const bool is_integer =
mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(inputs[0]));

switch (hlo.opcode()) {
case HloOpcode::kCopy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,41 +398,24 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
const se::DeviceDescription& device_info,
const HloInstruction& hlo,
ValueRange inputs) {
Type input_type = mlir::getElementTypeOrSelf(inputs[0]);
if (input_type.isBF16() || input_type.isF16() || input_type.isF32() ||
input_type.isF64()) {
if (mlir::getElementTypeOrSelf(inputs[0]).isF32() ||
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::RocmComputeCapability>(
device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
PrimitiveType output_type = hlo.shape().element_type();
llvm::SmallVector<Value, 2> casted_inputs;
if (input_type.isBF16() || input_type.isF16()) {
// Upcast the inputs to F32.
for (int64_t i = 0; i < inputs.size(); ++i) {
casted_inputs.push_back(Cast(b, inputs[i], b.getF32Type()));
}
output_type = F32;
} else {
casted_inputs.assign(inputs.begin(), inputs.end());
}
Value res = b.create<mt::ExternElementwiseOp>(
casted_inputs[0].getType(), casted_inputs, "libdevice",
libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(), output_type, triple),
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
/*pure=*/true);
if (input_type.isBF16() || input_type.isF16()) {
// Downcast back to the original input type.
res = Cast(b, res, input_type);
}
return res;
}
}
const bool is_integer =
mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(input_type));
mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(inputs[0]));

switch (hlo.opcode()) {
case HloOpcode::kCopy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,50 +249,6 @@ ENTRY e {
/*run_hlo_passes=*/false));
}

TEST_P(UnaryElementwiseTest, ElementwiseUnaryOpExecutesCorrectly) {
PrimitiveType data_type;
HloOpcode opcode;
float tolerance;
std::tie(data_type, opcode, tolerance) = GetParam();

const std::string kHloTestTemplate = R"(
triton_computation {
parameter_0 = $0[33,68]{1,0} parameter(0)
f1.1 = $0[33,68]{1,0} $1(parameter_0)
ROOT convert = f32[33,68]{1,0} convert(f1.1)
}
ENTRY e {
p0 = $0[33,68]{1,0} parameter(0)
ROOT triton_fusion = f32[33,68]{1,0} fusion(p0), kind=kCustom,
calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton",
"block_level_fusion_config":{"output_tile_sizes":["1", "1"],"num_warps":"1"}}}
})";
const std::string hlo_test = absl::Substitute(
kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

const std::string kHloRefTemplate = R"(
fused_computation {
param_0.1 = $0[33,68]{1,0} parameter(0)
f.1 = $0[33,68]{1,0} $1(param_0.1)
ROOT convert = f32[33,68]{1,0} convert(f.1)
}
ENTRY e {
p0 = $0[33,68]{1,0} parameter(0)
ROOT fusion = f32[33,68]{1,0} fusion(p0), kind=kLoop, calls=fused_computation
})";
const std::string hlo_ref = absl::Substitute(
kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

EXPECT_TRUE(RunAndCompareTwoModules(
hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance},
/*run_hlo_passes=*/false));
}

INSTANTIATE_TEST_SUITE_P(
ElementwiseTestSuitePRED, UnaryElementwiseTest,
::testing::Combine(
Expand Down Expand Up @@ -424,54 +380,6 @@ ENTRY e {
/*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6));
}

TEST_P(BinaryElementwiseTest, ElementwiseBinaryOpExecutesCorrectly) {
PrimitiveType data_type;
HloOpcode opcode;
float tolerance;
std::tie(data_type, opcode, tolerance) = GetParam();

const std::string kHloTestTemplate = R"(
triton_computation {
parameter_0 = $0[11,63]{1,0} parameter(0)
parameter_1 = $0[11,63]{1,0} parameter(1)
f1.1 = $0[11,63]{1,0} $1(parameter_0, parameter_1)
ROOT c.1 = f32[11,63]{1,0} convert(f1.1)
}
ENTRY e {
p0 = $0[11,63]{1,0} parameter(0)
p1 = $0[11,63]{1,0} parameter(1)
ROOT triton_fusion = f32[11,63]{1,0} fusion(p0, p1), kind=kCustom,
calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton",
"block_level_fusion_config":{"output_tile_sizes":["1", "1"],"num_warps":"1"}}}
})";
const std::string hlo_test = absl::Substitute(
kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

const std::string kHloRefTemplate = R"(
fused_computation {
p0 = $0[11,63]{1,0} parameter(0)
p1 = $0[11,63]{1,0} parameter(1)
f.1 = $0[11,63]{1,0} $1(p0, p1)
ROOT convert.1 = f32[11,63]{1,0} convert(f.1)
}
ENTRY e {
p1 = $0[11,63]{1,0} parameter(1)
p0 = $0[11,63]{1,0} parameter(0)
ROOT fusion = f32[11,63]{1,0} fusion(p0, p1), kind=kLoop, calls=fused_computation
})";
const std::string hlo_ref = absl::Substitute(
kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));

EXPECT_TRUE(RunAndCompareTwoModules(
hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance},
/*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6));
}

std::vector<HloOpcode> TestedBinaryElementwise(PrimitiveType element_type) {
std::vector<HloOpcode> ret =
legacy_triton::TritonSupportedBinaryElementwiseUpToFloatNormalization(
Expand Down Expand Up @@ -1155,7 +1063,6 @@ TEST_P(TritonSoftmaxTest,
if (data_type == F16) {
GTEST_SKIP() << "Exponential op does not support F16.";
}

const std::string hlo_text_template = R"(
HloModule softmax
max_computation {
Expand Down Expand Up @@ -1836,6 +1743,10 @@ ENTRY main {
TEST_P(TritonSoftmaxTest, CanFuseAndEmitRMSNormDiamond) {
PrimitiveType data_type = GetParam();

if (data_type == F16) {
GTEST_SKIP() << "rsqrt op does not support F16.";
}

const std::string hlo_text_template = R"(
HloModule rms_norm
add_computation {
Expand Down Expand Up @@ -1883,7 +1794,7 @@ ENTRY main.30 {
tolerance = 1e-6;
break;
case F16:
tolerance = 5e-4;
tolerance = 2e-4;
break;
case BF16:
tolerance = 4e-2;
Expand Down
17 changes: 8 additions & 9 deletions third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
ret.insert(HloOpcode::kNot);
}

if (element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16 ||
element_type == PrimitiveType::F32 ||
if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::F64) {
absl::flat_hash_set<HloOpcode> additional_opcodes{
HloOpcode::kCos, HloOpcode::kExp, HloOpcode::kExpm1,
Expand All @@ -102,6 +100,13 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
ret.insert(additional_opcodes.begin(), additional_opcodes.end());
}

if (element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16) {
absl::flat_hash_set<HloOpcode> additional_opcodes{HloOpcode::kFloor,
HloOpcode::kCeil};
ret.insert(additional_opcodes.begin(), additional_opcodes.end());
}

if (primitive_util::IsFloatingPointType(element_type)) {
ret.insert(HloOpcode::kReducePrecision);
}
Expand Down Expand Up @@ -180,12 +185,6 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
ret.insert(HloOpcode::kRemainder);
ret.insert(HloOpcode::kPower);
}
if (element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16) {
ret.insert(HloOpcode::kAtan2);
ret.insert(HloOpcode::kPower);
ret.insert(HloOpcode::kRemainder);
}

return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ std::vector<HloOpcode> TritonSupportedUnaryElementwiseUpToFloatNormalization(
ret.push_back(HloOpcode::kNegate);
if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16 ||
element_type == PrimitiveType::F64) {
absl::c_copy(std::vector<HloOpcode>{HloOpcode::kCos, HloOpcode::kExp,
HloOpcode::kExpm1, HloOpcode::kFloor,
Expand All @@ -169,13 +168,10 @@ std::vector<HloOpcode> TritonSupportedBinaryElementwiseUpToFloatNormalization(
HloOpcode::kMultiply, HloOpcode::kSubtract};
if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::BF16 ||
element_type == PrimitiveType::F16 ||
element_type == PrimitiveType::F64) {
ret.push_back(HloOpcode::kAtan2);
ret.push_back(HloOpcode::kDivide);
ret.push_back(HloOpcode::kPower);
if (element_type != PrimitiveType::F16) {
ret.push_back(HloOpcode::kDivide);
}
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,15 @@ ENTRY main {
reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0}
subtract = f16[127,125]{1,0} subtract(param_0, broadcast)
exp = f16[127,125]{1,0} exponential(subtract)
// Replace Softmax exponential with abs, because Triton doesn't support
// non-f32 exponentials.
abs = f16[127,125]{1,0} abs(subtract)
constant_zero = f16[] constant(0)
second_reduce = f16[127]{0} reduce(exp, constant_zero), dimensions={1}, to_apply=add_computation
second_reduce = f16[127]{0} reduce(abs, constant_zero), dimensions={1}, to_apply=add_computation
second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0}
// Replace divide with multiply, because Triton doesn't support f16
// divisions.
ROOT multiply = f16[127,125]{1,0} multiply(exp, second_broadcast)
ROOT multiply = f16[127,125]{1,0} multiply(abs, second_broadcast)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
Expand Down Expand Up @@ -207,20 +209,20 @@ ENTRY main {
reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
subtract = bf16[127,125]{1,0} subtract(param_0, broadcast)
ROOT round = bf16[127,125]{1,0} round-nearest-even(subtract)
ROOT exponential = bf16[127,125]{1,0} exponential(subtract)
})";

auto module = ParseAndReturnVerifiedModule(hlo_string).value();
const HloInstruction* bf16_round_nearest_even =
const HloInstruction* bf16_exponential =
hlo_query::GetFirstInstructionWithOpcode(*module->entry_computation(),
HloOpcode::kRoundNearestEven);
HloOpcode::kExp);
EXPECT_FALSE(IsTritonSupportedInstruction(
*bf16_round_nearest_even, device_info_.gpu_compute_capability()));
*bf16_exponential, device_info_.gpu_compute_capability()));
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
EXPECT_TRUE(verifier().Run(module.get()).status().ok());
EXPECT_THAT(
module->entry_computation()->root_instruction(),
GmockMatch(m::RoundNearestEven(
GmockMatch(m::Exp(
m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
}

Expand Down Expand Up @@ -731,8 +733,8 @@ max_computation {
ENTRY main {
param_0 = f16[127,125]{1,0} parameter(0)
round-nearest-even = f16[127,125] round-nearest-even(param_0)
convert = f32[127,125] convert(round-nearest-even)
exponential = f16[127,125] exponential(param_0)
convert = f32[127,125] convert(exponential)
constant_neg_inf = f32[] constant(-inf)
reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation
broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
Expand All @@ -743,7 +745,7 @@ ENTRY main {
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
EXPECT_TRUE(verifier().Run(module.get()).status().ok());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Fusion(m::RoundNearestEven(m::Parameter()))
GmockMatch(m::Fusion(m::Exp(m::Parameter()))
.WithPredicate(HasBlockLevelFusionConfig)));
}

Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/pattern_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -2678,6 +2678,7 @@ XLA_NULLOP_PATTERN(ReplicaId)
.WithOperand(0, std::forward<Arg>(arg)); \
}
XLA_UNOP_PATTERN(Abs)
XLA_UNOP_PATTERN(RoundNearestAfz)
XLA_UNOP_PATTERN(Bitcast)
XLA_UNOP_PATTERN(BitcastConvert)
XLA_UNOP_PATTERN(Broadcast)
Expand Down Expand Up @@ -2716,8 +2717,6 @@ XLA_UNOP_PATTERN(RecvDone)
XLA_UNOP_PATTERN(ReducePrecision)
XLA_UNOP_PATTERN(Reshape)
XLA_UNOP_PATTERN(Reverse)
XLA_UNOP_PATTERN(RoundNearestAfz)
XLA_UNOP_PATTERN(RoundNearestEven)
XLA_UNOP_PATTERN(Rsqrt)
XLA_UNOP_PATTERN(SendDone)
XLA_UNOP_PATTERN(Sign)
Expand Down

0 comments on commit 9d5bb83

Please sign in to comment.