From 5beb2038222e8c1dc457f8982f7bc1f2447ae3f4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 23 Jan 2025 23:43:04 -0600 Subject: [PATCH] Revert "pattern to remove dead results of stablehlo.while" This reverts commit 47db9dda28d30c618c3bf59d4406d2def0162f24. --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 186 ++---------------- .../jax/TransformOps/TransformOps.td | 4 - src/enzyme_ad/jax/primitives.py | 1 - test/lit_tests/whiledeadarg.mlir | 55 ------ 4 files changed, 13 insertions(+), 233 deletions(-) delete mode 100644 test/lit_tests/whiledeadarg.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 876bcf6e8..971c49ff7 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -10,8 +10,6 @@ // ops. //===----------------------------------------------------------------------===// -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -1444,145 +1442,6 @@ struct ShiftRightLogicalSimplify final } }; -struct WhileDeadResults final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - bool isLoopResultDead(OpResult result) const { - // Not dead if the result is in use. - if (!result.use_empty()) - return false; - - // Or if the corresponding argument is being used in computing the - // condition. - auto whileOp = cast(result.getOwner()); - Value condArgument = - whileOp.getCond().getArgument(result.getResultNumber()); - SetVector forwardSlice; - getForwardSlice(condArgument, &forwardSlice); - if (!llvm::all_of(forwardSlice, mlir::isPure)) - return false; - if (forwardSlice.contains(whileOp.getCond().front().getTerminator())) - return false; - - // Or in computing another result. We first do a fast-path check of having - // the argument not influencing the terminator operation, before going into - // finer-grain analysis. - // - // TODO: it is possible that this argument does influence another terminator - // operand, but that operand in turn corresponds to a dead value, but - // handling that would require more complex logic of detecting dead cycles - // of value chains. - forwardSlice.clear(); - assert(llvm::hasSingleElement(whileOp.getBody())); - Value bodyArgument = - whileOp.getBody().getArgument(result.getResultNumber()); - getForwardSlice(bodyArgument, &forwardSlice); - if (!llvm::all_of(forwardSlice, mlir::isPure)) - return false; - - Operation *bodyTerminator = whileOp.getBody().front().getTerminator(); - if (!forwardSlice.contains(bodyTerminator)) - return true; - - for (OpOperand &terminatorOperand : bodyTerminator->getOpOperands()) { - if (terminatorOperand.getOperandNumber() == result.getResultNumber()) - continue; - - SetVector backwardSlice; - getBackwardSlice(terminatorOperand.get(), &backwardSlice); - for (Operation *op : backwardSlice) { - if (llvm::is_contained(op->getOperands(), bodyArgument)) - return false; - } - } - return true; - } - - void replaceTerminator(PatternRewriter &rewriter, Region ®ion, - ArrayRef deadResults) const { - Operation *terminator = region.front().getTerminator(); - SmallVector terminatorOperands; - for (auto &&[i, operand] : llvm::enumerate(terminator->getOperands())) { - if (!llvm::is_contained(deadResults, i)) - terminatorOperands.push_back(operand); - } - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(terminator); - rewriter.replaceOpWithNewOp( - terminator, TypeRange(), terminatorOperands, terminator->getAttrs()); - } - - LogicalResult matchAndRewrite(mlir::stablehlo::WhileOp op, - PatternRewriter &rewriter) const override { - SmallVector deadResults; - for (OpResult result : op.getResults()) { - if (!isLoopResultDead(result)) - continue; - - deadResults.push_back(result.getResultNumber()); - } - if (deadResults.empty()) - return failure(); - - SetVector condSlice, bodySlice; - for (int64_t i : deadResults) { - getForwardSlice(op.getCond().getArgument(i), &condSlice); - getForwardSlice(op.getBody().getArgument(i), &bodySlice); - } - condSlice.remove(op.getCond().front().getTerminator()); - bodySlice.remove(op.getBody().front().getTerminator()); - condSlice = mlir::topologicalSort(condSlice); - bodySlice = mlir::topologicalSort(bodySlice); - for (Operation *erasable : llvm::reverse(condSlice)) - rewriter.eraseOp(erasable); - for (Operation *erasable : llvm::reverse(bodySlice)) - rewriter.eraseOp(erasable); - - replaceTerminator(rewriter, op.getCond(), deadResults); - replaceTerminator(rewriter, op.getBody(), deadResults); - - SmallVector operands; - SmallVector resultTypes; - SmallVector condBlockArgLocs, bodyBlockArgsLocs; - for (auto &&[i, operand, resultType] : - llvm::enumerate(op->getOperands(), op.getResultTypes())) { - if (llvm::is_contained(deadResults, i)) - continue; - - operands.push_back(operand); - resultTypes.push_back(resultType); - condBlockArgLocs.push_back(op.getCond().getArgument(i).getLoc()); - bodyBlockArgsLocs.push_back(op.getBody().getArgument(i).getLoc()); - } - - auto updated = rewriter.create( - op->getLoc(), resultTypes, operands, op->getAttrs()); - SmallVector resultReplacements; - for (int64_t old = 0, upd = 0, end = op->getNumResults(); old < end; - ++old) { - if (llvm::is_contained(deadResults, old)) { - resultReplacements.push_back(nullptr); - continue; - } - resultReplacements.push_back(updated->getResult(upd)); - ++upd; - } - - for (int64_t i : llvm::reverse(deadResults)) - op.getCond().eraseArgument(i); - rewriter.inlineRegionBefore(op.getCond(), updated.getCond(), - updated.getCond().begin()); - - for (int64_t i : llvm::reverse(deadResults)) - op.getBody().eraseArgument(i); - rewriter.inlineRegionBefore(op.getBody(), updated.getBody(), - updated.getBody().begin()); - - rewriter.replaceOp(op, resultReplacements); - return success(); - } -}; - struct NegativePadToSlice final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -7520,38 +7379,19 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { } patterns.add((no_nan || all_finite), context); - // clang-format off - patterns.add< - BroadcastInDimOpCanon, - ChainedDynamicBroadcastInDimCanonicalization, - CompareOpCanon, - ConjComplexNegate, - ConvertOpCanon, - DivideSqrtToMultiplyRsqrt, - DynamicBroadcastInDimAllDimsNonExpanding, - DynamicBroadcastInDimOpNotActuallyDynamic, - DynamicGatherOpIsNotDynamic, - DynamicReshapeOpCanon, - EmptyReduceOpCanon, - GatherOpCanon, - GetDimensionSizeOpCanon, - GetTupleElementOpCanon, - IfInline, - IfToSelect, - ImagOpCanon, - MergeConsecutiveReshapes, - NoopReduceOpCanon, - RealOpCanon, - ReorderElementwiseAndShapeOp, - ReshapeOpCanon, - SelectOpUsedWithinIf, - TransposeBroadcastInDimToBroadcastInDim, - TransposeIsReshape, - WhileDeadResults, - WhileSimplify, - ZeroExtentTensorCanon - >(context); - // clang-format on + patterns.add( + context); patterns.add(max_constant_expansion, context, PatternBenefit(65000)); patterns.add(max_constant_expansion, context, diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 0412b312f..f664ef7ca 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -797,10 +797,6 @@ def ApplyShiftRightLogicalSimplifyPatterns : EnzymeHLOPatternOp< "shift_right_logical_simplify"> { let patterns = ["ShiftRightLogicalSimplify"]; } -def WhileDeadResultPatterns : EnzymeHLOPatternOp< - "while_deadresult"> { - let patterns = ["WhileDeadResults"]; -} def ApplyRemSimplifyPatterns : EnzymeHLOPatternOp< "rem_simplify"> { let patterns = ["RemSimplify"]; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 20262baab..30a9d0f4f 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -285,7 +285,6 @@ def hlo_opts(): if_inline<1>; if_to_select<1>; while_simplify<1>; -while_deadresult<1>; dot_reshape_pad<1>; pad_dot_general<1>(1); diff --git a/test/lit_tests/whiledeadarg.mlir b/test/lit_tests/whiledeadarg.mlir deleted file mode 100644 index 84be8854c..000000000 --- a/test/lit_tests/whiledeadarg.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s - -// CHECK-LABEL: @while_deadarg -func.func @while_deadarg(%arg0: tensor<2x6x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>) -> (tensor<2x3xf32>, tensor<2xui64>, tensor<2x6x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) { - %c = stablehlo.constant dense<5> : tensor - %cst = stablehlo.constant dense<0.000000e+00> : tensor - %c_0 = stablehlo.constant dense<2> : tensor - %c_1 = stablehlo.constant dense<1> : tensor - %c_2 = stablehlo.constant dense<0> : tensor - %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x6x3xf32>) -> tensor<3x6x2xf32> - %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> - %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> - %3 = stablehlo.slice %0 [0:3, 0:1, 0:2] : (tensor<3x6x2xf32>) -> tensor<3x1x2xf32> - %4 = stablehlo.transpose %3, dims = [2, 1, 0] : (tensor<3x1x2xf32>) -> tensor<2x1x3xf32> - %5 = stablehlo.reshape %4 : (tensor<2x1x3xf32>) -> tensor<2x3xf32> - %6 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> - %7 = stablehlo.dot_general %arg1, %5, contracting_dims = [0] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32> - %8 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> - %9 = stablehlo.add %7, %8 : tensor<3x2xf32> - %10 = stablehlo.add %6, %9 : tensor<3x2xf32> - %11 = stablehlo.tanh %10 : tensor<3x2xf32> - %12 = stablehlo.reshape %11 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> - %13 = stablehlo.pad %12, %cst, low = [0, 0, 0], high = [0, 0, 5], interior = [0, 0, 0] : (tensor<3x2x1xf32>, tensor) -> tensor<3x2x6xf32> - - // CHECK: %{{.+}}:8 = stablehlo.while - %14:9 = stablehlo.while(%iterArg = %c_2, %iterArg_3 = %13, %iterArg_4 = %1, %iterArg_5 = %2, %iterArg_6 = %arg3, %iterArg_7 = %arg4, %iterArg_8 = %arg5, %iterArg_9 = %11, %iterArg_10 = %0) : tensor, tensor<3x2x6xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x2xf32>, tensor<3x6x2xf32> - cond { - %19 = stablehlo.compare LT, %iterArg, %c : (tensor, tensor) -> tensor - stablehlo.return %19 : tensor - } do { - %19 = stablehlo.add %c_0, %iterArg : tensor - %20 = stablehlo.subtract %19, %c_1 : tensor - %21 = stablehlo.dynamic_slice %iterArg_10, %c_2, %20, %c_2, sizes = [3, 1, 2] : (tensor<3x6x2xf32>, tensor, tensor, tensor) -> tensor<3x1x2xf32> - %22 = stablehlo.transpose %21, dims = [2, 1, 0] : (tensor<3x1x2xf32>) -> tensor<2x1x3xf32> - %23 = stablehlo.reshape %22 : (tensor<2x1x3xf32>) -> tensor<2x3xf32> - %24 = stablehlo.dot_general %iterArg_5, %iterArg_9, contracting_dims = [1] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - %25 = stablehlo.broadcast_in_dim %iterArg_7, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> - %26 = stablehlo.add %24, %25 : tensor<3x2xf32> - %27 = stablehlo.dot_general %iterArg_4, %23, contracting_dims = [1] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32> - %28 = stablehlo.broadcast_in_dim %iterArg_6, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> - %29 = stablehlo.add %27, %28 : tensor<3x2xf32> - %30 = stablehlo.add %26, %29 : tensor<3x2xf32> - %31 = stablehlo.tanh %30 : tensor<3x2xf32> - %32 = stablehlo.reshape %31 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> - // CHECK-NOT: dynamic_update_slice - %33 = stablehlo.dynamic_update_slice %iterArg_3, %32, %c_2, %c_2, %20 : (tensor<3x2x6xf32>, tensor<3x2x1xf32>, tensor, tensor, tensor) -> tensor<3x2x6xf32> - %34 = stablehlo.add %iterArg, %c_1 : tensor - stablehlo.return %34, %33, %iterArg_4, %iterArg_5, %iterArg_6, %iterArg_7, %iterArg_8, %31, %iterArg_10 : tensor, tensor<3x2x6xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x2xf32>, tensor<3x6x2xf32> - } - %15 = stablehlo.transpose %14#7, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32> - %16 = stablehlo.transpose %14#8, dims = [2, 1, 0] : (tensor<3x6x2xf32>) -> tensor<2x6x3xf32> - %17 = stablehlo.transpose %14#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> - %18 = stablehlo.transpose %14#3, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %15, %14#6, %16, %17, %18, %14#4, %14#5 : tensor<2x3xf32>, tensor<2xui64>, tensor<2x6x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32> -}