Skip to content

feat: simplify operations with sign #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 230 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
@@ -7792,6 +7792,231 @@ struct TransposeReduceSimplify : public OpRewritePattern<stablehlo::ReduceOp> {
}
};

// (select (x > 0) z (neg z)) -> (mul (sign x) z)
// (select (x >= 0) z (neg z)) -> (mul (sign x) z)
// (select (x > 0) (neg z) z) -> (mul (sign x) (neg z))
// (select (x >= 0) (neg z) z) -> (mul (sign x) (neg z))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this actually simpler/faster?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one mostly enables the other optimizations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a select is usually faster than a mul, so this in isolation makes things locally worse. Is it feasible for the downstream operations to work on select style forms?

// (select (x < 0) z (neg z)) -> (mul (sign x) z)
// (select (x <= 0) z (neg z)) -> (mul (sign x) z)
// (select (x < 0) (neg z) z) -> (mul (sign x) (neg z))
// (select (x <= 0) (neg z) z) -> (mul (sign x) (neg z))
struct PositiveNegativeSelectSimplify
: public OpRewritePattern<stablehlo::SelectOp> {
using OpRewritePattern<stablehlo::SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::SelectOp op,
PatternRewriter &rewriter) const override {
auto cond = op.getPred();
auto trueValue = op.getOnTrue();
auto falseValue = op.getOnFalse();

Value rhs = nullptr;
bool lhspositive = true;
if (trueValue.getDefiningOp<stablehlo::NegOp>()) {
if (trueValue.getDefiningOp<stablehlo::NegOp>().getOperand() !=
falseValue)
return failure();

rhs = falseValue; // cond ? -z : z
lhspositive = false;
} else if (falseValue.getDefiningOp<stablehlo::NegOp>()) {
if (falseValue.getDefiningOp<stablehlo::NegOp>().getOperand() !=
trueValue)
return failure();

rhs = trueValue; // cond ? z : -z
} else {
return failure();
}

auto compareOp = cond.getDefiningOp<stablehlo::CompareOp>();
if (!compareOp)
return failure();

if (compareOp.getComparisonDirection() ==
stablehlo::ComparisonDirection::EQ ||
compareOp.getComparisonDirection() ==
stablehlo::ComparisonDirection::NE)
return failure();

Value condValue = nullptr;
bool positive = true;
auto lhsCompareOp = compareOp.getLhs();
auto rhsCompareOp = compareOp.getRhs();
if (matchPattern(lhsCompareOp, m_AnyZeroFloat()) ||
matchPattern(lhsCompareOp, m_Zero())) {
condValue = compareOp.getRhs();
positive = compareOp.getComparisonDirection() ==
stablehlo::ComparisonDirection::GT ||
compareOp.getComparisonDirection() ==
stablehlo::ComparisonDirection::GE;
} else if (matchPattern(rhsCompareOp, m_AnyZeroFloat()) ||
matchPattern(rhsCompareOp, m_Zero())) {
condValue = compareOp.getLhs();
positive = compareOp.getComparisonDirection() ==
stablehlo::ComparisonDirection::LT ||
compareOp.getComparisonDirection() ==
stablehlo::ComparisonDirection::LE;
} else {
return failure();
}

auto newOp = rewriter.create<stablehlo::SignOp>(op.getLoc(), condValue);
if (positive) { // cond > or >= 0
if (lhspositive) {
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newOp, rhs);
} else {
auto negRhs = rewriter.create<stablehlo::NegOp>(op.getLoc(), rhs);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newOp, negRhs);
}
} else { // cond < or <= 0
if (lhspositive) {
auto negRhs = rewriter.create<stablehlo::NegOp>(op.getLoc(), rhs);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newOp, negRhs);
} else {
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newOp, rhs);
}
}

return success();
}
};

// (mul (sign x) (abs x)) -> x
// (mul (abs x) (sign x)) -> x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems reasonable, can you split to do this individually?

