Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve enzyme gradient ops removal in while op #167

Merged
merged 8 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,23 @@ class AutoDiffWhileRev
Block *oBB = &orig->getRegion(1).front();
auto term = oBB->getTerminator();

// All values defined in the body should have no use outside this block
// therefore we can set their diffe to zero upon entering the reverse
// block to simplify the work of the remove-unnecessary-enzyme-ops pass.
for (auto operand : oBB->getArguments()) {
// All arguments should have no use outside this block therefore we can
// set their diffe to zero upon entering the reverse block to simplify
// the work of the remove-unnecessary-enzyme-ops pass.
if (!gutils->isConstantValue(operand)) {
gutils->zeroDiffe(operand, bodyBuilder);
}
}

for (auto &it : oBB->getOperations()) {
for (auto res : it.getResults()) {
if (!gutils->isConstantValue(res)) {
gutils->zeroDiffe(res, bodyBuilder);
}
}
}

int revIdx = 1;
for (auto &&[active, operand] :
llvm::zip(operandsActive, term->getOperands())) {
Expand Down
6 changes: 3 additions & 3 deletions src/enzyme_ad/jax/Implementations/WhileLoopInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace enzyme {
struct WhileLoopInfo {
WhileOp op;

mlir::Value start; // garanteed to dominate the while op
mlir::Value limit; // not garanteed to dominate the while op
mlir::Value step; // not garanteed to dominate the while op
mlir::Value start; // guaranteed to dominate the while op
mlir::Value limit; // not guaranteed to dominate the while op
mlir::Value step; // not guaranteed to dominate the while op

WhileLoopInfo(WhileOp op_) : op(op_) {}

Expand Down
45 changes: 20 additions & 25 deletions test/lit_tests/diffrules/stablehlo/while2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,30 @@ module {
// REVERSE-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<1xf32>
// REVERSE-NEXT: %cst_1 = arith.constant dense<0.000000e+00> : tensor<5xf32>
// REVERSE-NEXT: %cst_2 = arith.constant dense<0.000000e+00> : tensor<f32>
// REVERSE-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<1xf32>>
// REVERSE-NEXT: "enzyme.set"(%0, %cst) : (!enzyme.Gradient<tensor<1xf32>>, tensor<1xf32>) -> ()
// REVERSE-NEXT: %1:4 = stablehlo.while(%iterArg = %arg0, %iterArg_3 = %c, %iterArg_4 = %arg1, %iterArg_5 = %c) : tensor<5xf32>, tensor<i64>, tensor<f32>, tensor<i64>
// REVERSE-NEXT: %0:4 = stablehlo.while(%iterArg = %arg0, %iterArg_3 = %c, %iterArg_4 = %arg1, %iterArg_5 = %c) : tensor<5xf32>, tensor<i64>, tensor<f32>, tensor<i64>
// REVERSE-NEXT: cond {
// REVERSE-NEXT: %6 = stablehlo.compare LT, %iterArg_3, %c_0, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
// REVERSE-NEXT: stablehlo.return %6 : tensor<i1>
// REVERSE-NEXT: %5 = stablehlo.compare LT, %iterArg_3, %c_0, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
// REVERSE-NEXT: stablehlo.return %5 : tensor<i1>
// REVERSE-NEXT: } do {
// REVERSE-NEXT: %6 = stablehlo.add %iterArg_5, %c_0 : tensor<i64>
// REVERSE-NEXT: %7 = stablehlo.slice %iterArg [0:1] : (tensor<5xf32>) -> tensor<1xf32>
// REVERSE-NEXT: %8 = stablehlo.reshape %7 : (tensor<1xf32>) -> tensor<f32>
// REVERSE-NEXT: stablehlo.return %iterArg, %c, %8, %6 : tensor<5xf32>, tensor<i64>, tensor<f32>, tensor<i64>
// REVERSE-NEXT: %5 = stablehlo.add %iterArg_5, %c_0 : tensor<i64>
// REVERSE-NEXT: %6 = stablehlo.slice %iterArg [0:1] : (tensor<5xf32>) -> tensor<1xf32>
// REVERSE-NEXT: %7 = stablehlo.reshape %6 : (tensor<1xf32>) -> tensor<f32>
// REVERSE-NEXT: stablehlo.return %iterArg, %c, %7, %5 : tensor<5xf32>, tensor<i64>, tensor<f32>, tensor<i64>
// REVERSE-NEXT: }
// REVERSE-NEXT: %2 = arith.addf %arg2, %cst_2 : tensor<f32>
// REVERSE-NEXT: %3:3 = stablehlo.while(%iterArg = %c, %iterArg_3 = %cst_1, %iterArg_4 = %2) : tensor<i64>, tensor<5xf32>, tensor<f32>
// REVERSE-NEXT: %1 = arith.addf %arg2, %cst_2 : tensor<f32>
// REVERSE-NEXT: %2:3 = stablehlo.while(%iterArg = %c, %iterArg_3 = %cst_1, %iterArg_4 = %1) : tensor<i64>, tensor<5xf32>, tensor<f32>
// REVERSE-NEXT: cond {
// REVERSE-NEXT: %6 = stablehlo.compare LT, %iterArg, %1#3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
// REVERSE-NEXT: stablehlo.return %6 : tensor<i1>
// REVERSE-NEXT: %5 = stablehlo.compare LT, %iterArg, %0#3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
// REVERSE-NEXT: stablehlo.return %5 : tensor<i1>
// REVERSE-NEXT: } do {
// REVERSE-NEXT: %6 = stablehlo.add %iterArg, %c_0 : tensor<i64>
// REVERSE-NEXT: %7 = stablehlo.reshape %iterArg_4 : (tensor<f32>) -> tensor<1xf32>
// REVERSE-NEXT: %8 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<1xf32>>) -> tensor<1xf32>
// REVERSE-NEXT: %9 = arith.addf %8, %7 : tensor<1xf32>
// REVERSE-NEXT: "enzyme.set"(%0, %9) : (!enzyme.Gradient<tensor<1xf32>>, tensor<1xf32>) -> ()
// REVERSE-NEXT: "enzyme.set"(%0, %cst) : (!enzyme.Gradient<tensor<1xf32>>, tensor<1xf32>) -> ()
// REVERSE-NEXT: %10 = stablehlo.pad %9, %cst_2, low = [0], high = [4], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<5xf32>
// REVERSE-NEXT: %11 = arith.addf %iterArg_3, %10 : tensor<5xf32>
// REVERSE-NEXT: stablehlo.return %6, %11, %cst_2 : tensor<i64>, tensor<5xf32>, tensor<f32>
// REVERSE-NEXT: %5 = stablehlo.add %iterArg, %c_0 : tensor<i64>
// REVERSE-NEXT: %6 = stablehlo.reshape %iterArg_4 : (tensor<f32>) -> tensor<1xf32>
// REVERSE-NEXT: %7 = arith.addf %6, %cst : tensor<1xf32>
// REVERSE-NEXT: %8 = stablehlo.pad %7, %cst_2, low = [0], high = [4], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<5xf32>
// REVERSE-NEXT: %9 = arith.addf %iterArg_3, %8 : tensor<5xf32>
// REVERSE-NEXT: stablehlo.return %5, %9, %cst_2 : tensor<i64>, tensor<5xf32>, tensor<f32>
// REVERSE-NEXT: }
// REVERSE-NEXT: %4 = arith.addf %3#1, %cst_1 : tensor<5xf32>
// REVERSE-NEXT: %5 = arith.addf %3#2, %cst_2 : tensor<f32>
// REVERSE-NEXT: return %4, %5 : tensor<5xf32>, tensor<f32>
// REVERSE-NEXT: %3 = arith.addf %2#1, %cst_1 : tensor<5xf32>
// REVERSE-NEXT: %4 = arith.addf %2#2, %cst_2 : tensor<f32>
// REVERSE-NEXT: return %3, %4 : tensor<5xf32>, tensor<f32>
// REVERSE-NEXT: }
Loading