From 1618f15878f63d24b83ce582a13299fc25b0b27a Mon Sep 17 00:00:00 2001 From: 6clc Date: Wed, 11 Oct 2023 11:01:02 +0800 Subject: [PATCH] cinn(py-dsl): parse schedule of python dsl (#57981) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 拆分新特性:CINN Python DSL, 主PR和单测见:#56393 此PR只负责 解析python dsl中的schedule定义 --- python/cinn/compiler/compiler.py | 7 +- .../cinn/compiler/schedule_code_generator.py | 189 ++++++++++++++++++ python/cinn/runtime/data_array.py | 97 +++++++++ test/cinn/CMakeLists.txt | 36 ++++ test/cinn/ir/test_llir_constructor.py | 36 ++++ .../ir/test_llir_schedule_cache_read_write.py | 73 +++++++ test/cinn/ir/test_llir_schedule_compute_at.py | 111 ++++++++++ .../ir/test_llir_schedule_compute_inline.py | 95 +++++++++ test/cinn/ir/test_llir_schedule_fuse_split.py | 131 ++++++++++++ test/cinn/ir/test_llir_schedule_reorder.py | 80 ++++++++ test/cinn/ir/test_llir_schedule_sequence.py | 70 +++++++ test/cinn/utils/testing.py | 28 +++ 12 files changed, 951 insertions(+), 2 deletions(-) create mode 100644 python/cinn/compiler/schedule_code_generator.py create mode 100644 python/cinn/runtime/data_array.py create mode 100644 test/cinn/ir/test_llir_constructor.py create mode 100644 test/cinn/ir/test_llir_schedule_cache_read_write.py create mode 100644 test/cinn/ir/test_llir_schedule_compute_at.py create mode 100644 test/cinn/ir/test_llir_schedule_compute_inline.py create mode 100644 test/cinn/ir/test_llir_schedule_fuse_split.py create mode 100644 test/cinn/ir/test_llir_schedule_reorder.py create mode 100644 test/cinn/ir/test_llir_schedule_sequence.py create mode 100644 test/cinn/utils/testing.py diff --git a/python/cinn/compiler/compiler.py b/python/cinn/compiler/compiler.py index 330d34962641d6..12f1ffb79d6407 100644 --- a/python/cinn/compiler/compiler.py +++ b/python/cinn/compiler/compiler.py @@ -15,6 +15,7 @@ from ..runtime import CinnLowerLevelIrJit from .compute_code_generator import ComputeCodeGenerator +from .schedule_code_generator import ScheduleCodeGenerator def ast_to_llir(fn, inputs_signature): @@ -24,7 +25,10 @@ def ast_to_llir(fn, inputs_signature): fn, function_name, inputs_signature ) cinn_llir_func = llir_compute_generator.parse() - return cinn_llir_func + + # 2. Parse CINN Schedule + llir_schedule_generator = ScheduleCodeGenerator(fn, cinn_llir_func) + return llir_schedule_generator.parse() def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs): @@ -35,4 +39,3 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs): if just_convert: return llir_func - return llir_func diff --git a/python/cinn/compiler/schedule_code_generator.py b/python/cinn/compiler/schedule_code_generator.py new file mode 100644 index 00000000000000..6cc4c2973464b9 --- /dev/null +++ b/python/cinn/compiler/schedule_code_generator.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast + +from cinn.schedule import IRSchedule + +from .expr_executor import ExprExecutor, exec_assign +from .utils import ( + VariableTable, + is_node_parsed_in_schedule, + node_is_schedule_block_context, +) + + +class ScheduleCodeGenerator(ast.NodeVisitor): + """ + Convert python ast to CINN Lower Level IR, + containing only the semantics of the schedule part + """ + + def __init__(self, fn, cinn_llir_func): + self.fn = fn + self.cinn_llir_func = cinn_llir_func + self.scheduler = IRSchedule.make(self.cinn_llir_func) + self.variable_table = VariableTable() + self.global_variable_table = VariableTable() + # Set the schedule-related variable to global + self.extra_scope = { + "ScheduleBlockVariable": ScheduleBlockVariable, + "scheduler": self.scheduler, + } + self.loop_var_stack = [] + self.block_stack = [] + self.sch_block_tmp_var_name = "__CINN_SCHEDULE_BLOCK_VAR_NAME__" + self.tmp_var_count = 1 + + def parse(self): + with self.variable_table, self.global_variable_table: + ast_node = self.fn.parse() + for k, v in self.fn.scope.items(): + self.variable_table.add(k, v) + for k, v in self.extra_scope.items(): + self.variable_table.add(k, v) + self.visit(ast_node) + return self.cinn_llir_func + + def visit_For(self, node): + assert isinstance( + node.target, ast.Name + ), "Current only support range() to make ForLoop" + with self.variable_table: + self.loop_var_stack.append(node.target) + self.generic_visit(node) + self.loop_var_stack.pop() + + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.visit(stmt) + + def visit_With(self, node): + with self.variable_table: + for item in node.items: + if isinstance( + item.context_expr, ast.Call + ) and not node_is_schedule_block_context(item.context_expr): + continue + # 1. replace ScheduleBlockContext to ScheduleBlockVariable + sch_ctx_node = item.context_expr + sch_block_node = ast.copy_location( + ast.Call( + func=ast.Name( + id="ScheduleBlockVariable", ctx=ast.Load() + ), + args=sch_ctx_node.args, + keywords=[], + starargs=None, + kwargs=None, + ), + item.context_expr, + ) + item.context_expr = sch_block_node + + # 2. store ScheduleBlockVariable node + sch_block = ExprExecutor(self.variable_table.get()).exec( + item.context_expr + ) + if item.optional_vars is None: + tmp_var_name = self.sch_block_tmp_var_name + str( + self.tmp_var_count + ) + sch_block_var_node = ast.Name( + id=tmp_var_name, ctx=ast.Store() + ) + item.optional_vars = sch_block_var_node + local_var_table = exec_assign( + target=item.optional_vars, source=sch_block + ) + # 3. Set the block's loop to its attritbute + sch_block.set_scheduler(self.scheduler) + self.block_stack.append(sch_block) + for k, v in local_var_table.items(): + self.variable_table.add(k, v) + self.global_variable_table.add(k, v) + for loop_var in self.loop_var_stack: + loop_var_value = ast.Attribute( + value=ast.Name(id=k, ctx=ast.Load()), + attr=loop_var.id, + ctx=ast.Load(), + ) + loop_var_value = ExprExecutor( + self.variable_table.get() + ).exec(loop_var_value) + for_loop_var_table = exec_assign( + loop_var, loop_var_value + ) + for ( + loop_var_k, + loop_var_v, + ) in for_loop_var_table.items(): + self.variable_table.add(loop_var_k, loop_var_v) + + body = self.visit_compound_statement(node.body) + + def visit_Assign(self, node): + if isinstance(node.value, ast.Call) and is_node_parsed_in_schedule( + node.value + ): + sch_ret = self.exec_schedule_primitive(node.value) + local_var_table = exec_assign( + target=node.targets[0], source=sch_ret + ) + for k, v in local_var_table.items(): + self.variable_table.add(k, v) + return + self.generic_visit(node) + + def visit_Call(self, node): + if isinstance(node, ast.Call) and is_node_parsed_in_schedule(node): + self.exec_schedule_primitive(node) + return + + def exec_schedule_primitive(self, node): + # Reflect ScheduleBlockContext to ScheduleBlockVariable + sch_primitive = node + args = [ast.Name(id="scheduler", ctx=ast.Load()), *sch_primitive.args] + sch_primitive.args = args + all_variable_table = self.variable_table.get() + for k, v in self.global_variable_table.get().items(): + all_variable_table[k] = v + sch_ret = ExprExecutor(all_variable_table).exec(node) + + return sch_ret + + +class ScheduleBlockVariable: + """ + The parse Schedule process replaces ScheduleBlockContext with this class on the ast layer to improve schedule usability on the python layer + For example, split a loop in c++ requires two steps: + 1. Gets the loop for the corresponding block: `x, y = sch.get_loops(block)` + 2. Apply schedule to loop: tx, xi = sch.split(x, [2]) + This class allows you to directly manipulate the loop name of a block + `sch.split(block.x, [2])` + """ + + def __init__(self, name): + self.name = name + self.scheduler = None + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + def __getattr__(self, k): + if k == "block": + return self.scheduler.get_block(self.name) + else: + name2loops = self.scheduler.get_name2loops_dict(self.name) + return name2loops[k] diff --git a/python/cinn/runtime/data_array.py b/python/cinn/runtime/data_array.py new file mode 100644 index 00000000000000..4e7c58eced3358 --- /dev/null +++ b/python/cinn/runtime/data_array.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from cinn import common, runtime +from cinn.common import BFloat16, Bool, Float, Float16, Int, UInt + + +class DataArray: + """ + Provides Python encapsulation of the cinn_buffer_t + data interface in the CINN RunTime module. + """ + + def __init__( + self, + shape: list, + dtype: common.Type = common.Float(32), + data: runtime.cinn_buffer_t = None, + ) -> None: + self.shape = shape + self.dtype = dtype + self.data = data + + def to_numpy(self): + """ + Convert DataArray to numpy array + """ + cinn_dtype_to_np_dtype = { + # numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle + BFloat16(): "uint16", + BFloat16(): "bfloat16", + Float16(): "float16", + Float(32): "float32", + Float(64): "float64", + Int(8): "int8", + Int(16): "int16", + Int(32): "int32", + Int(64): "int64", + UInt(8): "uint8", + # numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle + # "UInt(16): uint16" + UInt(32): "uint32", + UInt(64): "uint64", + Bool(): "bool", + } + for cinn_dtype, np_dtype in cinn_dtype_to_np_dtype.items(): + if isinstance(self.dtype, cinn_dtype): + np_arr = np.empty(self.shape, np_dtype) + assert np_arr.flags["C_CONTIGUOUS"] + self.data.copy_to(np_arr) + return np_arr + + raise TypeError(f"no support {self._dtype} in CINN") + + @staticmethod + def from_numpy(np_array, target=common.DefaultHostTarget()): + """ + Create DataArray form numpy array + """ + assert isinstance(np_array, np.ndarray) + data = runtime.cinn_buffer_t(np_array, target) + dtype_np_to_common = { + # numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle + "uint16": BFloat16(), + "bfloat16": BFloat16(), + "float16": Float16(), + "float32": Float(32), + "float64": Float(64), + "int8": Int(8), + "int16": Int(16), + "int32": Int(32), + "int64": Int(64), + "uint8": UInt(8), + # numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle + # "uint16": UInt(16), + "uint32": UInt(32), + "uint64": UInt(64), + "bool": Bool(), + } + dtype_np = str(np_array.dtype).split(".")[-1] + assert str(dtype_np) in dtype_np_to_common, ( + str(dtype_np) + " not support in CINN" + ) + assert dtype_np in dtype_np_to_common.keys() + + return DataArray(np_array.shape, dtype_np_to_common[dtype_np], data) diff --git a/test/cinn/CMakeLists.txt b/test/cinn/CMakeLists.txt index ca9989b745826d..3158c4372d8fdb 100644 --- a/test/cinn/CMakeLists.txt +++ b/test/cinn/CMakeLists.txt @@ -274,4 +274,40 @@ if(WITH_GPU) WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endforeach() + file( + GLOB CINN_RUNTIME_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "runtime/test_*.py") + + foreach(runtime_test_name ${EXCLUDE_RUNTIME}) + list(REMOVE_ITEM CINN_RUNTIME_TEST runtime/${runtime_test_name}.py) + endforeach() + + foreach(runtime_test_name ${CINN_RUNTIME_TEST}) + string(REGEX REPLACE ".py" "" runtime_test_name ${runtime_test_name}) + add_test( + NAME ${runtime_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} + python3 ${CMAKE_CURRENT_SOURCE_DIR}/${runtime_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + + file( + GLOB CINN_IR_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "ir/test_*.py") + + foreach(ir_test_name ${CINN_IR_TEST}) + string(REGEX REPLACE ".py" "" ir_test_name ${ir_test_name}) + add_test( + NAME ${ir_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH} + python3 ${CMAKE_CURRENT_SOURCE_DIR}/${ir_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + endif() diff --git a/test/cinn/ir/test_llir_constructor.py b/test/cinn/ir/test_llir_constructor.py new file mode 100644 index 00000000000000..05c44e8935dfbd --- /dev/null +++ b/test/cinn/ir/test_llir_constructor.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cinn import ir, lang, to_cinn_llir +from cinn.runtime.data_array import DataArray + + +def test_call_extern(): + @to_cinn_llir + def call_sinh(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))): + for i1 in range(1): + for j1 in range(4): + for k1 in range(256): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = lang.call_extern( + "sinh", [A[vi, vi, vj, vk]], {} + ) + + str(call_sinh) + + +if __name__ == "__main__": + test_call_extern() diff --git a/test/cinn/ir/test_llir_schedule_cache_read_write.py b/test/cinn/ir/test_llir_schedule_cache_read_write.py new file mode 100644 index 00000000000000..85badc819f8f55 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_cache_read_write.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_cache_read_elementwise(): + @to_cinn_llir + def elementwise_add_cache_read( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("B") as B_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + cached_a = sch.cache_read(A_block.block, 0, "global") + cached_b = sch.cache_read(B_block.block, 0, "local") + + assert_llir_equal(elementwise_add_cache_read, elementwise_add_cache_read) + + +def test_cache_write_elementwise(): + @to_cinn_llir + def elementwise_add_cache_write( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("B") as B_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + cached_a = sch.cache_write(A_block.block, 0, "global") + cached_b = sch.cache_write(B_block.block, 0, "local") + + # TODO(6clc): core dump + # assert_llir_equal(elementwise_add_cache_write, elementwise_add_cache_write) + + +if __name__ == "__main__": + test_cache_read_elementwise() + test_cache_write_elementwise() diff --git a/test/cinn/ir/test_llir_schedule_compute_at.py b/test/cinn/ir/test_llir_schedule_compute_at.py new file mode 100644 index 00000000000000..0f82786935b411 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_compute_at.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_compute_at_elementwise(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + sch.compute_at(A_block.block, i, False) + Y[i1, j1] = A[i1, j1] + 2.0 + + @to_cinn_llir + def elementwise_add_gt( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A"): + i1, j1 = ir.AxisMap("SS", [i, 0 + j]) + A[i1, j1] = X[i1, j1] * 2.0 + for k in range(128): + with ir.ScheduleBlockContext("Y"): + i2, k1 = ir.AxisMap("SS", [i, k]) + Y[i2, k1] = A[i2, k1] + 2.0 + + assert_llir_equal(elementwise_add, elementwise_add_gt) + + +def test_reverse_compute_at(): + @to_cinn_llir + def reverse_compute_at_tiled( + A: DataArray((128, 128)), + B: DataArray((128, 128)), + C: DataArray((128, 128)), + ): + for i0 in range(8): + for j0 in range(8): + for i1 in range(16): + for j1 in range(16): + with ir.ScheduleBlockContext("B") as B_block: + vi, vj = ir.AxisMap( + "SS", [i0 * 16 + i1, j0 * 16 + j1] + ) + B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("C") as C_block: + vi, vj = ir.AxisMap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch.reverse_compute_at(C_block.block, B_block.i1) + + @to_cinn_llir + def reverse_compute_at_tiled_gt( + A: DataArray((128, 128)), + B: DataArray((128, 128)), + C: DataArray((128, 128)), + ): + for i0 in range(8): + for j0 in range(8): + for i1 in range(16): + for j1 in range(16): + with ir.ScheduleBlockContext("B") as B_block: + vi, vj = ir.AxisMap( + "SS", [i0 * 16 + i1, j0 * 16 + j1] + ) + B[vi, vj] = A[vi, vj] * 2.0 + for j2 in range(16): + with ir.ScheduleBlockContext("C") as C_block: + vi, vj = ir.AxisMap( + "SS", [16 * i0 + i1, 16 * j0 + j2] + ) + C[vi, vj] = B[vi, vj] + 1.0 + + assert_llir_equal(reverse_compute_at_tiled, reverse_compute_at_tiled_gt) + + +if __name__ == '__main__': + test_compute_at_elementwise() + test_reverse_compute_at() diff --git a/test/cinn/ir/test_llir_schedule_compute_inline.py b/test/cinn/ir/test_llir_schedule_compute_inline.py new file mode 100644 index 00000000000000..a95d1dd8174495 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_compute_inline.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_compute_inline_elementwise(): + @to_cinn_llir + def elementwise_add_inline( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + block_a = sch.get_block("A") + sch.compute_inline(block_a) + + @to_cinn_llir + def elementwise_add_inline_gt( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = -(X[i1, j1] * 2.0) + 3.0 + + assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) + + +def test_reverse_compute_inline_elementwise(): + @to_cinn_llir + def elementwise_add_inline( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(128): + for j3 in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + sch.reverse_compute_inline(Y_block.block) + + @to_cinn_llir + def elementwise_add_inline_gt( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = -(X[i1, j1] * 2.0) + 3.0 + + assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) + + +if __name__ == "__main__": + test_compute_inline_elementwise() + test_reverse_compute_inline_elementwise() diff --git a/test/cinn/ir/test_llir_schedule_fuse_split.py b/test/cinn/ir/test_llir_schedule_fuse_split.py new file mode 100644 index 00000000000000..f22b1a1f8d3a94 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_fuse_split.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_fuse(): + @to_cinn_llir + def elementwise_fuse_assign_loop( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as block_y: + sch.fuse([i, j, k]) + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_fuse_assign_loop_gt( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(2097152): + with ir.ScheduleBlockContext("Y") as block_y: + i1_1, j1_1, k1_1 = ir.AxisMap( + "SSS", [(i / 128) / 128, (i / 128) % 128, i % 128] + ) + Y[i1_1, j1_1, k1_1] = X[i1_1, j1_1, k1_1] * 2.0 + + assert_llir_equal( + elementwise_fuse_assign_loop, elementwise_fuse_assign_loop_gt + ) + + +def test_split(): + @to_cinn_llir + def elementwise_split( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.i, factors=[2, 1, 64]) + sch.split(Y_block.j, factors=[4, 32]) + sch.split(Y_block.k, factors=[16, 8]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_split_inferred_factor( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.i, factors=[-1, 1, 64]) + sch.split(Y_block.j, factors=[4, -1]) + sch.split(Y_block.k, factors=[-1, 8]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + assert_llir_equal(elementwise_split, elementwise_split_inferred_factor) + + +def test_split_predicate(): + @to_cinn_llir + def elementwise_split_predicate( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(128): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.i, factors=[1000, 1, 64]) + sch.split(Y_block.j, factors=[4, 32]) + sch.split(Y_block.k, factors=[16, 8]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_split_predicate_gt( + X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) + ): + for i in range(1000): + for i_0 in range(1): + for i_1 in range(64): + if ((64 * i) + ((64 * i_0) + i_1)) < 128: + for j in range(4): + for j_0 in range(32): + for k in range(16): + for k_0 in range(8): + with ir.ScheduleBlockContext("Y"): + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + (64 * i) + + ((64 * i_0) + i_1), + (32 * j) + j_0, + (8 * k) + k_0, + ], + ) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + assert_llir_equal( + elementwise_split_predicate, elementwise_split_predicate_gt + ) + + +if __name__ == "__main__": + test_fuse() + test_split() + test_split_predicate() diff --git a/test/cinn/ir/test_llir_schedule_reorder.py b/test/cinn/ir/test_llir_schedule_reorder.py new file mode 100644 index 00000000000000..00ca99388ba941 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_reorder.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_reorder_elementwise(): + @to_cinn_llir + def reorder_elementwise( + X: DataArray((64, 64, 64, 64)), Y: DataArray((64, 64, 64, 64)) + ): + for i in range(64): + for j in range(64): + for k in range(64): + for l in range(8): + with ir.ScheduleBlockContext("Y") as Y_block: + vi, vj, vk, vl = ir.AxisMap( + "SSSS", [i, j, k, 8 * l] + ) + Y[vi, vj, vk, vl] = X[vi, vj, vk, vl] * 2.0 + sch.reorder([Y_block.k, Y_block.l, Y_block.i]) + + @to_cinn_llir + def reorder_elementwise_gt( + X: DataArray((64, 64, 64, 64)), Y: DataArray((64, 64, 64, 64)) + ): + for k in range(64): + for j in range(64): + for l in range(8): + for i in range(64): + with ir.ScheduleBlockContext("Y"): + vi, vj, vk, vl = ir.AxisMap( + "SSSS", [i, j, k, 8 * l] + ) + Y[vi, vj, vk, vl] = X[vi, vj, vk, vl] * 2.0 + + assert_llir_equal(reorder_elementwise, reorder_elementwise_gt) + + +def test_reorder_overlapped(): + @to_cinn_llir + def reorder_overlapped(X: DataArray((28, 8)), Y: DataArray((28, 8))): + for i in range(12): + for j in range(4): + for k in range(4): + with ir.ScheduleBlockContext("Y"): + vi, vj = ir.AxisMap("SS", [i, j]) + sch.reorder([i, k, j]) + Y[vi, vj] = X[vi, vj] + 1.0 + + @to_cinn_llir + def reorder_overlapped_gt(X: DataArray((28, 8)), Y: DataArray((28, 8))): + for i in range(12): + for k in range(4): + for j in range(4): + with ir.ScheduleBlockContext("Y"): + vi, vj = ir.AxisMap("SS", [i, j]) + Y[vi, vj] = X[vi, vj] + 1.0 + + assert_llir_equal(reorder_overlapped, reorder_overlapped_gt) + + +if __name__ == '__main__': + test_reorder_elementwise() + test_reorder_overlapped() diff --git a/test/cinn/ir/test_llir_schedule_sequence.py b/test/cinn/ir/test_llir_schedule_sequence.py new file mode 100644 index 00000000000000..2cff0c650fd632 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_sequence.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_split_reorder_elementwise(): + @to_cinn_llir + def split_reorder_elementwise( + X: DataArray((1024, 1024)), + Y: DataArray((1024, 1024)), + Z: DataArray((1024, 1024)), + ): + for i in range(1024): + for j in range(1024): + for k in range(1024): + with ir.ScheduleBlockContext("Z"): + i_split_0, i_split_1, i_split_2, i_split_3 = sch.split( + i, factors=[2, 4, 64, 2] + ) + sch.reorder([i_split_2, i_split_0]) + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + Z[i1, j1] = Z[i1, j1] + X[i1, k] * Y[k, j1] + + @to_cinn_llir + def split_reorder_elementwise_gt( + X: DataArray((1024, 1024)), + Y: DataArray((1024, 1024)), + Z: DataArray((1024, 1024)), + ): + for i_1 in range(64): + for i_0 in range(4): + for i in range(2): + for i_2 in range(2): + for j in range(1024): + for k in range(1024): + with ir.ScheduleBlockContext("Z"): + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + (512 * i) + + ((128 * i_0) + ((2 * i_1) + i_2)), + j, + k, + ], + ) + Z[i1, j1] = Z[i1, j1] + ( + X[i1, k] * Y[k, j1] + ) + + assert_llir_equal(split_reorder_elementwise, split_reorder_elementwise_gt) + + +if __name__ == "__main__": + test_split_reorder_elementwise() diff --git a/test/cinn/utils/testing.py b/test/cinn/utils/testing.py new file mode 100644 index 00000000000000..b67432a17c189a --- /dev/null +++ b/test/cinn/utils/testing.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cinn.ir import IrCompare +from cinn.runtime import CinnLowerLevelIrJit + + +def assert_llir_equal( + llir1, llir2, allow_name_suffix_diff=True, only_compare_structure=True +): + comparer = IrCompare(allow_name_suffix_diff, only_compare_structure) + + if isinstance(llir1, CinnLowerLevelIrJit): + llir1_expr = llir1.convert_to_llir().body() + llir2_expr = llir2.convert_to_llir().body() + assert comparer.compare( + llir1_expr, llir2_expr + ), f'llir1: {llir1} \n llir2: {llir2}'