Skip to content

Commit

Permalink
[PERF] Indexes optimization (#458)
Browse files Browse the repository at this point in the history
### Performance
 - synthetic test, elementwise add `a = b + 1`. Improvement around 2%

 - models (together with #452 )

|model|latency|prev_latency|delta|
|--------|--------|--------|--------|
|bert-base-uncased|19.6886|20.0891|2.034
|densenet121|35.0828|35.2025|0.341
|efficientnet_b0|18.8482|18.9655|0.622
|mobilenet_v2|11.5672|11.5901|0.198
|resnet50|29.0077|29.6469|2.204
|vit_b_16|123.344|126.705|2.725
|**GMEAN**| | |**1.349**

  
|model|latency|prev_latency|delta|
|--------|--------|--------|--------|
|bert-base-uncased|19.6886|19.9814|1.487
|densenet121|35.0828|35.2032|0.343
|efficientnet_b0|18.8482|18.9711|0.652
|mobilenet_v2|11.5672|11.5913|0.208
|resnet50|29.0077|29.101|0.322
|vit_b_16|123.344|124.837|1.21
|**GMEAN**| | |**0.702**

The above comparisons were done against exactly the same `main` branch.
Yes, we have a big fluctuation in perf results :(


### TODO
- doesn't work with dynamic shapes yet
- only `spatial` mapping is supported right now. Another mapping support
should be investigating
  • Loading branch information
vadiklyutiy committed Dec 19, 2024
1 parent 9308c9f commit cb07596
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 14 deletions.
4 changes: 3 additions & 1 deletion gallery/developer-guides/hidet-script-dynamic-kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def matmul_kernel(


def main():
from hidet.utils.benchmark import benchmark_func

func = matmul_simt_kernel()

for m, n, k in [(1024, 1024, 1024), (333, 444, 555), (1, 12, 13)]:
Expand All @@ -135,7 +137,7 @@ def main():
actual=c.cpu().numpy(), desired=a.cpu().numpy() @ b.cpu().numpy(), rtol=1e-4, atol=1e-4
)

hidet_latency = hidet.utils.benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50)
hidet_latency = benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50)
print(f'{m}x{k}x{n}: hidet takes {hidet_latency:.2f} ms')


Expand Down
7 changes: 4 additions & 3 deletions gallery/hidet-script/5-efficient-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from hidet.lang.cuda import threadIdx, blockIdx, syncthreads
from hidet.lang.mapping import spatial, auto_map
from hidet.lang.layout import row_major, local_layout
from hidet.utils.benchmark import benchmark_func

# the hyperparameters of the kernel
warps_m, warps_n = 4, 2 # we use 4x2 warps
Expand Down Expand Up @@ -154,10 +155,10 @@ def torch_matmul_relu(a: torch.Tensor, b: torch.Tensor):

torch.testing.assert_close(c1, c2, atol=1e-4, rtol=1e-4)

hidet_latency = hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b), repeat=50)
hidet_latency = benchmark_func(lambda: hidet_matmul_relu(a, b), repeat=50)
print(f'{m}x{k}x{n}:')
print(' torch: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: torch_matmul_relu(a, b))))
print(' hidet: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b))))
print(' torch: {:.3f} ms'.format(benchmark_func(lambda: torch_matmul_relu(a, b))))
print(' hidet: {:.3f} ms'.format(benchmark_func(lambda: hidet_matmul_relu(a, b))))

# %%
# Get the source code:
Expand Down
2 changes: 1 addition & 1 deletion gallery/tutorials/optimize-onnx-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# The :func:`benchmark_func() <hidet.utils.benchmark_func>` function runs the given function multiple times to
# get the median latency.

from hidet.utils import benchmark_func
from hidet.utils.benchmark import benchmark_func

print('PyTorch: {:.3f} ms'.format(benchmark_func(lambda: torch_model(torch_data))))

Expand Down
1 change: 1 addition & 0 deletions python/hidet/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import layout
from . import mapping
from . import task
from . import polinomial

from .node import Node
from .module import IRModule
Expand Down
2 changes: 0 additions & 2 deletions python/hidet/ir/functors/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def visit_dispatch(self, node):
return self.visit_Var(node)
elif isinstance(node, Add):
return self.visit_Add(node)
elif isinstance(node, Add):
return self.visit_Add(node)
elif isinstance(node, Sub):
return self.visit_Sub(node)
elif isinstance(node, Multiply):
Expand Down
154 changes: 154 additions & 0 deletions python/hidet/ir/polinomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel, useless-parent-delegation, redefined-outer-name, redefined-builtin
# pylint: disable=useless-super-delegation, protected-access
from hidet.ir.functors import ExprVisitor
from .node import Node
from .expr import Expr, Var, Constant

POLINOMIAL_BIAS_NAME = 'zzz_polinomial_bias'

# Right now implemented Linear polinomial with integer coeficients only.
# TODO: expand functionality.
class Poli(Node):
def __init__(self, var: str = POLINOMIAL_BIAS_NAME, coef: int = 0) -> None:
super().__init__()
# monos is a list of monomials of polinomial
if var == POLINOMIAL_BIAS_NAME:
self.monos: dict[str, int] = {var: coef}
else:
self.monos: dict[str, int] = {var: coef, POLINOMIAL_BIAS_NAME: 0}

def is_constant(self):
self.remove_zeros()
if len(self.monos) == 1 and POLINOMIAL_BIAS_NAME in self.monos: # pylint: disable=simplifiable-if-statement
return True
else:
return False

def remove_zeros(self):
copy = self.monos.copy()
for key in copy:
if self.monos[key] == 0 and key != POLINOMIAL_BIAS_NAME:
del self.monos[key]

def get_bias(self):
return self.monos.get(POLINOMIAL_BIAS_NAME, 0)

@staticmethod
def _binary_add(oper, a, b):
if a is None or b is None:
return None
res_monos = {}
all_keys = set(a.monos.keys()).union(set(b.monos.keys()))
for key in all_keys:
res_monos[key] = oper(a.monos.get(key, 0), b.monos.get(key, 0))
res = Poli()
res.monos = res_monos
return res

@staticmethod
def _binary_mul(a, b: int):
if a is None or b is None:
return None
# Only linear polinomials with int coefs are supported now
if not isinstance(b, int):
return None
res_monos = {}
for key in a.monos.keys():
res_monos[key] = a.monos[key] * b
res = Poli()
res.monos = res_monos
return res

@staticmethod
def _convert(obj):
if isinstance(obj, Poli):
return obj
if isinstance(obj, int):
return Poli(coef=obj)
if isinstance(obj, str):
return Poli(var=obj, coef=1)
if isinstance(obj, Var):
return Poli(var=obj.name, coef=1)
# Cannot convert. Return None
return None

def __add__(self, other):
return self._binary_add(lambda a, b: a + b, self, self._convert(other))

def __radd__(self, other):
return self._binary_add(lambda a, b: a + b, self._convert(other), self)

def __sub__(self, other):
return self._binary_add(lambda a, b: a - b, self, self._convert(other))

def __rsub__(self, other):
return self._binary_add(lambda a, b: a - b, self._convert(other), self)

def __mul__(self, other):
return self._binary_mul(self, other)

def __rmul__(self, other):
assert isinstance(other, int)
return self._binary_mul(self, other)

def __eq__(self, other):
assert False

def __str__(self):
res = ''
sorted_keys = sorted(self.monos.keys())
for key in sorted_keys:
res += f"{self.monos[key]}*{key}+"
res = res.replace('*' + POLINOMIAL_BIAS_NAME, '')
res = res[:-1]
return res


# Convert expression `Expr` to polinomial `Poli`.
#
# Return:
# None if conversion isn't successful
def from_expr_to_poli(expr: Expr) -> Poli:
assert isinstance(expr, Expr)
visitor = Expr2PoliConverter()
poli = visitor.visit(expr)
return poli


class Expr2PoliConverter(ExprVisitor):
def __init__(self):
super().__init__(use_memo=True)
self.poli: Poli = Poli()

def visit_Constant(self, c: Constant):
return c.value

def visit_Var(self, v: Var):
return Poli(v.hint, coef=1)

def _visit_binary(self, e: Expr, op):
a = self.visit(e.a)
b = self.visit(e.b)
if a is None or b is None:
return None
return op(a, b)

def visit_Add(self, e):
return self._visit_binary(e, lambda x, y: x + y)

def visit_Sub(self, e):
return self._visit_binary(e, lambda x, y: x - y)

def visit_Multiply(self, e):
return self._visit_binary(e, lambda x, y: x * y)
46 changes: 44 additions & 2 deletions python/hidet/ir/tools/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,53 @@
# limitations under the License.
from typing import Dict, List, Union, Mapping

from hidet.ir.expr import Let, Var
from hidet.ir.expr import Let, Var, Expr, TensorElement
from hidet.ir.functors import IRRewriter
from hidet.ir.node import Node
from hidet.ir.stmt import ForMappingStmt, DeclareStmt, ForStmt
from hidet.ir.stmt import LetStmt
from hidet.ir.stmt import LetStmt, BufferStoreStmt
from hidet.ir.polinomial import Poli, from_expr_to_poli

# Rewriter that search for given polinomial `old: Expr` and change it on another `new: Expr`
# Search only throught indeces of tensors.
# We calculate d = indeces - old.
# If d is const than we change indeces = new + d
class PolinomialExpr2ExprRewriter(IRRewriter):
def __init__(self, old: Expr, new: Expr):
super().__init__()
self.old: Poli = from_expr_to_poli(old)
self.new = new

def visit_TensorElement(self, te: TensorElement):
assert len(te.indices) == 1
indices = from_expr_to_poli(te.indices[0])
# TODO indices is None mean fail of conversion. unsqeeze produce i % 40
if indices is None or self.old is None:
return te
diff = indices - self.old
if isinstance(diff, int):
return TensorElement(te.base, (self.new + diff,), te.protected)
elif diff.is_constant():
return TensorElement(te.base, (self.new + diff.get_bias(),), te.protected)
else:
return te

def visit_BufferStoreStmt(self, stmt: BufferStoreStmt):
assert len(stmt.indices) == 1
indices = stmt.indices
indices_poli = from_expr_to_poli(indices[0])
if indices_poli is None or self.old is None:
return stmt
diff = indices_poli - self.old
if isinstance(diff, int):
indices = (self.new + diff,)
elif diff.is_constant():
indices = (self.new + diff.get_bias(),)
value = self.visit(stmt.value)
if indices[0] is stmt.indices[0] and value is stmt.value:
return stmt
else:
return BufferStoreStmt(stmt.buf, indices, value, stmt.protected)


class MapBasedRewriter(IRRewriter):
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.utils import benchmark_func
from hidet.utils.benchmark import benchmark_func

from . import models
from . import utils
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def lower(ir_module: IRModule) -> IRModule:
flatten_tensor_slice_pass(),
lower_protect_access_pass(),
spatial_simplification_pass(),
flatten_tensor_index_pass(),
lower_task_mapping_pass(),
normalize_const_tensor_pass(),
declare_to_let_pass(),
rule_based_simplify_pass(), # make ir more readable
flatten_tensor_index_pass(),
lower_integer_subbyte_pass(),
lower_special_cast_pass(),
inline_function_pass(),
Expand Down
16 changes: 15 additions & 1 deletion python/hidet/transforms/lower_task_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# limitations under the License.
from typing import List, Dict, Sequence, Union, Optional
from hidet.ir import Var, ForMappingStmt, Stmt, ForStmt, Expr, SeqStmt
from hidet.ir.dtypes import int32
from hidet.ir.expr import var
from hidet.ir.mapping import TaskMapping, SpatialTaskMapping, RepeatTaskMapping, ComposedTaskMapping
from hidet.ir.func import Function
from hidet.ir.functors import IRRewriter
from hidet.ir.tools import rewrite, simplify
from hidet.ir.tools.rewriter import PolinomialExpr2ExprRewriter
from hidet.transforms.base import Pass, FunctionPass
from hidet.utils import prod

Expand All @@ -36,7 +38,19 @@ def __init__(self):
self.loop_nests: List[ForStmt] = []

def expand(self, mapping: TaskMapping, worker: Expr, loop_vars: List[Var], body: Stmt) -> Stmt:
tasks: List[TaskIndex] = self.visit(mapping, worker)
# Here we try to find expression that represent the flatten index and
# change it on worker (because worker is a same as a flatten index).
# Just default expand - when we represent every loop var as worker expression is
# compilcated. In many cases either hidet's passes or nvcc cannot optimise it.
if isinstance(mapping, SpatialTaskMapping):
flatten = int32.zero
for loop_var, stride in zip(loop_vars, mapping.strides):
flatten += loop_var * stride

rewriter = PolinomialExpr2ExprRewriter(flatten, worker)
body = rewriter.rewrite(body)

tasks = self.visit(mapping, worker)
seq = []
for task in tasks:
remap: Dict[Var, Expr] = {a: b for a, b in zip(loop_vars, task)}
Expand Down
1 change: 0 additions & 1 deletion python/hidet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .py import prod, Timer, repeat_until_converge, COLORS, get_next_file_index, factorize, HidetProfiler, TableBuilder
from .py import same_list, strict_zip, index_of, initialize, gcd, lcm, error_tolerance, green, red, cyan, bold, blue
from .py import str_indent, unique, assert_close, cdiv
from .benchmark import Bench, benchmark_func
from .structure import DirectedGraph
from .cache_utils import cache_dir, cache_file, clear_op_cache, clear_cache_dir
from .net_utils import download
5 changes: 4 additions & 1 deletion tests/ir/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from hidet.transforms.check_launch_configuration import check_launch_configuration_pass
from hidet.transforms.lower_special_cast import lower_special_cast_pass
from hidet.transforms.annotate_header_and_libs import annotate_header_and_libs_pass
from hidet.transforms.spatial_simplification import spatial_simplification_pass


# from hidet.graph.ops.softmax import SoftmaxTask
from hidet.graph.ops.matmul.matmul_f16 import MatmulF16Task
Expand Down Expand Up @@ -89,11 +91,12 @@ def generate_ir_modules():
generate_launch_func_pass(),
flatten_tensor_slice_pass(),
lower_protect_access_pass(),
spatial_simplification_pass(),
flatten_tensor_index_pass(),
lower_task_mapping_pass(),
normalize_const_tensor_pass(),
declare_to_let_pass(),
rule_based_simplify_pass(),
flatten_tensor_index_pass(),
lower_special_cast_pass(),
inline_function_pass(),
resolve_primitive_func_pass(),
Expand Down

0 comments on commit cb07596

Please sign in to comment.