Skip to content

Commit

Permalink
[MLIR] Add utility function to create values for all dimensions of a …
Browse files Browse the repository at this point in the history
…tensor value

This is a variant of the already provided `createDynamicDimValues` helper.

Differential Revision: https://reviews.llvm.org/D131798
  • Loading branch information
frgossen committed Aug 12, 2022
1 parent 6826682 commit 2c3ca3b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ PadOp createPadScalarOp(Type type, Value source, Value pad,
ArrayRef<OpFoldResult> low, ArrayRef<OpFoldResult> high,
bool nofold, Location loc, OpBuilder &builder);

// Creates dim ops for each dynamic dimension of the raked tensor argument and
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
// returns these as values.
SmallVector<Value> createDynamicDimValues(OpBuilder &b, Location loc,
Value rankedTensor);

// Creates dim ops or constant ops for each dimension of the ranked tensor
// argument and returns these as values.
SmallVector<Value> createDimValues(OpBuilder &b, Location loc,
Value rankedTensor);

} // namespace tensor
} // namespace mlir

Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/Tensor/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
}
return dynamicDims;
}

SmallVector<Value> mlir::tensor::createDimValues(OpBuilder &b, Location loc,
Value rankedTensor) {
auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
SmallVector<Value> dims;
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
dims.push_back(
b.createOrFold<tensor::DimOp>(loc, rankedTensor, en.index()));
}
return dims;
}

0 comments on commit 2c3ca3b

Please sign in to comment.