Skip to content

Commit

Permalink
fix: no recursion + test update
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 7, 2025
1 parent 3e2df17 commit d5b2556
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
23 changes: 12 additions & 11 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7649,11 +7649,6 @@ struct ScatterUpdateComputationConstProp

LogicalResult matchAndRewrite(stablehlo::ScatterOp op,
PatternRewriter &rewriter) const final {
// If this scatter op was created by this pass, don't rewrite it again
// Q: Is there a better way to do this?
if (op->hasAttr("enzymexla.transformed_by_scatter_const_prop"))
return failure();

auto &region = op.getUpdateComputation();
auto &block = region.front();

Expand All @@ -7666,32 +7661,38 @@ struct ScatterUpdateComputationConstProp
isConstantSplatValueRange(op.getUpdates());

if (constInput || constUpdate) {
if (constInput) {
auto blockArgInput = block.getArgument(0);
bool inputTransformed = false;
bool updateTransformed = false;
auto blockArgInput = block.getArgument(0);
auto blockArgUpdate = block.getArgument(1);

if (constInput && !blockArgInput.getUses().empty()) {
inputTransformed = true;
auto denseAttr = DenseElementsAttr::get(
blockArgInput.getType().cast<ShapedType>(), inputSplatAttr);
auto constInputOp =
rewriter.create<stablehlo::ConstantOp>(op.getLoc(), denseAttr);
blockArgInput.replaceAllUsesWith(constInputOp);
}

if (constUpdate) {
auto blockArgUpdate = block.getArgument(1);
if (constUpdate && !blockArgUpdate.getUses().empty()) {
updateTransformed = true;
auto denseAttr = DenseElementsAttr::get(
blockArgUpdate.getType().cast<ShapedType>(), updateSplatAttr);
auto constUpdateOp =
rewriter.create<stablehlo::ConstantOp>(op.getLoc(), denseAttr);
blockArgUpdate.replaceAllUsesWith(constUpdateOp);
}

if (!inputTransformed && !updateTransformed)
return failure();

auto newOp = rewriter.create<stablehlo::ScatterOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
op.getScatterIndices(), op.getUpdates(),
op.getScatterDimensionNumbers(), op.getIndicesAreSorted(),
op.getUniqueIndices());
newOp.getUpdateComputation().takeBody(region);
newOp->setAttr("enzymexla.transformed_by_scatter_const_prop",
rewriter.getUnitAttr());
rewriter.replaceOp(op, newOp);

return success();
Expand Down
9 changes: 3 additions & 6 deletions test/lit_tests/diffrules/stablehlo/gather.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ module {
// REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x3xf32>
// REVERSE-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// REVERSE-NEXT: %0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// REVERSE-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// REVERSE-NEXT: %1 = stablehlo.add %arg3, %arg4 : tensor<f32>
// REVERSE-NEXT: stablehlo.return %1 : tensor<f32>
// REVERSE-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// REVERSE-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// REVERSE-NEXT: stablehlo.return %arg4 : tensor<f32>
// REVERSE-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// REVERSE-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// REVERSE-NEXT: }


34 changes: 34 additions & 0 deletions test/lit_tests/scatterupdatecomputationconstprop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,37 @@ func.func @main1(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>, %arg2: tensor
}) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
}

// CHECK: func.func @main1(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>, %arg2: tensor<45x3xf32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x3xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// CHECK-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// CHECK-NEXT: stablehlo.return %arg4 : tensor<f32>
// CHECK-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// CHECK-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// CHECK-NEXT: }

func.func @main2(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
%cst = stablehlo.constant dense<2.000000e+00> : tensor<64x3xf32>
%c = stablehlo.constant dense<0> : tensor<45x1xi32>
%cst_2 = stablehlo.constant dense<5.000000e+00> : tensor<45x3xf32>
%0 = "stablehlo.scatter"(%cst, %arg1, %cst_2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.multiply %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
}

// CHECK: func.func @main2(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+01> : tensor<f32>
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<64x3xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<5.000000e+00> : tensor<45x3xf32>
// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst_0, %arg1, %cst_1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
// CHECK-NEXT: stablehlo.return %cst : tensor<f32>
// CHECK-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// CHECK-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// CHECK-NEXT: }

0 comments on commit d5b2556

Please sign in to comment.