struct SignAbsSimplify : public OpRewritePattern<stablehlo::MulOp> {
using OpRewritePattern<stablehlo::MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::MulOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getOperand(0);
auto rhs = op.getOperand(1);

auto lhsSignOp = lhs.getDefiningOp<stablehlo::SignOp>();
if (lhsSignOp) {
auto rhsAbsOp = rhs.getDefiningOp<stablehlo::AbsOp>();
if (!rhsAbsOp)
return failure();

if (lhsSignOp.getOperand() != rhsAbsOp.getOperand())
return failure();

rewriter.replaceOp(op, lhsSignOp.getOperand());
return success();
}

auto rhsSignOp = rhs.getDefiningOp<stablehlo::SignOp>();
if (rhsSignOp) {
auto lhsAbsOp = lhs.getDefiningOp<stablehlo::AbsOp>();
if (!lhsAbsOp)
return failure();

if (rhsSignOp.getOperand() != lhsAbsOp.getOperand())
return failure();

rewriter.replaceOp(op, rhsSignOp.getOperand());
return success();
}

return failure();
}
};

// (mul (neg x) (neg y)) -> (mul x y)
// (mul (neg x) y) -> (neg (mul x y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top one is always good, the one negation variations we should separate since there's a separate question of whether we want to propagate them up or down (e.g. if we had mul (neg x), constant) we'd want to do mul x (-constant)

// (mul x (neg y)) -> (neg (mul x y))
struct MultiplyNegateSimplify : public OpRewritePattern<stablehlo::MulOp> {
using OpRewritePattern<stablehlo::MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::MulOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getOperand(0);
auto rhs = op.getOperand(1);

auto lhsNegOp = lhs.getDefiningOp<stablehlo::NegOp>();
auto rhsNegOp = rhs.getDefiningOp<stablehlo::NegOp>();
if (!lhsNegOp && !rhsNegOp)
return failure();

if (lhsNegOp) {
if (rhsNegOp) {
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, lhsNegOp.getOperand(),
rhsNegOp.getOperand());
return success();
} else {
if (!isOnlyUsedInOperation(lhsNegOp, op))
return failure();
auto newOp = rewriter.create<stablehlo::MulOp>(
op.getLoc(), lhsNegOp.getOperand(), rhs);
rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, newOp);
return success();
}
} else if (rhsNegOp) {
if (!isOnlyUsedInOperation(rhsNegOp, op))
return failure();
auto newOp = rewriter.create<stablehlo::MulOp>(op.getLoc(), lhs,
rhsNegOp.getOperand());
rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, newOp);
return success();
}

return failure();
}
};

// This pattern only does partially the following. We rely on transforming the op to a
// pattern which further uses the above pattern.
// (mul (sign x) (add (abs x) (abs x))) -> (mul x x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

longer term i feel like this merits a broader sign analysis (alongside perhaps a transpose analysis)

// TODO: We can simplify for cases where only one of the add operands is abs.
struct MultiplySignAddSimplify : public OpRewritePattern<stablehlo::MulOp> {
using OpRewritePattern<stablehlo::MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::MulOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getOperand(0);
auto rhs = op.getOperand(1);

stablehlo::SignOp signOp = nullptr;
stablehlo::AddOp addOp = nullptr;
if (lhs.getDefiningOp<stablehlo::SignOp>()) {
signOp = lhs.getDefiningOp<stablehlo::SignOp>();
if (rhs.getDefiningOp<stablehlo::AddOp>()) {
addOp = rhs.getDefiningOp<stablehlo::AddOp>();
} else {
return failure();
}
} else if (rhs.getDefiningOp<stablehlo::SignOp>()) {
signOp = rhs.getDefiningOp<stablehlo::SignOp>();
if (lhs.getDefiningOp<stablehlo::AddOp>()) {
addOp = lhs.getDefiningOp<stablehlo::AddOp>();
} else {
return failure();
}
} else {
return failure();
}

auto signOperand = signOp.getOperand();

auto lhsAddOp = addOp.getOperand(0);
auto rhsAddOp = addOp.getOperand(1);

if (lhsAddOp != rhsAddOp)
return failure(); // TODO: Can support more cases.

auto lhsAddAbsOp = lhsAddOp.getDefiningOp<stablehlo::AbsOp>();
auto rhsAddAbsOp = rhsAddOp.getDefiningOp<stablehlo::AbsOp>();
if (!lhsAddAbsOp || !rhsAddAbsOp)
return failure();

if (signOperand != lhsAddAbsOp.getOperand() || signOperand != rhsAddAbsOp.getOperand())
return failure();

rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, signOperand, signOperand);
return success();
}
};

