diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index fdd08f9208c9..50f9a22ec205 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -25,7 +25,7 @@ from tvm.runtime import Object, convert from tvm.ir import container as _container -from tvm.tir import IterVar, Buffer, Var +from tvm.tir import IterVar, Buffer, Var, IndexMap from . import tensor as _tensor from . import _ffi_api @@ -599,65 +599,12 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr """ - args = [] - var_arg_name = None - kwargs = collections.OrderedDict() - default_index_dtype = "int32" - - # Make a dummy variable for each explicitly named input index. - # We may have some keyword-only arguments, if the function has - # *args before the last argument. - params = inspect.signature(mapping_function).parameters - for name, param in params.items(): - if param.kind in [ - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ]: - args.append(tvm.tir.Var(name, default_index_dtype)) - - elif param.kind == inspect.Parameter.VAR_POSITIONAL: - var_arg_name = name - - elif param.kind == inspect.Parameter.KEYWORD_ONLY: - kwargs[name] = tvm.tir.Var(name, default_index_dtype) - - elif param.kind in [inspect.Parameter.VAR_KEYWORD]: - raise ValueError("transform_layout mapping may not have **kwargs") - ndim = len(self.op.output(0).shape) + index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim) - # Now that all the named arguments have been collected, - # everything that remains should go to the *args, if - # specified. - if var_arg_name is not None: - num_var_args = ndim - len(args) - len(kwargs) - for i in range(num_var_args): - args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype)) - - initial_indices = args + list(kwargs.values()) - if len(initial_indices) != ndim: - raise ValueError( - f"transform_layout mapping accepts {len(params)} initial indices, " - f"but {self.op.name} is {len(self.op.shape)}-dimensional" - ) - - mapping = mapping_function(*args, **kwargs) - - final_indices = [] - axis_separators = [] - for val in mapping: - if isinstance(val, tvm.ir.PrimExpr): - final_indices.append(val) - elif val is AXIS_SEPARATOR: - axis_separators.append(len(final_indices)) - else: - raise TypeError( - "Expected mapping function to return list of " - "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. " - "Instead received {val} of type {type(val)}." - ) - - new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices) + new_iter_vars = _ffi_api.StageTransformLayout( + self, index_map.initial_indices, index_map.final_indices + ) _ffi_api.StageSetAxisSeparators(self, axis_separators) return new_iter_vars or None @@ -700,9 +647,10 @@ def __exit__(self, ptype, value, trace): # Sentinel value used to indicate which groups of pre-flattening axes -# should be used to post-flattening axes axes. See -# Stage.transform_layout for more details. -AXIS_SEPARATOR = "axis_separator" +# should be used to post-flattening axes axes. Moved from +# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use, +# maintained here for backwards compatibility. +AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR tvm._ffi._init_api("schedule", __name__) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index d84513e072d3..a921c5b9fc40 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,8 +16,9 @@ # under the License. """Function data types.""" -from typing import Callable, List, Mapping, Optional, Union, Tuple +import collections import inspect +from typing import Callable, List, Mapping, Optional, Union, Tuple import tvm import tvm._ffi @@ -258,6 +259,11 @@ class IndexMap(Object): initial_indices: List[Var] final_indices: List[PrimExpr] + # Sentinel value used to indicate which groups of pre-flattening axes + # should be used to post-flattening axes axes. See + # Stage.transform_layout for more details. + AXIS_SEPARATOR = "axis_separator" + def __init__(self, initial_indices, final_indices): self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices) @@ -268,34 +274,117 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): Parameters ---------- mapping_function : Callable - The function to map from source indices to target indices + + The function to map from source indices to target indices. + The function should accept `tir.Var` parameters and return + a list. Each element of the returned list should be a + `tir.PrimExpr`. + + ndim: Optional[int] + + The dimensionality of the buffer to which this + transformation should be applied. If mapping_function uses + variadic argument `*args`, `ndim` must be specified. If + mapping_function does not use variadic arguments, ndim is + optional. + + Returns + ------- + index_map: IndexMap + + Returns an IndexMap representing the `mapping_function`. + + """ + index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim) + assert not axis_separators, ( + "The mapping_function provided to IndexMap.from_func " + "may not return IndexMap.AXIS_SEPARATOR. " + "If required, please use IndexMap.from_func_with_separators instead." + ) + return index_map + + @staticmethod + def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None): + """Create an index map from a function + + Parameters + ---------- + mapping_function : Callable + + The function to map from source indices to target indices. + The function should accept tir.Var parameters and return a + list. Each element of the returned list should be either a + `tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`. + + ndim: Optional[int] + + The dimensionality of the buffer to which this + transformation should be applied. If mapping_function uses + variadic argument `*args`, ndim must be specified. If + mapping_function does not use variadic arguments, ndim is + optional. + + Returns + ------- + ret: Tuple[IndexMap, List[int]] + + Returns a tuple whose first element is an IndexMap + representing the `mapping_function`, and whose second index + is a list of indices at which `IndexMap.AXIS_SEPARATOR` + occurred. + """ params = inspect.signature(mapping_function).parameters - default_index_dtype = "int32" + args = [] var_arg_name = None + kwargs = collections.OrderedDict() + default_index_dtype = "int32" + for name, param in params.items(): if param.kind in [ inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ]: args.append(tvm.tir.Var(name, default_index_dtype)) + elif param.kind == inspect.Parameter.VAR_POSITIONAL: var_arg_name = name + + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs[name] = tvm.tir.Var(name, default_index_dtype) + else: - raise ValueError("transform_layout mapping may not have *args or **kwargs") + raise ValueError("transform_layout mapping may not have *args") # Now that all the named arguments have been collected, # everything that remains should go to the *args, if # specified. if var_arg_name is not None: assert ndim is not None, "ndim must be specified when *args is used" - num_var_args = ndim - len(args) + num_var_args = ndim - len(args) - len(kwargs) for i in range(num_var_args): args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype)) - final_indices = mapping_function(*args) - return IndexMap(args, final_indices) + mapping = mapping_function(*args, **kwargs) + + initial_indices = args + list(kwargs.values()) + + final_indices = [] + axis_separators = [] + for val in mapping: + if isinstance(val, tvm.ir.PrimExpr): + final_indices.append(val) + elif val is IndexMap.AXIS_SEPARATOR: + axis_separators.append(len(final_indices)) + else: + raise TypeError( + "Expected mapping function to return list of " + "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " + "Instead received {val} of type {type(val)}." + ) + + return IndexMap(initial_indices, final_indices), axis_separators def is_equivalent_to(self, other_map: "IndexMap") -> bool: """Return if the index maps are equivalent. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 6474ba0baa3d..dc687b1eaef1 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union, Tuple from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer from ..function import IndexMap from . import _ffi_api @@ -2114,25 +2114,111 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: ########## Schedule: Layout transformation ########## + def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV: + if isinstance(block, str): + return self.get_block(block) + + return block + + def _normalize_buffer_arg( + self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer] + ) -> Tuple[str, int, Buffer]: + + block_name = self.get(block).name_hint + + def iter_buffers(): + block_obj = self.get(block) + for i, read in enumerate(block_obj.reads): + yield "read", i, read.buffer + for i, write in enumerate(block_obj.writes): + yield "write", i, write.buffer + + if isinstance(buffer, str): + possible_buffers = {} + # String lookup requires ensuring that the name is unique + for buffer_index, buffer_index_type, buf in iter_buffers(): + if buf.name == buffer: + possible_buffers[buf] = (buffer_index_type, buffer_index) + + assert possible_buffers, f"Could not find buffer '{buffer}' in block '{block_name}'" + assert ( + len(possible_buffers) == 1 + ), f"Multiple buffers named '{buffer}' in block '{block_name}'" + buffer_obj, (buffer_index, buffer_index_type) = next(iter(possible_buffers.items())) + + elif isinstance(buffer, Buffer): + # Buffer lookup has unique id, can break out early + found = False + for buffer_index, buffer_index_type, buffer_obj in iter_buffers(): + if buffer_obj.same_as(buffer): + found = True + break + + assert found, "Could not find buffer '{buffer.name}' in block '{block_name}'" + + elif isinstance(buffer, tuple): + buffer_index_type, buffer_index = buffer + assert buffer_index_type in ["read", "write",], ( + f"Invalid buffer_index_type. " + f"Expected 'read' or 'write', " + f"but received {buffer_index_type}" + ) + buffer_list = ( + self.get(block).reads if buffer_index_type == "read" else self.get(block).writes + ) + assert 0 <= buffer_index < len(buffer_list), ( + f"Invalid buffer_index {buffer_index}. " + f"Block {block_name} has only " + f"{len(buffer_list)} {buffer_index_type} buffers." + ) + buffer_obj = buffer_list[buffer_index].buffer + + else: + raise TypeError(f"Invalid type for argument 'buffer': {type(buffer)}") + + return (buffer_index_type, buffer_index, buffer_obj) + @type_checked def transform_layout( self, - block: BlockRV, - buffer_index: int, - buffer_index_type: str, + block: Union[BlockRV, str], + buffer: Union[Tuple[str, int], str, Buffer], index_map: Union[IndexMap, Callable], ) -> None: """Apply a transformation represented by IndexMap to buffer + Parameters ---------- - block : BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_index_type : str - Type of the buffer index, "read" or "write" + block : Union[BlockRV, str] + + The block that accesses the target buffer. If a string, + this must uniquely identify a block. + + buffer: Union[Tuple[str,int], Buffer, str] + + The buffer to be transformed, or a specification of how to + identify the buffer to be transformed. + + If `buffer` if a tuple of ``(str,int)``, the first item + should be either "read" or "write", and the second item is + an index into the block's read or write regions. + + If `buffer` is a string, it is the name of the buffer, + which must exist within the reads/writes of the block. In + addition, the reads/writes of the block may not contain + more than one buffer with this name. + + If `buffer` is a Buffer object, it must exist within the + reads/writes of the block. + index_map : Union[IndexMap, Callable] - The transformation to apply + + The transformation to apply. + + If `index_map` is a callable, and the returned list + contains IndexMap.AXIS_SEPARATOR, the SetAxisSeparators + primitive will be called in addition to the + TransformLayout primitive. Examples -------- @@ -2159,7 +2245,7 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_storage_align) - sch.transform_layout(sch.get_block("B"), buffer_index=0, "write", + sch.transform_layout(sch.get_block("B"), buffer=("write",0), index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) @@ -2182,20 +2268,29 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 """ + block = self._normalize_block_arg(block) + buffer_index_type, buffer_index, buffer_obj = self._normalize_buffer_arg(block, buffer) + + ndim = len(buffer_obj.shape) if callable(index_map): - index_map = IndexMap.from_func(index_map) - assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim) + else: + axis_separators = [] + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member self, block, buffer_index, buffer_index_type_enum, index_map ) + if axis_separators: + _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member + self, block, buffer_index, buffer_index_type_enum, axis_separators + ) @type_checked def set_axis_separator( self, - block: BlockRV, - buffer_index: int, - buffer_index_type: str, + block: Union[BlockRV, str], + buffer: Union[Tuple[str, int], str, Buffer], axis_separators: Optional[List[int]], ) -> None: """Set the axis separator of a buffer, where the buffer is specified by a block and a read @@ -2203,13 +2298,30 @@ def set_axis_separator( Parameters ---------- - block : BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_index_type : str - Type of the buffer index, "read" or "write" + block : Union[BlockRV, str] + + The block that accesses the target buffer. If a string, + this must uniquely identify a block. + + buffer: Union[Tuple[str,int], Buffer, str] + + The buffer to be transformed, or a specification of how to + identify the buffer to be transformed. + + If `buffer` if a tuple of ``(str,int)``, the first item + should be either "read" or "write", and the second item is + an index into the block's read or write regions. + + If `buffer` is a string, it is the name of the buffer, + which must exist within the reads/writes of the block. In + addition, the reads/writes of the block may not contain + more than one buffer with this name. + + If `buffer` is a Buffer object, it must exist within the + reads/writes of the block. + axis_separators : Optional[List[int]] + The axis separators. Examples @@ -2263,7 +2375,10 @@ def after_set_axis_separators( C[vi, vj] = B[vi, vj] + T.float32(1) """ axis_separators = axis_separators or [] - assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + + block = self._normalize_block_arg(block) + buffer_index_type, buffer_index, _ = self._normalize_buffer_arg(block, buffer) + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member self, block, buffer_index, buffer_index_type_enum, axis_separators diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 04cbffcd4d87..3689f756e83c 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Testing utilities for the TensorIR schedule API""" -from typing import Union +from typing import Union, Sequence +import tvm from tvm.ir import IRModule, structural_equal from tvm.tir import PrimFunc from tvm.tir.schedule import Trace, Schedule @@ -27,6 +28,7 @@ def verify_trace_roundtrip( mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", + text_format: Union[str, Sequence[str]] = ["python", "json"], ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. @@ -44,18 +46,36 @@ def verify_trace_roundtrip( 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask + text_format: Union[str, Sequence[str]] + The text format or formats whose round-trip behavior should be + validated. If a single string, validate round-trips through """ - # Step 1. Serialize the trace to JSON + if not isinstance(text_format, str): + for opt in text_format: + new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) + return new_sch + trace = sch.trace assert trace is not None - json_obj = trace.as_json() - # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling + + # Step 1. Perform a round-trip through the text-format new_sch = Schedule(mod=mod, debug_mask=debug_mask) - Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + if text_format == "json": + json_obj = trace.as_json() + Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + elif text_format == "python": + py_trace = "\n".join(trace.as_python()) + exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used + else: + assert text_format in ("json", "python"), f"Unknown text format: {text_format}" + + # Step 2. Verify that the round-trip produced the same scheduling assert structural_equal(new_sch.mod, sch.mod) + # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr + # Step 4. Return the new schedule in case it could be useful return new_sch diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 331d098347b0..7ed80a1c5b8f 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -699,7 +699,7 @@ struct TensorizeTraits : public UnpackedInstTraits { static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin) { PythonAPICall py("tensorize"); py.Input("block_or_loop", block_or_loop_rv); - py.Input("intrin", intrin); + py.Input("tensor_intrin", intrin); return py.Str(); } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index fb63b1b289b1..cf95665ee828 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -291,11 +291,12 @@ struct TransformLayoutTraits : public UnpackedInstTraits Integer buffer_index_type, IndexMap index_map) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); - py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", '"' + - std::string(BufferIndexType2Str( - static_cast(buffer_index_type->value))) + - '"'); + + std::ostringstream os; + os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\", " << buffer_index << ")"; + py.Input("buffer", os.str()); + py.Input("index_map", index_map->ToPythonString()); return py.Str(); } @@ -343,11 +344,12 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits axis_separators) { PythonAPICall py("set_axis_separator"); py.Input("block", block_rv); - py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", '"' + - std::string(BufferIndexType2Str( - static_cast(buffer_index_type->value))) + - '"'); + + std::ostringstream os; + os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\", " << buffer_index << ")"; + py.Input("buffer", os.str()); + py.Input("axis_separators", axis_separators); return py.Str(); } diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index 8c3d1e673571..102b3d1cd710 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -20,6 +20,7 @@ import tvm import tvm.testing from tvm import tir +from tvm.tir import IndexMap from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -102,11 +103,19 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +use_sugared_transform = tvm.testing.parameter( + by_dict={"set_axis_separators": False, "transform_layout_sugared": True} +) -def test_set_axis_separator(): +def test_set_axis_separator(use_sugared_transform): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + + if use_sugared_transform: + s.set_axis_separator(s.get_block("B"), ("write",0), [1]) + else: + s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -114,24 +123,35 @@ def test_set_axis_separator(): def test_set_scope_fail_on_index_out_of_bound(): func = element_wise s = tir.Schedule(func, debug_mask='all') - with pytest.raises(tvm.tir.ScheduleError): - s.set_axis_separator(s.get_block("B"), 1, "write",[1]) - with pytest.raises(tvm.tir.ScheduleError): - s.set_axis_separator(s.get_block("B"), -1, "read",[1]) + with pytest.raises(AssertionError): + s.set_axis_separator(s.get_block("B"), ("write",1),[1]) + with pytest.raises(AssertionError): + s.set_axis_separator(s.get_block("B"), ("read",-1),[1]) -def test_set_axis_separator_input_buffer(): +def test_set_axis_separator_input_buffer(use_sugared_transform): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_axis_separator(s.get_block("B"), 0, "read", [1]) + + if use_sugared_transform: + s.transform_layout(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + else: + s.set_axis_separator(s.get_block("B"), ("read",0), [1]) + + tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) -def test_set_axis_separator_subregion(): +def test_set_axis_separator_subregion(use_sugared_transform): func = element_wise_subregion_match s = tir.Schedule(func, debug_mask='all') - s.set_axis_separator(s.get_block("B"), 0, "write", [1]) + + if use_sugared_transform: + s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + else: + s.set_axis_separator(s.get_block("B"), ("write",0), [1]) + tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 67e8ae0ad836..e9ee990a2415 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -177,9 +177,9 @@ def tile_wmma_fragment(block_read, height, width): else: loop_b = tile_wmma_fragment(B_warp, k_inner, 16) - sch.transform_layout(A_warp, 0, "write", index_map_A) - sch.transform_layout(B_warp, 0, "write", index_map_B) - sch.transform_layout(C_warp, 0, "read", index_map_C) + sch.transform_layout(A_warp, ("write", 0), index_map_A) + sch.transform_layout(B_warp, ("write", 0), index_map_B) + sch.transform_layout(C_warp, ("read", 0), index_map_C) sch.tensorize(loop_a, ldmatrix_a_intrin) sch.tensorize(loop_b, ldmatrix_b_intrin) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 9e7cad4d8526..699eaf1236ac 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -94,27 +94,58 @@ def two_elementwise_transformed_output_buffer( # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on +use_sugared_transform = tvm.testing.parameter( + by_dict={"transform_layout": False, "transform_layout_sugared": True} +) -def test_two_elementwise_transform_intermediate_buffer(): + +def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") - sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m % 16, n % 16)) + + if use_sugared_transform: + sch.transform_layout( + block="B", + buffer="B", + index_map=packed_index_map_func, + ) + else: + block = sch.get_block("B") + sch.transform_layout(block, ("write", 0), packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_input_buffer(): +def test_two_elementwise_transform_input_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") - sch.transform_layout(block, 0, "read", packed_index_map_func) + + if use_sugared_transform: + sch.transform_layout( + index_map=packed_index_map_func, + block="B", + buffer="A", + ) + else: + block = sch.get_block("B") + sch.transform_layout(block, ("read", 0), packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_output_buffer(): +def test_two_elementwise_transform_output_buffer(use_sugared_transform): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("C") - sch.transform_layout(block, 0, "write", packed_index_map_func) + + if use_sugared_transform: + sch.transform_layout( + index_map=packed_index_map_func, + block="C", + buffer="C", + ) + else: + block = sch.get_block("C") + sch.transform_layout(block, ("write", 0), packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -136,7 +167,7 @@ def test_simplify(): block_outer = sch.blockize(i_inner) B = sch.cache_read(block_outer, 0, "global") - sch.transform_layout(B, 0, "write", lambda i, j: (i // 16, j // 16, i % 16, j % 16)) + sch.transform_layout(B, ("write", 0), lambda i, j: (i // 16, j // 16, i % 16, j % 16)) @T.prim_func def ref(B: T.Buffer[(8, 8, 16, 16), "float32"], C: T.Buffer[(128, 128), "float32"]): @@ -159,5 +190,33 @@ def ref(B: T.Buffer[(8, 8, 16, 16), "float32"], C: T.Buffer[(128, 128), "float32 tvm.ir.assert_structural_equal(ref.body.block.body, sch.get(sch.get_loops(block_outer)[0])) +def test_var_args_sugar(): + @T.prim_func + def summation_3d( + A: T.Buffer[(1024, 1024, 32), "float32"], B: T.Buffer[(1,), "float32"] + ) -> None: + B[0] = 0 + for i, j, k in T.grid(1024, 1024, 32): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[0] = B[0] + A[vi, vj, vk] + + @T.prim_func + def summation_3d_split( + A: T.Buffer[(1024, 1024, 8, 4), "float32"], B: T.Buffer[(1,), "float32"] + ) -> None: + B[0] = 0 + for i, j, k in T.grid(1024, 1024, 32): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[0] = B[0] + A[vi, vj, vk // 4, vk % 4] + + sch = tir.Schedule(summation_3d, debug_mask="all") + sch.transform_layout( + index_map=lambda *indices, k: [*indices, k // 4, k % 4], block="compute", buffer="A" + ) + tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()