diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 79dd9476ad0ff..0b30d10209908 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -13,6 +13,7 @@ from typing import Sequence, Union from ._ods_common import _cext as _ods_cext +from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results @_ods_cext.register_operation(_Dialect, replace=True) @@ -43,6 +44,18 @@ def __init__( super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) +def empty( + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None, +) -> _ods_cext.ir.Value: + return _get_op_result_or_op_results( + EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip) + ) + + generate = region_op( lambda result, dynamic_extents: GenerateOp(result, dynamic_extents), terminator=lambda args: YieldOp(args[0]), diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py index 18c237c68081a..64df4e1276222 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -63,8 +63,8 @@ def matmul_poly( RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) ) def test_matmul_mono(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) - return matmul_mono(lhs, rhs, outs=[init_result.result]) + init_result = tensor.empty([4, 8], f32) + return matmul_mono(lhs, rhs, outs=[init_result]) # CHECK-LABEL: @test_i8i8i32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index b147551c2e73d..3bfbcf7d7f7c8 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -97,7 +97,7 @@ def testNamedStructuredOpGenericForm(): RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) ) def named_form(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) + init_result = tensor.empty([4, 8], f32) # CHECK: "linalg.matmul"(%{{.*}}) # CHECK-SAME: cast = #linalg.type_fn # CHECK-SAME: operandSegmentSizes = array @@ -106,7 +106,7 @@ def named_form(lhs, rhs): # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> - return linalg.matmul(lhs, rhs, outs=[init_result.result]) + return linalg.matmul(lhs, rhs, outs=[init_result]) module.operation.print(print_generic_op_form=True)