/////////////// End Imported from stablehlo

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"
@@ -8032,7 +8257,11 @@ struct EnzymeHLOOptPass
CommonCompareExpressionRewrite,
ScatterUpdateComputationConstProp,
ScatterIndicesAreUnique,
TransposeReduceSimplify
TransposeReduceSimplify,
SignAbsSimplify,
PositiveNegativeSelectSimplify,
MultiplyNegateSimplify,
MultiplySignAddSimplify
>(context);
// clang-format on
patterns.add<SelectOpCanon>(max_constant_expansion, context,
20 changes: 20 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
@@ -888,6 +888,26 @@ def ApplyTransposeReduceSimplifyPatterns : EnzymeHLOPatternOp<
let patterns = ["TransposeReduceSimplify"];
}

def ApplySignAbsSimplifyPatterns : EnzymeHLOPatternOp<
"sign_abs_simplify"> {
let patterns = ["SignAbsSimplify"];
}

def ApplyPositiveNegativeSelectSimplifyPatterns : EnzymeHLOPatternOp<
"positive_negative_select_simplify"> {
let patterns = ["PositiveNegativeSelectSimplify"];
}

def ApplyMultiplyNegateSimplifyPatterns : EnzymeHLOPatternOp<
"multiply_negate_simplify"> {
let patterns = ["MultiplyNegateSimplify"];
}

def ApplyMultiplySignAddSimplifyPatterns : EnzymeHLOPatternOp<
"multiply_sign_add_simplify"> {
let patterns = ["MultiplySignAddSimplify"];
}

// TODO: better naming for parameters requires a static interface for
// constructing them in search.

4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
@@ -225,6 +225,10 @@ def hlo_opts():
scatter_update_computation_const_prop;
scatter_indices_are_unique;
transpose_reduce_simplify;
sign_abs_simplify;
positive_negative_select_simplify;
multiply_negate_simplify;
multiply_sign_add_simplify;

transpose_unary_transpose_abs<1>;
transpose_unary_transpose_neg<1>;
27 changes: 27 additions & 0 deletions test/lit_tests/signabssimplify.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

func.func @main1(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<8x4xf32>) -> tensor<4x8xf32>
%1 = stablehlo.sign %0 : tensor<4x8xf32>
%2 = stablehlo.abs %0 : tensor<4x8xf32>
%3 = stablehlo.multiply %1, %2 : tensor<4x8xf32>
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<4x8xf32>) -> tensor<8x4xf32>
return %4 : tensor<8x4xf32>
}

// CHECK: func.func @main1(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> {
// CHECK-NEXT: return %arg0 : tensor<8x4xf32>
// CHECK-NEXT: }

func.func @main2(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<8x4xf32>) -> tensor<4x8xf32>
%1 = stablehlo.sign %0 : tensor<4x8xf32>
%2 = stablehlo.abs %0 : tensor<4x8xf32>
%3 = stablehlo.multiply %2, %1 : tensor<4x8xf32>
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<4x8xf32>) -> tensor<8x4xf32>
return %4 : tensor<8x4xf32>
}

// CHECK: func.func @main2(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> {
// CHECK-NEXT: return %arg0 : tensor<8x4xf32>
// CHECK-NEXT: }