diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 3170115883e2b..8fcc413edf272 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -344,6 +344,43 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// ExtractStaticOp +//===----------------------------------------------------------------------===// + +def Tensor_ExtractStaticOp : Tensor_Op<"extract_static", [ + DeclareOpInterfaceMethods, + Pure, + TypesMatchWith<"result type matches element type of tensor", + "tensor", "result", + "::llvm::cast($_self).getElementType()">]> { + let summary = "element extraction operation with static indices"; + let description = [{ + The same as `tensor.extract` op except that `tensor.extract_static` op only + takes static indices. + + Example: + + ```mlir + %4 = tensor.extract_static %t[1, 2] : tensor<4x4xi32> + %5 = tensor.extract_static %rt[1, 2] : tensor + ``` + }]; + + let arguments = (ins + AnyRankedTensor:$tensor, + DenseI64ArrayAttr:$static_indices + ); + + let results = (outs AnyType:$result); + let assemblyFormat = [{$tensor `` $static_indices attr-dict `:` type($tensor)}]; + + let hasCanonicalizer = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + + //===----------------------------------------------------------------------===// // ExtractSliceOp @@ -822,6 +859,50 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// InsertStaticOp +//===----------------------------------------------------------------------===// + +def Tensor_InsertStaticOp : Tensor_Op<"insert_static", [ + DeclareOpInterfaceMethods, + DestinationStyleOpInterface, + Pure, + TypesMatchWith<"result type matches type of dest", + "dest", "result", + "$_self">, + TypesMatchWith<"scalar type matches element type of dest", + "dest", "scalar", + "::llvm::cast($_self).getElementType()">]> { + let summary = "element insertion operation with static indices"; + let description = [{ + The same as `tensor.insert` op except that `tensor.insert_static` op only + takes static indices. + + Example: + + ```mlir + %4 = tensor.insert_static %t into %dest[1, 2] : tensor<4x4xi32> + %5 = tensor.insert_static %rt into %dest[1, 2] : tensor + ``` + }]; + + let arguments = (ins AnyType:$scalar, + AnyRankedTensor:$dest, + DenseI64ArrayAttr:$static_indices); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $scalar `into` $dest `` $static_indices attr-dict `:` type($dest) + }]; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } + }]; + + let hasFolder = 1; + let hasVerifier = 1; +} + + //===----------------------------------------------------------------------===// // InsertSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 1ac96756e22b5..26d4434a484d6 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -39,6 +39,59 @@ using llvm::divideCeilSigned; using llvm::divideFloorSigned; using llvm::mod; +namespace { +template +OpFoldResult foldExtractFromElementsHelper(ExtractOpTy op, + FromElementsOp fromElementsOp, + ArrayRef indices) { + // Fold extract(from_elements(...)). + auto tensorType = llvm::cast(fromElementsOp.getType()); + auto rank = tensorType.getRank(); + assert(static_cast(indices.size()) == tensorType.getRank() && + "rank mismatch"); + int flatIndex = 0; + int stride = 1; + for (int i = rank - 1; i >= 0; --i) { + flatIndex += indices[i] * stride; + stride *= tensorType.getDimSize(i); + } + // Prevent out of bounds accesses. This can happen in invalid code that + // will never execute. + if (static_cast(fromElementsOp.getElements().size()) <= flatIndex || + flatIndex < 0) + return {}; + return fromElementsOp.getElements()[flatIndex]; +} + +LogicalResult verifyStaticIndicesInBound(RankedTensorType type, + ArrayRef indices) { + ArrayRef shape = type.getShape(); + for (auto [dim, index] : llvm::zip(shape, indices)) { + if (index < 0) + return failure(); + if (ShapedType::isDynamic(dim)) + continue; + if (index >= dim) + return failure(); + } + return success(); +} + +template +OpFoldResult insertOpFoldHelper(InsertOpTy insert, AdapterTy adaptor) { + Attribute scalar = adaptor.getScalar(); + Attribute dest = adaptor.getDest(); + if (scalar && dest) { + if (auto splatDest = llvm::dyn_cast(dest)) { + if (scalar == splatDest.getSplatValue()) + return dest; + } + } + return {}; +} + +} // namespace + /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *TensorDialect::materializeConstant(OpBuilder &builder, @@ -1097,18 +1150,28 @@ namespace { /// to /// /// %extracted_element = tensor.extract %source[%c0] : tensor -struct ExtractFromTensorCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct ExtractFromTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::ExtractOp extract, + LogicalResult matchAndRewrite(ExtractOpTy extract, PatternRewriter &rewriter) const final { - auto tensorCast = extract.getTensor().getDefiningOp(); + auto tensorCast = + extract.getTensor().template getDefiningOp(); if (!tensorCast) return failure(); if (!llvm::isa(tensorCast.getSource().getType())) return failure(); - rewriter.replaceOpWithNewOp( - extract, tensorCast.getSource(), extract.getIndices()); + Operation *op = extract; + if (auto extractOp = llvm::dyn_cast(op)) { + rewriter.replaceOpWithNewOp( + extractOp, tensorCast.getSource(), extractOp.getIndices()); + } else if (auto extractStaticOp = + llvm::dyn_cast(op)) { + rewriter.replaceOpWithNewOp( + extractStaticOp, tensorCast.getSource(), + extractStaticOp.getStaticIndices()); + } return success(); } }; @@ -1145,22 +1208,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Fold extract(from_elements(...)). if (auto fromElementsOp = getTensor().getDefiningOp()) { - auto tensorType = llvm::cast(fromElementsOp.getType()); - auto rank = tensorType.getRank(); - assert(static_cast(indices.size()) == tensorType.getRank() && - "rank mismatch"); - int flatIndex = 0; - int stride = 1; - for (int i = rank - 1; i >= 0; --i) { - flatIndex += indices[i] * stride; - stride *= tensorType.getDimSize(i); - } - // Prevent out of bounds accesses. This can happen in invalid code that - // will never execute. - if (static_cast(fromElementsOp.getElements().size()) <= flatIndex || - flatIndex < 0) - return {}; - return fromElementsOp.getElements()[flatIndex]; + return foldExtractFromElementsHelper(*this, fromElementsOp, + indices); } // If this is an elements attribute, query the value at the given indices. @@ -1175,7 +1224,56 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add>(context); +} + +//===----------------------------------------------------------------------===// +// ExtractStaticOp +//===----------------------------------------------------------------------===// + +void ExtractStaticOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "extracted"); +} + +LogicalResult ExtractStaticOp::verify() { + // Verify the # indices match if we have a ranked type. + auto tensorType = llvm::cast(getTensor().getType()); + if (tensorType.getRank() != static_cast(getStaticIndices().size())) + return emitOpError("incorrect number of indices for extract_static"); + if (failed(verifyStaticIndicesInBound(tensorType, getStaticIndices()))) + return emitOpError("static index out of bound for extract_static"); + return success(); +} + +OpFoldResult ExtractStaticOp::fold(FoldAdaptor adaptor) { + // If this is a splat elements attribute, simply return the value. All of + // the elements of a splat attribute are the same. + if (Attribute tensor = adaptor.getTensor()) { + if (auto splatTensor = llvm::dyn_cast(tensor)) + return splatTensor.getSplatValue(); + } + + SmallVector indices(getStaticIndices()); + // Fold extract(from_elements(...)). + if (auto fromElementsOp = getTensor().getDefiningOp()) { + return foldExtractFromElementsHelper(*this, fromElementsOp, + indices); + } + + // If this is an elements attribute, query the value at the given indices. + if (Attribute tensor = adaptor.getTensor()) { + auto elementsAttr = llvm::dyn_cast(tensor); + if (elementsAttr && elementsAttr.isValidIndex(indices)) + return elementsAttr.getValues()[indices]; + } + + return {}; +} + +void ExtractStaticOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); } //===----------------------------------------------------------------------===// @@ -1368,13 +1466,34 @@ LogicalResult InsertOp::verify() { } OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { - Attribute scalar = adaptor.getScalar(); - Attribute dest = adaptor.getDest(); - if (scalar && dest) - if (auto splatDest = llvm::dyn_cast(dest)) - if (scalar == splatDest.getSplatValue()) - return dest; - return {}; + return insertOpFoldHelper>>( + *this, adaptor); +} + +//===----------------------------------------------------------------------===// +// InsertStaticOp +//===----------------------------------------------------------------------===// + +void InsertStaticOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "inserted"); +} + +LogicalResult InsertStaticOp::verify() { + // Verify the # indices match if we have a ranked type. + auto destType = llvm::cast(getDest().getType()); + if (destType.getRank() != static_cast(getStaticIndices().size())) + return emitOpError("incorrect number of indices for insert_static"); + if (failed(verifyStaticIndicesInBound(destType, getStaticIndices()))) + return emitOpError("static index out of bound for insert_static"); + return success(); +} + +OpFoldResult InsertStaticOp::fold(FoldAdaptor adaptor) { + return insertOpFoldHelper>>( + *this, adaptor); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 86754c1c37536..25b46e7877cba 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -173,6 +173,40 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { // ----- +// CHECK-LABEL: func @fold_extract_static +func.func @fold_extract_static() -> (f32, f16, f16, i32, complex) { + // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32 + // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16 + // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16 + + // Fold an extract into a splat. + // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32 + %0 = arith.constant dense<4.0> : tensor<4xf32> + %ext_1 = tensor.extract_static %0[1] : tensor<4xf32> + + // Fold an extract into a sparse with a sparse index. + %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16> + %ext_2 = tensor.extract_static %1[1, 1, 1] : tensor<4x4x4xf16> + + // Fold an extract into a sparse with a non sparse index. + %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16> + %ext_3 = tensor.extract_static %2[0, 0, 0] : tensor<2x2x2xf16> + + // Fold an extract into a dense tensor. + %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> + %ext_4 = tensor.extract_static %3[1, 0, 3] : tensor<2x1x4xi32> + + // Fold an extract into a complex constant. + // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex + %4 = arith.constant dense<(1.2, 2.3)> : tensor> + %ext_5 = tensor.extract_static %4[] : tensor> + + // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]] + return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex +} + +// ----- + // CHECK-LABEL: func @fold_insert func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { // Fold an insert into a splat. @@ -186,6 +220,19 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { // ----- +// CHECK-LABEL: func @fold_insert_static +func.func @fold_insert_static() -> (tensor<4xf32>) { + // Fold an insert into a splat. + // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32> + %0 = arith.constant dense<4.0> : tensor<4xf32> + %1 = arith.constant 4.0 : f32 + %ins_1 = tensor.insert_static %1 into %0[3] : tensor<4xf32> + // CHECK-NEXT: return %[[C4]] + return %ins_1 : tensor<4xf32> +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.cast // CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32> func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 { @@ -200,6 +247,18 @@ func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 { // ----- +// CHECK-LABEL: func @extract_static_from_tensor.cast +// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32> +func.func @extract_static_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 { + // CHECK-NOT: tensor.cast + %casted = tensor.cast %tensor : tensor<9xf32> to tensor + // CHECK-NEXT: tensor.extract_static %[[TENSOR]][0] + %result = tensor.extract_static %casted[0] : tensor + return %result : f32 +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.from_elements func.func @extract_from_tensor.from_elements(%element : index) -> index { // CHECK-SAME: ([[ARG:%.*]]: index) @@ -212,6 +271,17 @@ func.func @extract_from_tensor.from_elements(%element : index) -> index { // ----- +// CHECK-LABEL: func @extract_static_from_tensor.from_elements +func.func @extract_static_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract_static %tensor[0] : tensor<1xindex> + // CHECK: [[ARG]] : index + return %extracted_element : index +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.from_elements_0d func.func @extract_from_tensor.from_elements_0d(%element : index) -> index { // CHECK-SAME: ([[ARG:%.*]]: index) @@ -224,6 +294,17 @@ func.func @extract_from_tensor.from_elements_0d(%element : index) -> index { // ----- +// CHECK-LABEL: func @extract_static_from_tensor.from_elements_0d +func.func @extract_static_from_tensor.from_elements_0d(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %tensor = tensor.from_elements %element : tensor + %extracted_element = tensor.extract_static %tensor[] : tensor + // CHECK: [[ARG]] : index + return %extracted_element : index +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.from_elements_3d func.func @extract_from_tensor.from_elements_3d() -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) { @@ -261,6 +342,61 @@ func.func @extract_from_tensor.from_elements_3d() return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11 : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32 } + +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 +// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 +// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 +// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 +// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 +// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 +// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 +// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 +// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 +// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 +// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 +// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 + +// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], +// CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]] + + +// ----- + +// CHECK-LABEL: func @extract_static_from_tensor.from_elements_3d +func.func @extract_static_from_tensor.from_elements_3d() + -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) { + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + %f4 = arith.constant 4.0 : f32 + %f5 = arith.constant 5.0 : f32 + %f6 = arith.constant 6.0 : f32 + %f7 = arith.constant 7.0 : f32 + %f8 = arith.constant 8.0 : f32 + %f9 = arith.constant 9.0 : f32 + %f10 = arith.constant 10.0 : f32 + %f11 = arith.constant 11.0 : f32 + + %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 + : tensor<3x2x2xf32> + + %r0 = tensor.extract_static %tensor[0, 0, 0] : tensor<3x2x2xf32> + %r1 = tensor.extract_static %tensor[0, 0, 1] : tensor<3x2x2xf32> + %r2 = tensor.extract_static %tensor[0, 1, 0] : tensor<3x2x2xf32> + %r3 = tensor.extract_static %tensor[0, 1, 1] : tensor<3x2x2xf32> + %r4 = tensor.extract_static %tensor[1, 0, 0] : tensor<3x2x2xf32> + %r5 = tensor.extract_static %tensor[1, 0, 1] : tensor<3x2x2xf32> + %r6 = tensor.extract_static %tensor[1, 1, 0] : tensor<3x2x2xf32> + %r7 = tensor.extract_static %tensor[1, 1, 1] : tensor<3x2x2xf32> + %r8 = tensor.extract_static %tensor[2, 0, 0] : tensor<3x2x2xf32> + %r9 = tensor.extract_static %tensor[2, 0, 1] : tensor<3x2x2xf32> + %r10 = tensor.extract_static %tensor[2, 1, 0] : tensor<3x2x2xf32> + %r11 = tensor.extract_static %tensor[2, 1, 1] : tensor<3x2x2xf32> + return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11 + : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32 +} + // CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 84e6c59e403dd..4be9b6a9c8718 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -72,6 +72,22 @@ func.func @extract_too_many_indices(%arg0: tensor) { // ----- +func.func @extract_static_too_many_indices(%arg0: tensor) { + // expected-error@+1 {{incorrect number of indices for extract_static}} + %0 = tensor.extract_static %arg0[] : tensor + return +} + +// ----- + +func.func @extract_static_indices_out_of_bound(%arg0: tensor<2xf32>) { + // expected-error@+1 {{static index out of bound for extract_static}} + %0 = tensor.extract_static %arg0[4] : tensor<2xf32> + return +} + +// ----- + func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { // expected-error@+1 {{incorrect number of indices}} %0 = tensor.insert %arg0 into %arg1[] : tensor @@ -80,6 +96,29 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { // ----- +func.func @insert_static_too_many_indices(%arg0: f32, %arg1: tensor) { + // expected-error@+1 {{incorrect number of indices for insert_static}} + %0 = tensor.insert_static %arg0 into %arg1[] : tensor + return +} + +// ----- + +func.func @insert_static_indices_out_of_bound(%arg0: f32, %arg1: tensor<2xf32>) { + // expected-error@+1 {{static index out of bound for insert_static}} + %0 = tensor.insert_static %arg0 into %arg1[4] : tensor<2xf32> + return +} + +// ----- + +func.func @insert_static_indices_dynamic(%arg0: f32, %arg1: tensor) { + %0 = tensor.insert_static %arg0 into %arg1[4] : tensor + return +} + +// ----- + func.func @tensor.from_elements_wrong_result_type() { // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}} %c0 = arith.constant 0 : i32 diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 378137a14b59f..fcfee9978d4bd 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -63,6 +63,16 @@ func.func @extract(%arg0: tensor, %arg1: index) { // ----- +// CHECK-LABEL: func.func @extract_static( +// CHECK-SAME: %[[TENSOR:.*]]: tensor) { +func.func @extract_static(%arg0: tensor) { + // CHECK: tensor.extract_static %[[TENSOR]][1, 2, 3] : tensor + %0 = tensor.extract_static %arg0[1, 2, 3] : tensor + return +} + +// ----- + // CHECK-LABEL: func @insert( // CHECK-SAME: %[[SCALAR:.*]]: f32 // CHECK-SAME: %[[INDEX:.*]]: index @@ -75,6 +85,17 @@ func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor) { // ----- +// CHECK-LABEL: func @insert_static( +// CHECK-SAME: %[[SCALAR:.*]]: f32 +// CHECK-SAME: %[[DEST1:.*]]: tensor +func.func @insert_static(%arg0: f32, %arg1: tensor) { + // CHECK: tensor.insert_static %[[SCALAR]] into %[[DEST1]][1, 2, 3] : tensor + %0 = tensor.insert_static %arg0 into %arg1[1, 2, 3] : tensor + return +} + +// ----- + // CHECK-LABEL: func @tensor.from_elements() { func.func @tensor.from_elements() { %c0 = "arith.constant"() {value = 0: index} : () -> index