Skip to content

Commit

Permalink
feat: rewrite constants into scatter body
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 7, 2025
1 parent bba7c6b commit 620d438
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 1 deletion.
99 changes: 98 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7643,6 +7643,102 @@ struct CommonCompareExpressionRewrite
}
};

struct ScatterUpdateComputationConstProp
: public OpRewritePattern<stablehlo::ScatterOp> {
using OpRewritePattern<stablehlo::ScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::ScatterOp op,
PatternRewriter &rewriter) const final {
// If this scatter op was created by this pass, don't rewrite it again
// Q: Is there a better way to do this?

llvm::errs() << "op: " << op << "\n";
llvm::errs() << "op.getUpdateComputation() attr: "
<< op->hasAttr("enzymexla.transformed_by_scatter_const_prop")
<< "\n";

if (op->hasAttr("enzymexla.transformed_by_scatter_const_prop"))
return failure();

auto &region = op.getUpdateComputation();
auto &block = region.front();

// Check all inputs are constant and splat and their values are the same.
auto [constInput, inputSplatAttr] =
isConstantSplatValueRange(op.getInputs());

// Check all updates are constant and splat and their values are the same.
auto [constUpdate, updateSplatAttr] =
isConstantSplatValueRange(op.getUpdates());

if (constInput || constUpdate) {
if (constInput) {
auto blockArgInput = block.getArgument(0);
llvm::errs() << "blockArgInput: " << blockArgInput << "\n";
auto denseAttr = DenseElementsAttr::get(
blockArgInput.getType().cast<ShapedType>(), inputSplatAttr);
auto constInputOp =
rewriter.create<stablehlo::ConstantOp>(op.getLoc(), denseAttr);
llvm::errs() << "constInputOp: " << constInputOp << "\n";
blockArgInput.replaceAllUsesWith(constInputOp);
}

if (constUpdate) {
auto blockArgUpdate = block.getArgument(1);
llvm::errs() << "blockArgUpdate: " << blockArgUpdate << "\n";
auto denseAttr = DenseElementsAttr::get(
blockArgUpdate.getType().cast<ShapedType>(), updateSplatAttr);
auto constUpdateOp =
rewriter.create<stablehlo::ConstantOp>(op.getLoc(), denseAttr);
llvm::errs() << "constUpdateOp: " << constUpdateOp << "\n";
blockArgUpdate.replaceAllUsesWith(constUpdateOp);
}

auto newOp = rewriter.create<stablehlo::ScatterOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
op.getScatterIndices(), op.getUpdates(),
op.getScatterDimensionNumbers(), op.getIndicesAreSorted(),
op.getUniqueIndices());
newOp.getUpdateComputation().takeBody(region);
newOp->setAttr("enzymexla.transformed_by_scatter_const_prop",
rewriter.getUnitAttr());
rewriter.replaceOp(op, newOp);

return success();
}

return failure();
}

private:
std::tuple<bool, Attribute>
isConstantSplatValueRange(ValueRange range) const {
Attribute splatAttr = nullptr;
bool isConstant = true;
for (auto val : range) {
DenseElementsAttr attr;
if (matchPattern(val, m_Constant(&attr))) {
if (attr.isSplat()) {
if (!splatAttr) {
splatAttr = attr.getSplatValue<Attribute>();
continue;
} else if (splatAttr != attr.getSplatValue<Attribute>()) {
isConstant = false;
break;
}
} else {
isConstant = false;
break;
}
} else {
isConstant = false;
break;
}
}
return std::make_tuple(isConstant, splatAttr);
};
};

/////////////// End Imported from stablehlo

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"
Expand Down Expand Up @@ -7873,7 +7969,8 @@ struct EnzymeHLOOptPass
ZeroExtentTensorCanon,
CompareSelectSimplify,
NotSelectSimplify,
CommonCompareExpressionRewrite
CommonCompareExpressionRewrite,
ScatterUpdateComputationConstProp
>(context);
// clang-format on
patterns.add<SelectOpCanon>(max_constant_expansion, context,
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,11 @@ def CommonCompareExpressionRewritePatterns : EnzymeHLOPatternOp<
let patterns = ["CommonCompareExpressionRewrite"];
}

def ApplyScatterUpdateComputationConstPropPatterns : EnzymeHLOPatternOp<
"scatter_update_computation_const_prop"> {
let patterns = ["ScatterUpdateComputationConstProp"];
}

// TODO: better naming for parameters requires a static interface for
// constructing them in search.

Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def hlo_opts():
compare_select_simplify;
common_compare_expression_rewrite;
not_select_simplify;
scatter_update_computation_const_prop;
transpose_unary_transpose_abs<1>;
transpose_unary_transpose_neg<1>;
Expand Down
12 changes: 12 additions & 0 deletions test/lit_tests/scatterupdatecomputationconstprop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

func.func @main1(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>, %arg2: tensor<45x3xf32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<64x3xf32>
%c = stablehlo.constant dense<0> : tensor<45x1xi32>
%0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
}

0 comments on commit 620d438

Please sign in to comment.