From cb07596066eb2497b651203e3557e69b48e79cd9 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Fri, 20 Sep 2024 04:18:29 +0400 Subject: [PATCH] [PERF] Indexes optimization (#458) ### Performance - synthetic test, elementwise add `a = b + 1`. Improvement around 2% - models (together with #452 ) |model|latency|prev_latency|delta| |--------|--------|--------|--------| |bert-base-uncased|19.6886|20.0891|2.034 |densenet121|35.0828|35.2025|0.341 |efficientnet_b0|18.8482|18.9655|0.622 |mobilenet_v2|11.5672|11.5901|0.198 |resnet50|29.0077|29.6469|2.204 |vit_b_16|123.344|126.705|2.725 |**GMEAN**| | |**1.349** |model|latency|prev_latency|delta| |--------|--------|--------|--------| |bert-base-uncased|19.6886|19.9814|1.487 |densenet121|35.0828|35.2032|0.343 |efficientnet_b0|18.8482|18.9711|0.652 |mobilenet_v2|11.5672|11.5913|0.208 |resnet50|29.0077|29.101|0.322 |vit_b_16|123.344|124.837|1.21 |**GMEAN**| | |**0.702** The above comparisons were done against exactly the same `main` branch. Yes, we have a big fluctuation in perf results :( ### TODO - doesn't work with dynamic shapes yet - only `spatial` mapping is supported right now. Another mapping support should be investigating --- .../hidet-script-dynamic-kernel.py | 4 +- gallery/hidet-script/5-efficient-matmul.py | 7 +- gallery/tutorials/optimize-onnx-model.py | 2 +- python/hidet/ir/__init__.py | 1 + python/hidet/ir/functors/expr_functor.py | 2 - python/hidet/ir/polinomial.py | 154 ++++++++++++++++++ python/hidet/ir/tools/rewriter.py | 46 +++++- python/hidet/testing/__init__.py | 2 +- python/hidet/transforms/__init__.py | 2 +- python/hidet/transforms/lower_task_mapping.py | 16 +- python/hidet/utils/__init__.py | 1 - tests/ir/parser/test_parser.py | 5 +- 12 files changed, 228 insertions(+), 14 deletions(-) create mode 100644 python/hidet/ir/polinomial.py diff --git a/gallery/developer-guides/hidet-script-dynamic-kernel.py b/gallery/developer-guides/hidet-script-dynamic-kernel.py index ee895f110..b10d1612c 100644 --- a/gallery/developer-guides/hidet-script-dynamic-kernel.py +++ b/gallery/developer-guides/hidet-script-dynamic-kernel.py @@ -124,6 +124,8 @@ def matmul_kernel( def main(): + from hidet.utils.benchmark import benchmark_func + func = matmul_simt_kernel() for m, n, k in [(1024, 1024, 1024), (333, 444, 555), (1, 12, 13)]: @@ -135,7 +137,7 @@ def main(): actual=c.cpu().numpy(), desired=a.cpu().numpy() @ b.cpu().numpy(), rtol=1e-4, atol=1e-4 ) - hidet_latency = hidet.utils.benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50) + hidet_latency = benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50) print(f'{m}x{k}x{n}: hidet takes {hidet_latency:.2f} ms') diff --git a/gallery/hidet-script/5-efficient-matmul.py b/gallery/hidet-script/5-efficient-matmul.py index 19c9c1cbf..1942ba52d 100644 --- a/gallery/hidet-script/5-efficient-matmul.py +++ b/gallery/hidet-script/5-efficient-matmul.py @@ -20,6 +20,7 @@ from hidet.lang.cuda import threadIdx, blockIdx, syncthreads from hidet.lang.mapping import spatial, auto_map from hidet.lang.layout import row_major, local_layout +from hidet.utils.benchmark import benchmark_func # the hyperparameters of the kernel warps_m, warps_n = 4, 2 # we use 4x2 warps @@ -154,10 +155,10 @@ def torch_matmul_relu(a: torch.Tensor, b: torch.Tensor): torch.testing.assert_close(c1, c2, atol=1e-4, rtol=1e-4) - hidet_latency = hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b), repeat=50) + hidet_latency = benchmark_func(lambda: hidet_matmul_relu(a, b), repeat=50) print(f'{m}x{k}x{n}:') - print(' torch: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: torch_matmul_relu(a, b)))) - print(' hidet: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b)))) + print(' torch: {:.3f} ms'.format(benchmark_func(lambda: torch_matmul_relu(a, b)))) + print(' hidet: {:.3f} ms'.format(benchmark_func(lambda: hidet_matmul_relu(a, b)))) # %% # Get the source code: diff --git a/gallery/tutorials/optimize-onnx-model.py b/gallery/tutorials/optimize-onnx-model.py index 6487ebfff..a61ff667d 100644 --- a/gallery/tutorials/optimize-onnx-model.py +++ b/gallery/tutorials/optimize-onnx-model.py @@ -44,7 +44,7 @@ # The :func:`benchmark_func() ` function runs the given function multiple times to # get the median latency. -from hidet.utils import benchmark_func +from hidet.utils.benchmark import benchmark_func print('PyTorch: {:.3f} ms'.format(benchmark_func(lambda: torch_model(torch_data)))) diff --git a/python/hidet/ir/__init__.py b/python/hidet/ir/__init__.py index e65661ba8..677f8c998 100644 --- a/python/hidet/ir/__init__.py +++ b/python/hidet/ir/__init__.py @@ -18,6 +18,7 @@ from . import layout from . import mapping from . import task +from . import polinomial from .node import Node from .module import IRModule diff --git a/python/hidet/ir/functors/expr_functor.py b/python/hidet/ir/functors/expr_functor.py index 81ca79bf1..8f80cff4c 100644 --- a/python/hidet/ir/functors/expr_functor.py +++ b/python/hidet/ir/functors/expr_functor.py @@ -25,8 +25,6 @@ def visit_dispatch(self, node): return self.visit_Var(node) elif isinstance(node, Add): return self.visit_Add(node) - elif isinstance(node, Add): - return self.visit_Add(node) elif isinstance(node, Sub): return self.visit_Sub(node) elif isinstance(node, Multiply): diff --git a/python/hidet/ir/polinomial.py b/python/hidet/ir/polinomial.py new file mode 100644 index 000000000..34e0d4735 --- /dev/null +++ b/python/hidet/ir/polinomial.py @@ -0,0 +1,154 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=import-outside-toplevel, useless-parent-delegation, redefined-outer-name, redefined-builtin +# pylint: disable=useless-super-delegation, protected-access +from hidet.ir.functors import ExprVisitor +from .node import Node +from .expr import Expr, Var, Constant + +POLINOMIAL_BIAS_NAME = 'zzz_polinomial_bias' + +# Right now implemented Linear polinomial with integer coeficients only. +# TODO: expand functionality. +class Poli(Node): + def __init__(self, var: str = POLINOMIAL_BIAS_NAME, coef: int = 0) -> None: + super().__init__() + # monos is a list of monomials of polinomial + if var == POLINOMIAL_BIAS_NAME: + self.monos: dict[str, int] = {var: coef} + else: + self.monos: dict[str, int] = {var: coef, POLINOMIAL_BIAS_NAME: 0} + + def is_constant(self): + self.remove_zeros() + if len(self.monos) == 1 and POLINOMIAL_BIAS_NAME in self.monos: # pylint: disable=simplifiable-if-statement + return True + else: + return False + + def remove_zeros(self): + copy = self.monos.copy() + for key in copy: + if self.monos[key] == 0 and key != POLINOMIAL_BIAS_NAME: + del self.monos[key] + + def get_bias(self): + return self.monos.get(POLINOMIAL_BIAS_NAME, 0) + + @staticmethod + def _binary_add(oper, a, b): + if a is None or b is None: + return None + res_monos = {} + all_keys = set(a.monos.keys()).union(set(b.monos.keys())) + for key in all_keys: + res_monos[key] = oper(a.monos.get(key, 0), b.monos.get(key, 0)) + res = Poli() + res.monos = res_monos + return res + + @staticmethod + def _binary_mul(a, b: int): + if a is None or b is None: + return None + # Only linear polinomials with int coefs are supported now + if not isinstance(b, int): + return None + res_monos = {} + for key in a.monos.keys(): + res_monos[key] = a.monos[key] * b + res = Poli() + res.monos = res_monos + return res + + @staticmethod + def _convert(obj): + if isinstance(obj, Poli): + return obj + if isinstance(obj, int): + return Poli(coef=obj) + if isinstance(obj, str): + return Poli(var=obj, coef=1) + if isinstance(obj, Var): + return Poli(var=obj.name, coef=1) + # Cannot convert. Return None + return None + + def __add__(self, other): + return self._binary_add(lambda a, b: a + b, self, self._convert(other)) + + def __radd__(self, other): + return self._binary_add(lambda a, b: a + b, self._convert(other), self) + + def __sub__(self, other): + return self._binary_add(lambda a, b: a - b, self, self._convert(other)) + + def __rsub__(self, other): + return self._binary_add(lambda a, b: a - b, self._convert(other), self) + + def __mul__(self, other): + return self._binary_mul(self, other) + + def __rmul__(self, other): + assert isinstance(other, int) + return self._binary_mul(self, other) + + def __eq__(self, other): + assert False + + def __str__(self): + res = '' + sorted_keys = sorted(self.monos.keys()) + for key in sorted_keys: + res += f"{self.monos[key]}*{key}+" + res = res.replace('*' + POLINOMIAL_BIAS_NAME, '') + res = res[:-1] + return res + + +# Convert expression `Expr` to polinomial `Poli`. +# +# Return: +# None if conversion isn't successful +def from_expr_to_poli(expr: Expr) -> Poli: + assert isinstance(expr, Expr) + visitor = Expr2PoliConverter() + poli = visitor.visit(expr) + return poli + + +class Expr2PoliConverter(ExprVisitor): + def __init__(self): + super().__init__(use_memo=True) + self.poli: Poli = Poli() + + def visit_Constant(self, c: Constant): + return c.value + + def visit_Var(self, v: Var): + return Poli(v.hint, coef=1) + + def _visit_binary(self, e: Expr, op): + a = self.visit(e.a) + b = self.visit(e.b) + if a is None or b is None: + return None + return op(a, b) + + def visit_Add(self, e): + return self._visit_binary(e, lambda x, y: x + y) + + def visit_Sub(self, e): + return self._visit_binary(e, lambda x, y: x - y) + + def visit_Multiply(self, e): + return self._visit_binary(e, lambda x, y: x * y) diff --git a/python/hidet/ir/tools/rewriter.py b/python/hidet/ir/tools/rewriter.py index 2443086d1..24e0ea4ea 100644 --- a/python/hidet/ir/tools/rewriter.py +++ b/python/hidet/ir/tools/rewriter.py @@ -11,11 +11,53 @@ # limitations under the License. from typing import Dict, List, Union, Mapping -from hidet.ir.expr import Let, Var +from hidet.ir.expr import Let, Var, Expr, TensorElement from hidet.ir.functors import IRRewriter from hidet.ir.node import Node from hidet.ir.stmt import ForMappingStmt, DeclareStmt, ForStmt -from hidet.ir.stmt import LetStmt +from hidet.ir.stmt import LetStmt, BufferStoreStmt +from hidet.ir.polinomial import Poli, from_expr_to_poli + +# Rewriter that search for given polinomial `old: Expr` and change it on another `new: Expr` +# Search only throught indeces of tensors. +# We calculate d = indeces - old. +# If d is const than we change indeces = new + d +class PolinomialExpr2ExprRewriter(IRRewriter): + def __init__(self, old: Expr, new: Expr): + super().__init__() + self.old: Poli = from_expr_to_poli(old) + self.new = new + + def visit_TensorElement(self, te: TensorElement): + assert len(te.indices) == 1 + indices = from_expr_to_poli(te.indices[0]) + # TODO indices is None mean fail of conversion. unsqeeze produce i % 40 + if indices is None or self.old is None: + return te + diff = indices - self.old + if isinstance(diff, int): + return TensorElement(te.base, (self.new + diff,), te.protected) + elif diff.is_constant(): + return TensorElement(te.base, (self.new + diff.get_bias(),), te.protected) + else: + return te + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + assert len(stmt.indices) == 1 + indices = stmt.indices + indices_poli = from_expr_to_poli(indices[0]) + if indices_poli is None or self.old is None: + return stmt + diff = indices_poli - self.old + if isinstance(diff, int): + indices = (self.new + diff,) + elif diff.is_constant(): + indices = (self.new + diff.get_bias(),) + value = self.visit(stmt.value) + if indices[0] is stmt.indices[0] and value is stmt.value: + return stmt + else: + return BufferStoreStmt(stmt.buf, indices, value, stmt.protected) class MapBasedRewriter(IRRewriter): diff --git a/python/hidet/testing/__init__.py b/python/hidet/testing/__init__.py index 8d241f813..a3ea8f12d 100644 --- a/python/hidet/testing/__init__.py +++ b/python/hidet/testing/__init__.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from hidet.utils import benchmark_func +from hidet.utils.benchmark import benchmark_func from . import models from . import utils diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index e32e8a90b..0e1a295d9 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -76,11 +76,11 @@ def lower(ir_module: IRModule) -> IRModule: flatten_tensor_slice_pass(), lower_protect_access_pass(), spatial_simplification_pass(), + flatten_tensor_index_pass(), lower_task_mapping_pass(), normalize_const_tensor_pass(), declare_to_let_pass(), rule_based_simplify_pass(), # make ir more readable - flatten_tensor_index_pass(), lower_integer_subbyte_pass(), lower_special_cast_pass(), inline_function_pass(), diff --git a/python/hidet/transforms/lower_task_mapping.py b/python/hidet/transforms/lower_task_mapping.py index 116fe449e..47cbc81f9 100644 --- a/python/hidet/transforms/lower_task_mapping.py +++ b/python/hidet/transforms/lower_task_mapping.py @@ -11,11 +11,13 @@ # limitations under the License. from typing import List, Dict, Sequence, Union, Optional from hidet.ir import Var, ForMappingStmt, Stmt, ForStmt, Expr, SeqStmt +from hidet.ir.dtypes import int32 from hidet.ir.expr import var from hidet.ir.mapping import TaskMapping, SpatialTaskMapping, RepeatTaskMapping, ComposedTaskMapping from hidet.ir.func import Function from hidet.ir.functors import IRRewriter from hidet.ir.tools import rewrite, simplify +from hidet.ir.tools.rewriter import PolinomialExpr2ExprRewriter from hidet.transforms.base import Pass, FunctionPass from hidet.utils import prod @@ -36,7 +38,19 @@ def __init__(self): self.loop_nests: List[ForStmt] = [] def expand(self, mapping: TaskMapping, worker: Expr, loop_vars: List[Var], body: Stmt) -> Stmt: - tasks: List[TaskIndex] = self.visit(mapping, worker) + # Here we try to find expression that represent the flatten index and + # change it on worker (because worker is a same as a flatten index). + # Just default expand - when we represent every loop var as worker expression is + # compilcated. In many cases either hidet's passes or nvcc cannot optimise it. + if isinstance(mapping, SpatialTaskMapping): + flatten = int32.zero + for loop_var, stride in zip(loop_vars, mapping.strides): + flatten += loop_var * stride + + rewriter = PolinomialExpr2ExprRewriter(flatten, worker) + body = rewriter.rewrite(body) + + tasks = self.visit(mapping, worker) seq = [] for task in tasks: remap: Dict[Var, Expr] = {a: b for a, b in zip(loop_vars, task)} diff --git a/python/hidet/utils/__init__.py b/python/hidet/utils/__init__.py index 1432332b2..9464b129f 100644 --- a/python/hidet/utils/__init__.py +++ b/python/hidet/utils/__init__.py @@ -20,7 +20,6 @@ from .py import prod, Timer, repeat_until_converge, COLORS, get_next_file_index, factorize, HidetProfiler, TableBuilder from .py import same_list, strict_zip, index_of, initialize, gcd, lcm, error_tolerance, green, red, cyan, bold, blue from .py import str_indent, unique, assert_close, cdiv -from .benchmark import Bench, benchmark_func from .structure import DirectedGraph from .cache_utils import cache_dir, cache_file, clear_op_cache, clear_cache_dir from .net_utils import download diff --git a/tests/ir/parser/test_parser.py b/tests/ir/parser/test_parser.py index 47616c3f3..4bea74aa1 100644 --- a/tests/ir/parser/test_parser.py +++ b/tests/ir/parser/test_parser.py @@ -35,6 +35,8 @@ from hidet.transforms.check_launch_configuration import check_launch_configuration_pass from hidet.transforms.lower_special_cast import lower_special_cast_pass from hidet.transforms.annotate_header_and_libs import annotate_header_and_libs_pass +from hidet.transforms.spatial_simplification import spatial_simplification_pass + # from hidet.graph.ops.softmax import SoftmaxTask from hidet.graph.ops.matmul.matmul_f16 import MatmulF16Task @@ -89,11 +91,12 @@ def generate_ir_modules(): generate_launch_func_pass(), flatten_tensor_slice_pass(), lower_protect_access_pass(), + spatial_simplification_pass(), + flatten_tensor_index_pass(), lower_task_mapping_pass(), normalize_const_tensor_pass(), declare_to_let_pass(), rule_based_simplify_pass(), - flatten_tensor_index_pass(), lower_special_cast_pass(), inline_function_pass(), resolve_primitive_func_pass(),