Skip to content

Commit

Permalink
Revert "pattern to remove dead results of stablehlo.while"
Browse files Browse the repository at this point in the history
This reverts commit 47db9dd.
  • Loading branch information
wsmoses committed Jan 24, 2025
1 parent df4b196 commit 5beb203
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 233 deletions.
186 changes: 13 additions & 173 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
// ops.
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -1444,145 +1442,6 @@ struct ShiftRightLogicalSimplify final
}
};

struct WhileDeadResults final : OpRewritePattern<mlir::stablehlo::WhileOp> {
using OpRewritePattern::OpRewritePattern;

bool isLoopResultDead(OpResult result) const {
// Not dead if the result is in use.
if (!result.use_empty())
return false;

// Or if the corresponding argument is being used in computing the
// condition.
auto whileOp = cast<mlir::stablehlo::WhileOp>(result.getOwner());
Value condArgument =
whileOp.getCond().getArgument(result.getResultNumber());
SetVector<Operation *> forwardSlice;
getForwardSlice(condArgument, &forwardSlice);
if (!llvm::all_of(forwardSlice, mlir::isPure))
return false;
if (forwardSlice.contains(whileOp.getCond().front().getTerminator()))
return false;

// Or in computing another result. We first do a fast-path check of having
// the argument not influencing the terminator operation, before going into
// finer-grain analysis.
//
// TODO: it is possible that this argument does influence another terminator
// operand, but that operand in turn corresponds to a dead value, but
// handling that would require more complex logic of detecting dead cycles
// of value chains.
forwardSlice.clear();
assert(llvm::hasSingleElement(whileOp.getBody()));
Value bodyArgument =
whileOp.getBody().getArgument(result.getResultNumber());
getForwardSlice(bodyArgument, &forwardSlice);
if (!llvm::all_of(forwardSlice, mlir::isPure))
return false;

Operation *bodyTerminator = whileOp.getBody().front().getTerminator();
if (!forwardSlice.contains(bodyTerminator))
return true;

for (OpOperand &terminatorOperand : bodyTerminator->getOpOperands()) {
if (terminatorOperand.getOperandNumber() == result.getResultNumber())
continue;

SetVector<Operation *> backwardSlice;
getBackwardSlice(terminatorOperand.get(), &backwardSlice);
for (Operation *op : backwardSlice) {
if (llvm::is_contained(op->getOperands(), bodyArgument))
return false;
}
}
return true;
}

void replaceTerminator(PatternRewriter &rewriter, Region &region,
ArrayRef<int64_t> deadResults) const {
Operation *terminator = region.front().getTerminator();
SmallVector<Value> terminatorOperands;
for (auto &&[i, operand] : llvm::enumerate(terminator->getOperands())) {
if (!llvm::is_contained(deadResults, i))
terminatorOperands.push_back(operand);
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReturnOp>(
terminator, TypeRange(), terminatorOperands, terminator->getAttrs());
}

LogicalResult matchAndRewrite(mlir::stablehlo::WhileOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> deadResults;
for (OpResult result : op.getResults()) {
if (!isLoopResultDead(result))
continue;

deadResults.push_back(result.getResultNumber());
}
if (deadResults.empty())
return failure();

SetVector<Operation *> condSlice, bodySlice;
for (int64_t i : deadResults) {
getForwardSlice(op.getCond().getArgument(i), &condSlice);
getForwardSlice(op.getBody().getArgument(i), &bodySlice);
}
condSlice.remove(op.getCond().front().getTerminator());
bodySlice.remove(op.getBody().front().getTerminator());
condSlice = mlir::topologicalSort(condSlice);
bodySlice = mlir::topologicalSort(bodySlice);
for (Operation *erasable : llvm::reverse(condSlice))
rewriter.eraseOp(erasable);
for (Operation *erasable : llvm::reverse(bodySlice))
rewriter.eraseOp(erasable);

replaceTerminator(rewriter, op.getCond(), deadResults);
replaceTerminator(rewriter, op.getBody(), deadResults);

SmallVector<Value> operands;
SmallVector<Type> resultTypes;
SmallVector<Location> condBlockArgLocs, bodyBlockArgsLocs;
for (auto &&[i, operand, resultType] :
llvm::enumerate(op->getOperands(), op.getResultTypes())) {
if (llvm::is_contained(deadResults, i))
continue;

operands.push_back(operand);
resultTypes.push_back(resultType);
condBlockArgLocs.push_back(op.getCond().getArgument(i).getLoc());
bodyBlockArgsLocs.push_back(op.getBody().getArgument(i).getLoc());
}

auto updated = rewriter.create<mlir::stablehlo::WhileOp>(
op->getLoc(), resultTypes, operands, op->getAttrs());
SmallVector<Value> resultReplacements;
for (int64_t old = 0, upd = 0, end = op->getNumResults(); old < end;
++old) {
if (llvm::is_contained(deadResults, old)) {
resultReplacements.push_back(nullptr);
continue;
}
resultReplacements.push_back(updated->getResult(upd));
++upd;
}

for (int64_t i : llvm::reverse(deadResults))
op.getCond().eraseArgument(i);
rewriter.inlineRegionBefore(op.getCond(), updated.getCond(),
updated.getCond().begin());

for (int64_t i : llvm::reverse(deadResults))
op.getBody().eraseArgument(i);
rewriter.inlineRegionBefore(op.getBody(), updated.getBody(),
updated.getBody().begin());

rewriter.replaceOp(op, resultReplacements);
return success();
}
};

struct NegativePadToSlice final : OpRewritePattern<mlir::stablehlo::PadOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -7520,38 +7379,19 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
}
patterns.add<NoNanAddSubSimplify>((no_nan || all_finite), context);

