diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index c6e5e63dc..430ff922b 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6935,6 +6935,246 @@ struct NoopReverse final : OpRewritePattern { } }; +bool check_periodicity(std::vector &sliceStart, + std::vector &sliceStride, std::vector &prev, + int dim, int k) { + if (sliceStart[k] == -1) { + sliceStart[k] = dim; + prev[k] = dim; + return true; + } + if (dim - prev[k] == sliceStride[k]) { + prev[k] = dim; + return true; + } else + return false; +} + +int get_index(const std::vector &strides, const std::vector &dims, + int rank_lower, int rank_upper, int iter) { + int index = 0; + for (int i = rank_lower; i >= rank_upper; i--) { + int mod = iter % dims[i]; + iter = iter / dims[i]; + index += mod * strides[i]; + } + return index; +} + +// Rank 0 to n-1, where 0 is the slowest moving dimension (outermost) +bool isGatherPeriodic(const std::vector &data, + const std::vector &dims, int rank, int x_dim, + std::vector &sliceStart, std::vector &sliceEnd, + std::vector &sliceStride) { + // Calculate total elements and strides + int total_size = 1; + std::vector strides(rank); + for (int i = 0; i < rank; i++) + total_size *= dims[i]; + + strides[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) + strides[i] = strides[i + 1] * dims[i + 1]; + + int outer_batch_size = 1; + for (int i = 0; i <= x_dim; i++) { + outer_batch_size *= dims[i]; + } + int inner_batch_size = 1; + for (int i = x_dim + 1; i < rank; i++) { + inner_batch_size *= dims[i]; + } + + int vec_length = dims[x_dim]; + + // Run 2 iterations of each rank to get the stride on each rank + bool strideFound = false; + for (int i = 0; i < outer_batch_size / vec_length; i++) { + int index_outer = get_index(strides, dims, x_dim - 1, 0, i); + for (int j = 0; j < inner_batch_size; j++) { + int index_inner = get_index(strides, dims, rank - 1, x_dim + 1, j); + for (int k = 0; k < vec_length; k++) { + int index = index_outer + k * inner_batch_size + index_inner; + int value = data[index]; + if (sliceStride[k] == -1) { + sliceStride[k] = value; + } else { + sliceStride[k] = value - sliceStride[k]; + strideFound = true; + } + } + if (strideFound) + break; + } + if (strideFound) + break; + } + + // Run all the iterations to check if the strides match + std::vector prev(vec_length, -1); + for (int i = 0; i < outer_batch_size / vec_length; i++) { + int index_outer = get_index(strides, dims, x_dim - 1, 0, i); + for (int j = 0; j < inner_batch_size; j++) { + int index_inner = get_index(strides, dims, rank - 1, x_dim + 1, j); + for (int k = 0; k < vec_length; k++) { + int index = index_outer + k * inner_batch_size + index_inner; + int value = data[index]; + auto res = check_periodicity(sliceStart, sliceStride, prev, value, k); + if (!res) + return false; + } + } + } + for (int k = 0; k < vec_length; k++) { + if(sliceStride[k] >= 0) { + sliceEnd[k] = prev[k] + 1; + } + else if (sliceStride[k] < 0) { + sliceEnd[k] = prev[k]; + sliceStart[k] = sliceStart[k] + 1; + } + } + return true; +} + +/// Converts gather ops to slice ops in case we have a single set of constant +/// indices. +struct GatherToSliceOp final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::GatherOp gather, + PatternRewriter &rewriter) const override { + DenseIntElementsAttr index; + // 1. Check following preconditions for converting gather to slice op + // Preconditions: + // i. Check if all dims are collapsed? + // ii. Check if sliceSizes are 1 in each dim (else more complicated and + // there could be a subset that + // could still be transformed to slice op if there is no overlap) + // iii. Based on indexVector dim, check if correspoding values in each dim + // are either constant or strided + // 2. If so, convert the gather to a slice op + // 3. If not, return failure + // 4. To convert to slice operation, + + // Get all the other properties of the gather operation + auto startIndices = gather.getStartIndices(); + auto gatherOperands = gather.getOperands(); + auto gatherDimensionNumbers = gather.getDimensionNumbers(); + auto gatherSliceSizes = gather.getSliceSizes(); + auto gatherIndexVectorDim = gatherDimensionNumbers.getIndexVectorDim(); + auto gatherStartIndexMap = gatherDimensionNumbers.getStartIndexMap(); + auto collapsedSliceDims = gatherDimensionNumbers.getCollapsedSliceDims(); + auto offsetDims = + gatherDimensionNumbers.getOffsetDims(); // Non collapsed dimensions + + ////Check rank of gather operands + // auto gatherOperandsRank = gatherOperands.getType().getRank(); + + //// TODO: Currently only handling simplified case + //// Check collapseSlicedims size equals gatherOperandsRank + // if(collapsedSliceDims.size() != gatherOperandsRank) + // return failure(); + + // Check if sliceSizes are 1 in each dim + if (llvm::any_of(gatherSliceSizes, [](int64_t size) { return size != 1; })) + return failure(); + + // From start indices check if it's constant, else return failure + if (!matchPattern(startIndices, m_Constant(&index))) + return failure(); + + // Currently only handling constant index case with dense elements + std::vector indices; + if (auto denseAttr = llvm::dyn_cast(index)) { + for (auto value : denseAttr.getValues()) { + int64_t intValue = value.getSExtValue(); + indices.push_back(intValue); + } + } else { + return failure(); + } + + // Process indices in row-major order + auto tensorType = index.getType().cast(); + auto rank = tensorType.getRank(); + auto shape = tensorType.getShape(); + + // dims : gatherIndexVectorDim, innermostDim .... outermostDim + int indexVectorSize = shape[gatherIndexVectorDim]; + + std::vector dims(shape.size()); + for (int i = 0; i < shape.size(); i++) { + dims[i] = shape[i]; + } + + std::vector sliceStart(indexVectorSize, -1); + std::vector sliceEnd(indexVectorSize, -1); + std::vector sliceStride(indexVectorSize, -1); + + auto isPeriodic = + isGatherPeriodic(indices, dims, rank, gatherIndexVectorDim, sliceStart, + sliceEnd, sliceStride); + if (!isPeriodic) + return failure(); + + SmallVector sliceStartI64(sliceStart.begin(), sliceStart.end()); + SmallVector sliceEndI64(sliceEnd.begin(), sliceEnd.end()); + SmallVector sliceStrideI64(sliceStride.begin(), + sliceStride.end()); + + // Create the slice type + SmallVector sliceShape(sliceStrideI64.size()); + SmallVector collapsedShape; + SmallVector reverseDims; + bool reverse = false; + for (int i = 0; i < sliceStrideI64.size(); i++) { + sliceShape[i] = sliceEndI64[i] - sliceStartI64[i]; + //Reverse the slice if it's negative + if(sliceShape[i] < 0) { + sliceShape[i] = -sliceShape[i]; + sliceStrideI64[i] = -sliceStrideI64[i]; + auto temp = sliceStartI64[i]; + sliceStartI64[i] = sliceEndI64[i]; + sliceEndI64[i] = temp; + reverseDims.push_back(i); + reverse = true; + } + + if (sliceShape[i] != 1) + collapsedShape.push_back(sliceShape[i]); + } + Type elementType = gather.getType().getElementType(); + auto sliceType = RankedTensorType::get(sliceShape, elementType); + + // Fix the constant dims + for (int i = 0; i < sliceStrideI64.size(); i++) { + if (sliceStrideI64[i] == 0) + sliceStrideI64[i] = 1; + } + + // Creating the slice op + Value sliceOp = rewriter.create( + gather.getLoc(), sliceType, gather.getOperand(), + rewriter.getDenseI64ArrayAttr(sliceStartI64), + rewriter.getDenseI64ArrayAttr(sliceEndI64), + rewriter.getDenseI64ArrayAttr(sliceStrideI64)); + + if(reverse) { + sliceOp = rewriter.create(gather.getLoc(), sliceOp, reverseDims); + } + + // Create result type and reshape operation + auto collapsedType = RankedTensorType::get(collapsedShape, elementType); + Value sliceReshaped = rewriter.create( + gather.getLoc(), collapsedType, sliceOp); + + rewriter.replaceOp(gather, sliceReshaped); + + return success(); + } +}; + /// Converts gather ops to slice ops in case we have a single set of constant /// indices. struct GatherOpCanon final : OpRewritePattern { @@ -8215,6 +8455,7 @@ struct EnzymeHLOOptPass // clang-format off patterns.add< + GatherToSliceOp, BroadcastInDimOpCanon, ChainedDynamicBroadcastInDimCanonicalization, CompareOpCanon, diff --git a/test/lit_tests/gatherToSlice.mlir b/test/lit_tests/gatherToSlice.mlir new file mode 100644 index 000000000..03a6c2ef8 --- /dev/null +++ b/test/lit_tests/gatherToSlice.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt -split-input-file -convert-stablehlo-gather-to-slice %s | FileCheck %s + +// Original example: +// %c_1179 = stablehlo.constant dense<"0x00000000000000006B00000000000000070000000000000000..."> : tensor<180x3xi64> +// %2803 = stablehlo.dynamic_update_slice %2715, %2802, %c_1336, %c_1308, %c_1329 : (tensor<1x128x194xf64>, tensor<1x1x180xf64>, tensor, tensor, tensor) -> tensor<1x128x194xf64> +// %2804 = "stablehlo.gather"(%2803, %c_1179) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, +// slice_sizes = array}> : (tensor<1x128x194xf64>, tensor<180x3xi64>) -> tensor<180xf64> + +// CHECK-LABEL: func @gather_to_slice +//func.func @gather_to_slice_wrapped_around(%arg0: tensor<1x128x194xf64>) -> tensor<180xf64> { +// %indices = stablehlo.constant dense<"0xtensor<180x3xi64> +// +// // CHECK: %[[SLICE:.*]] = "stablehlo.slice"(%arg0) +// // CHECK-SAME: start_indices = array +// // CHECK-SAME: limit_indices = array +// // CHECK-SAME: strides = array +// // CHECK: return %[[SLICE]] : tensor<180xf64> +// %result = "stablehlo.gather"(%arg0, %indices) { +// dimension_numbers = #stablehlo.gather< +// collapsed_slice_dims = [0, 1, 2], +// start_index_map = [0, 1, 2], +// index_vector_dim = 1 +// >, +// indices_are_sorted = false, +// slice_sizes = array +// } : (tensor<1x128x194xf64>, tensor<180x3xi64>) -> tensor<180xf64> +// +// return %result : tensor<180xf64> +//} + +// ----- + +// Example of multi dim strided slice op (for reference): +// %1 = "stablehlo.slice"(%arg0) +// start_indices = array +// limit_indices = array +// strides = array +// } : tensor<1x128x194xf64> -> tensor<180xf64> +// + +func.func @gather_to_slice_reverse(%arg0: tensor<1x128x194xf64>) -> tensor<179xf64> { + %indices = stablehlo.constant dense<"0xtensor<179x3xi64> + // CHECK-NOT: stablehlo.gather + // CHECK: %[[SLICE:.*]] = stablehlo.slice %[[ARG0]] [0:1, 107:108, 8:187] + // CHECK-SAME: : (tensor<1x128x194xf64>) -> tensor<1x1x179xf64> + // CHECK: %[[REVERSED:.*]] = stablehlo.reverse %[[SLICE]], dims = [2] + // CHECK-SAME: : tensor<1x1x179xf64> + // CHECK: %[[RESHAPED:.*]] = stablehlo.reshape %[[REVERSED]] + // CHECK-SAME: : (tensor<1x1x179xf64>) -> tensor<179xf64> + %result = "stablehlo.gather"(%arg0, %indices) { + dimension_numbers = #stablehlo.gather< + collapsed_slice_dims = [0, 1, 2], + start_index_map = [0, 1, 2], + index_vector_dim = 1 + >, + indices_are_sorted = false, + slice_sizes = array + } : (tensor<1x128x194xf64>, tensor<179x3xi64>) -> tensor<179xf64> + + return %result : tensor<179xf64> +} + +func.func @gather_to_slice_collapse_dims(%arg0: tensor<1x128x194xf64>) -> tensor<3xf64> { + %indices = stablehlo.constant dense<[ + [0, 10, 4], + [0, 10, 5], + [0, 10, 6] +]> : tensor<3x3xi64> + // CHECK-NOT: stablehlo.gather + // CHECK: %[[SLICE:.*]] = stablehlo.slice %[[ARG0]] [0:1, 10:11, 4:7] + // CHECK-SAME: : (tensor<1x128x194xf64>) -> tensor<1x1x3xf64> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[SLICE]] + // CHECK-SAME: : (tensor<1x1x3xf64>) -> tensor<3xf64> + %result = "stablehlo.gather"(%arg0, %indices) { + dimension_numbers = #stablehlo.gather< + collapsed_slice_dims = [0, 1, 2], + start_index_map = [0, 1, 2], + index_vector_dim = 1 + >, + indices_are_sorted = false, + slice_sizes = array + } : (tensor<1x128x194xf64>, tensor<3x3xi64>) -> tensor<3xf64> + + return %result : tensor<3xf64> +} +