// clang-format off
patterns.add<
BroadcastInDimOpCanon,
ChainedDynamicBroadcastInDimCanonicalization,
CompareOpCanon,
ConjComplexNegate,
ConvertOpCanon,
DivideSqrtToMultiplyRsqrt,
DynamicBroadcastInDimAllDimsNonExpanding,
DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicGatherOpIsNotDynamic,
DynamicReshapeOpCanon,
EmptyReduceOpCanon,
GatherOpCanon,
GetDimensionSizeOpCanon,
GetTupleElementOpCanon,
IfInline,
IfToSelect,
ImagOpCanon,
MergeConsecutiveReshapes,
NoopReduceOpCanon,
RealOpCanon,
ReorderElementwiseAndShapeOp,
ReshapeOpCanon,
SelectOpUsedWithinIf,
TransposeBroadcastInDimToBroadcastInDim,
TransposeIsReshape,
WhileDeadResults,
WhileSimplify,
ZeroExtentTensorCanon
>(context);
// clang-format on
patterns.add<CompareOpCanon, BroadcastInDimOpCanon,
TransposeBroadcastInDimToBroadcastInDim, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
SelectOpUsedWithinIf, IfInline, IfToSelect, WhileSimplify,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp,
DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>(
context);
patterns.add<SelectOpCanon>(max_constant_expansion, context,
PatternBenefit(65000));
patterns.add<ConcatenateOpCanon>(max_constant_expansion, context,
Expand Down
4 changes: 0 additions & 4 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -797,10 +797,6 @@ def ApplyShiftRightLogicalSimplifyPatterns : EnzymeHLOPatternOp<
"shift_right_logical_simplify"> {
let patterns = ["ShiftRightLogicalSimplify"];
}
def WhileDeadResultPatterns : EnzymeHLOPatternOp<
"while_deadresult"> {
let patterns = ["WhileDeadResults"];
}
def ApplyRemSimplifyPatterns : EnzymeHLOPatternOp<
"rem_simplify"> {
let patterns = ["RemSimplify"];
Expand Down
1 change: 0 additions & 1 deletion src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def hlo_opts():
if_inline<1>;
if_to_select<1>;
while_simplify<1>;
while_deadresult<1>;
dot_reshape_pad<1>;
pad_dot_general<1>(1);
Expand Down
55 changes: 0 additions & 55 deletions test/lit_tests/whiledeadarg.mlir

This file was deleted.

0 comments on commit 5beb203

Please sign in to comment.