diff --git a/CMakeLists.txt b/CMakeLists.txt index 32c0a9333..3e982c11b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(hidet_runtime SHARED src/hidet/runtime/callbacks.cpp src/hidet/runtime/logging.cpp src/hidet/runtime/symbols.cpp + src/hidet/runtime/int_fastdiv.cpp src/hidet/runtime/llm/tokenizer/decoders.cpp src/hidet/runtime/llm/tokenizer/models.cpp src/hidet/runtime/llm/tokenizer/normalizers.cpp diff --git a/include/hidet/runtime/int_fastdiv.h b/include/hidet/runtime/int_fastdiv.h new file mode 100644 index 000000000..0fc3d7344 --- /dev/null +++ b/include/hidet/runtime/int_fastdiv.h @@ -0,0 +1,22 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +#ifdef __CUDA_ARCH__ +#define HOST_DEVICE __host__ __device__ +#else +#define HOST_DEVICE +#endif + +HOST_DEVICE void calculate_magic_numbers(int d, int &m, int &s, int &as); diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index 345bcb405..0e3846abf 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -712,7 +712,7 @@ def require_headers(self) -> Doc: doc += Text('#include ') + NewLine() doc += Text('#include ') + NewLine() doc += Text("#include ") + NewLine() - + doc += Text('#include ') + NewLine() for header in self.ir_module.include_headers: doc += Text('#include <{}>').format(header) + NewLine() @@ -801,6 +801,7 @@ def require_headers(self) -> Doc: doc += Text('#include ') + NewLine() doc += Text('#include ') + NewLine() doc += Text("#include ") + NewLine() + doc += Text('#include ') + NewLine() if self.require_complex: doc += Text('#include ') + NewLine() diff --git a/python/hidet/ir/primitives/cuda/fastintdiv.py b/python/hidet/ir/primitives/cuda/fastintdiv.py new file mode 100644 index 000000000..727c61447 --- /dev/null +++ b/python/hidet/ir/primitives/cuda/fastintdiv.py @@ -0,0 +1,64 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir.dtypes import i32 +from hidet.ir.expr import Expr, cast +from hidet.ir.primitives.func import register_primitive_function, call_primitive_func +from hidet.utils import initialize + + +@initialize() +def register_fastdiv_functions(): + from hidet.lang import script, attrs, asm + + dtype = i32 + div_func_name = 'fastintdiv' + + @script + def div_op(dividend: dtype, divisor: dtype, m: dtype, s: dtype, ads: dtype) -> dtype: + attrs.func_kind = 'cuda_internal' + attrs.func_name = div_func_name + q = 0 + asm('mul.hi.s32 %0, %1, %2;', outputs=[q], inputs=[m, dividend]) + q = q + dividend * ads + if s >= 0: + q = q >> s + q = q + (cast(q, 'uint32') >> 31) + return q + + register_primitive_function(name=div_func_name, func_or_type=div_op) + + mod_func_name = 'fastintmod' + + @script + def mod_op(dividend: dtype, divisor: dtype, m: dtype, s: dtype, ads: dtype) -> dtype: + attrs.func_kind = 'cuda_internal' + attrs.func_name = mod_func_name + q = 0 + asm('mul.hi.s32 %0, %1, %2;', outputs=[q], inputs=[m, dividend]) + q = q + dividend * ads + if s >= 0: + q = q >> s + q = q + (cast(q, 'uint32') >> 31) + remainder = dividend - q * divisor + return remainder + + register_primitive_function(name=mod_func_name, func_or_type=mod_op) + + +# fast int div and fast int mod's implementation are borrowed from: +# https://github.com/milakov/int_fastdiv +def fast_intdiv(dividend: Expr, divisor: Expr, m: int, s: int, ads: int): + return call_primitive_func('fastintdiv', [dividend, divisor, m, s, ads]) + + +def fast_intmod(dividend: Expr, divisor: Expr, m: int, s: int, ads: int): + return call_primitive_func('fastintdiv', [dividend, divisor, m, s, ads]) diff --git a/python/hidet/ir/primitives/runtime.py b/python/hidet/ir/primitives/runtime.py index e16e55490..a5c1f740e 100644 --- a/python/hidet/ir/primitives/runtime.py +++ b/python/hidet/ir/primitives/runtime.py @@ -56,6 +56,11 @@ def register_functions(): register_primitive_function( name='get_nccl_comm', func_or_type=FuncType([int32], void_p), codegen_name='get_nccl_comm' ) + register_primitive_function( + name='calculate_magic_numbers', + func_or_type=FuncType([int32, int32, int32, int32], void_p), + codegen_name='calculate_magic_numbers', + ) def get_cuda_stream() -> void_p: @@ -96,3 +101,7 @@ def memory_planner_used(idx: Union[int, Expr]): def get_nccl_comm(idx: int) -> void_p: return call_primitive_func('get_nccl_comm', [idx]) + + +def calculate_magic_numbers(divisor: int, m: int, s: int, ads: int): + return call_primitive_func('calculate_magic_numbers', [divisor, m, s, ads]) diff --git a/python/hidet/lang/cuda.py b/python/hidet/lang/cuda.py index f67e6ca93..8aea62b6c 100644 --- a/python/hidet/lang/cuda.py +++ b/python/hidet/lang/cuda.py @@ -26,6 +26,7 @@ from hidet.ir.primitives.cuda.time import nano_sleep from hidet.ir.primitives.cuda.memcpy import memcpy_async, memcpy from hidet.ir.primitives.cuda.atomic import atomic_add, atomic_sub, atomic_min, atomic_max, atomic_exchange, atomic_cas +from hidet.ir.primitives.cuda.fastintdiv import fast_intdiv, fast_intmod from hidet.ir.primitives.cuda.shfl import shfl_sync, shfl_up_sync, shfl_xor_sync, shfl_down_sync from hidet.ir.primitives.cuda.mutex import acquire_lock, release_lock, acquire_seq_semaphore, release_seq_semaphore from hidet.lang.constructs.declare import register_tensor, shared_tensor diff --git a/python/hidet/testing/torch_utils.py b/python/hidet/testing/torch_utils.py index 6440e2efb..eb95aa37e 100644 --- a/python/hidet/testing/torch_utils.py +++ b/python/hidet/testing/torch_utils.py @@ -84,6 +84,7 @@ def init_hidet(self): # hidet.option.debug_cache_tuning(True) # hidet.option.save_lower_ir(True) # hidet.option.debug_show_verbose_flow_graph(True) + # hidet.torch.dynamo_config.dump_graph_ir("./graph_ir") # Initialise compiler server if os.environ.get('CI_CS_HOSTNAME'): diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index 991a7fc1c..ab309a6e2 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -17,6 +17,7 @@ from .attach_hash_to_signature import attach_hash_to_signature from .unify_global_objects import unify_global_objects_pass +from .convert_div_to_fastintdiv import convert_div_to_fastintdiv_pass from .flatten_tensor_slice import flatten_tensor_slice_pass from .flatten_tensor_index import flatten_tensor_index_pass from .generate_launch_func import generate_launch_func_pass @@ -97,6 +98,8 @@ def lower(ir_module: IRModule) -> IRModule: add_explicit_cast_pass(), declare_to_let_pass(), instantiate_symbols_pass(), + convert_div_to_fastintdiv_pass(), + import_primitive_functions_pass(), check_launch_configuration_pass(), # simplification expand_let_expr_pass(), diff --git a/python/hidet/transforms/convert_div_to_fastintdiv.py b/python/hidet/transforms/convert_div_to_fastintdiv.py new file mode 100644 index 000000000..990a671ec --- /dev/null +++ b/python/hidet/transforms/convert_div_to_fastintdiv.py @@ -0,0 +1,189 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Dict +from hidet.ir.expr import Var, Div, Mod +from hidet.ir.func import Function +from hidet.ir.stmt import DeclareStmt, LetStmt +from hidet.ir.module import IRModule +from hidet.ir.functors import IRRewriter, IRVisitor +from hidet.ir.type import FuncType +from hidet.ir.dtypes.integer import i32 +from hidet.transforms import Pass +from hidet.ir.builders import StmtBuilder +from hidet.ir.stmt import LaunchKernelStmt +from hidet.ir.expr import Call, constant +from hidet.ir.primitives.runtime import calculate_magic_numbers +from hidet.logging import logger + + +def is_launch_function(func: Function): + return func.kind == 'public' and 'launch' in func.name + + +def is_kernel_function(func: Function, func_name: str): + return func.kind == "cuda_kernel" and func.name == func_name + + +def is_required_letstmt(stmt: LetStmt): + return ( + len(stmt.bind_values) != 0 + and all(hasattr(bv, 'func_var') for bv in stmt.bind_values) + and all(bv.func_var.name == 'get_symbol_value' for bv in stmt.bind_values) + ) + + +class CollectSymVarsAndFuncNames(IRVisitor): + def __init__(self): + super().__init__() + self.sym_var_names: Optional[List[str]] = None + self.kernel_function_name: Optional[str] = None + + def visit_Function(self, func: Function): + if is_launch_function(func): + super().visit_Function(func) + return func + + def visit_LetStmt(self, stmt: LetStmt): + if not is_required_letstmt(stmt): + logger.warning( + f'public launch function contains LetStmt {stmt} \ + that may be optimized with fast int div' + ) + return stmt + super().visit_LetStmt(stmt) + self.sym_var_names = [bv.hint for bv in stmt.bind_vars] + return stmt + + def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt): + self.kernel_function_name = stmt.func_var.name + return stmt + + +# This visitor creates a filtered symbol var pool that contains only +# symbol vars that are used in the kernel function as a divisor (/ symbolvar) +# or modulus (% symbolvar) +class FilterSymbolVar(IRVisitor): + def __init__(self, sym_var_names: List[str], kernel_function_name: str): + super().__init__() + self.sym_var_names = sym_var_names + self.kernel_function_name = kernel_function_name + self.filtered_sym_var_names = set() + + def visit_Function(self, func: Function): + if is_kernel_function(func, self.kernel_function_name): + super().visit_Function(func) + return func + + def visit_Div(self, e: Div): + if isinstance(e.b, Var) and e.b.hint in self.sym_var_names: + self.filtered_sym_var_names.add(e.b.hint) + return super().visit_Div(e) + + def visit_Mod(self, e: Mod): + if isinstance(e.b, Var) and e.b.hint in self.sym_var_names: + self.filtered_sym_var_names.add(e.b.hint) + return super().visit_Mod(e) + + +class GenerateMagicVarsRewriter(IRRewriter): + def __init__(self, filtered_sym_var_names: List[str]): + super().__init__() + self.magic_vars: Dict[str, List[Var]] = {} + self.filtered_sym_var_names = filtered_sym_var_names + + def visit_Function(self, func: Function): + if is_launch_function(func): + new_func = super().visit_Function(func) + return new_func + return func + + def visit_LetStmt(self, stmt: LetStmt): + if not is_required_letstmt(stmt): + return stmt + sb = StmtBuilder() + for bind_var in stmt.bind_vars: + if bind_var.hint not in self.filtered_sym_var_names: + continue + magic_m = Var(f'magic_m_{bind_var.hint}', bind_var.type, f'magic_m_{bind_var.hint}') + magic_s = Var(f'magic_s_{bind_var.hint}', bind_var.type, f'magic_s_{bind_var.hint}') + magic_as = Var(f'magic_as_{bind_var.hint}', bind_var.type, f'magic_as_{bind_var.hint}') + sb += DeclareStmt(magic_m, constant(0, i32)) + sb += DeclareStmt(magic_s, constant(0, i32)) + sb += DeclareStmt(magic_as, constant(0, i32)) + sb += calculate_magic_numbers(bind_var, magic_m, magic_s, magic_as) + self.magic_vars[bind_var.hint] = [magic_m, magic_s, magic_as] + super().visit_LetStmt(stmt) + sb += stmt.body + stmt.body = sb.finish() + return LetStmt(stmt.bind_vars, stmt.bind_values, stmt.body) + + def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt): + if self.magic_vars: + stmt.args = stmt.args + [item for sublist in list(self.magic_vars.values()) for item in sublist] + return LaunchKernelStmt( + stmt.func_var, stmt.args, stmt.grid_dim, stmt.cluster_dim, stmt.block_dim, stmt.shared_mem_bytes + ) + else: + return stmt + + +class ExpandFunctionParamRewriter(IRRewriter): + def __init__(self, magic_vars: Dict[str, List[Var]], kernel_function_name: str): + super().__init__() + self.magic_vars = magic_vars + self.kernel_function_name = kernel_function_name + + def visit_Function(self, func: Function): + if is_kernel_function(func, self.kernel_function_name): + func = super().visit_Function(func) + func.params = func.params + [item for sublist in list(self.magic_vars.values()) for item in sublist] + return Function(func.name, func.params, func.body, func.ret_type, func.kind, func.attrs) + return func + + def visit_Div(self, e: Div): + if isinstance(e.b, Var) and e.b.hint in self.magic_vars.keys(): + fastdiv_prim = Var('fastintdiv', FuncType([i32, i32, i32, i32, i32], i32), 'fastintdiv') + return Call( + fastdiv_prim, + (e.a, e.b, self.magic_vars[e.b.hint][0], self.magic_vars[e.b.hint][1], self.magic_vars[e.b.hint][2]), + ) + return super().visit_Div(e) + + def visit_Mod(self, e: Mod): + if isinstance(e.b, Var) and e.b.hint in self.magic_vars.keys(): + fastmod_prim = Var('fastintmod', FuncType([i32, i32, i32, i32, i32], i32), 'fastintmod') + return Call( + fastmod_prim, + (e.a, e.b, self.magic_vars[e.b.hint][0], self.magic_vars[e.b.hint][1], self.magic_vars[e.b.hint][2]), + ) + return super().visit_Mod(e) + + +class ConvertDivToFastIntDivPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + collector = CollectSymVarsAndFuncNames() + collector.visit(ir_module) + if collector.sym_var_names is None or collector.kernel_function_name is None: + return ir_module + filter = FilterSymbolVar(collector.sym_var_names, collector.kernel_function_name) + filter.visit(ir_module) + if not filter.filtered_sym_var_names: + return ir_module + generate_rewriter = GenerateMagicVarsRewriter(list(filter.filtered_sym_var_names)) + ir_module = generate_rewriter.visit(ir_module) + expand_rewriter = ExpandFunctionParamRewriter(generate_rewriter.magic_vars, collector.kernel_function_name) + ir_module = expand_rewriter.visit(ir_module) + return ir_module + + +def convert_div_to_fastintdiv_pass() -> Pass: + return ConvertDivToFastIntDivPass() diff --git a/src/hidet/runtime/int_fastdiv.cpp b/src/hidet/runtime/int_fastdiv.cpp new file mode 100644 index 000000000..bd893f1b1 --- /dev/null +++ b/src/hidet/runtime/int_fastdiv.cpp @@ -0,0 +1,64 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +HOST_DEVICE void calculate_magic_numbers(int d, int &m, int &s, int &as) { + if (d == 1) { + m = 0; + s = -1; + as = 1; + return; + } else if (d == -1) { + m = 0; + s = -1; + as = -1; + return; + } + + int p; + unsigned int ad, anc, delta, q1, r1, q2, r2, t; + const unsigned two31 = 0x80000000; + ad = (d == 0) ? 1 : abs(d); + t = two31 + ((unsigned int)d >> 31); + anc = t - 1 - t % ad; + p = 31; + q1 = two31 / anc; + r1 = two31 - q1 * anc; + q2 = two31 / ad; + r2 = two31 - q2 * ad; + do { + ++p; + q1 = 2 * q1; + r1 = 2 * r1; + if (r1 >= anc) { + ++q1; + r1 -= anc; + } + q2 = 2 * q2; + r2 = 2 * r2; + if (r2 >= ad) { + ++q2; + r2 -= ad; + } + delta = ad - r2; + } while (q1 < delta || (q1 == delta && r1 == 0)); + m = q2 + 1; + if (d < 0) m = -m; + s = p - 32; + + if ((d > 0) && (m < 0)) + as = 1; + else if ((d < 0) && (m > 0)) + as = -1; + else + as = 0; +} diff --git a/tests/benchmarks/bench_dynamic.py b/tests/benchmarks/bench_dynamic.py new file mode 100644 index 000000000..7e5218ddb --- /dev/null +++ b/tests/benchmarks/bench_dynamic.py @@ -0,0 +1,74 @@ +import argparse +import torch +from hidet.testing.torch_utils import Backend, bench_torch_model + + +def bench_reduce(backend, mode, dtype, cache): + comp_backend = Backend(backend, mode, dtype, cache) + + class ReduceModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x) + + model = ReduceModule().to(torch.float16).cuda() + example_inputs = torch.randn((1024), device='cuda', dtype=torch.float16) + + torch._dynamo.mark_dynamic(example_inputs, 0) + model_op = comp_backend.compile(model) + model_op(example_inputs) + + +def bench_matmul(backend, mode, dtype, cache): + comp_backend = Backend(backend, mode, dtype, cache) + M = 1024 + N = 1024 + K = 1024 + + class MatMulModule(torch.nn.Module): + def __init__(self): + super(MatMulModule, self).__init__() + self.weight = torch.nn.Parameter(torch.randn(K, N)) + + def forward(self, x): + return torch.matmul(x, self.weight) + + model = MatMulModule().to(torch.float16).cuda() + example_inputs = torch.randn((M, K), device='cuda', dtype=torch.float16) + torch._dynamo.mark_dynamic(example_inputs, 0) + model_op = comp_backend.compile(model) + model_op(example_inputs) + + +def bench_conv(backend, mode, dtype, cache): + comp_backend = Backend(backend, mode, dtype, cache) + N = 2 + C = 3 + H = 224 + W = 224 + + class SingleConv(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): + super(SingleConv, self).__init__() + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + + def forward(self, x): + x = self.conv(x) + return x + + model = SingleConv(3, 32, (3, 3), stride=(2, 2)).to(torch.float16).cuda() + example_inputs = torch.randn((N, C, H, W), device='cuda', dtype=torch.float16) + torch._dynamo.mark_dynamic(example_inputs, 0) + model_op = comp_backend.compile(model) + bench_torch_model(model_op, [example_inputs]) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='Benchmark Dynamic operators') + parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float16') + parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend') + parser.add_argument('--mode', type=str, default='max-autotune', help='torch.compile mode') + parser.add_argument('--cache', type=str, default='', help='') + + args = parser.parse_args() + dtype, backend, mode, cache = args.dtype, args.backend, args.mode, args.cache + bench_conv(backend, mode, dtype, cache) diff --git a/tests/benchmarks/bench_transformer.py b/tests/benchmarks/bench_transformer.py index 2d3ea2cf6..b10029be7 100644 --- a/tests/benchmarks/bench_transformer.py +++ b/tests/benchmarks/bench_transformer.py @@ -117,8 +117,8 @@ def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode, cache): return latency -def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode): - comp_backend = Backend(backend, mode, dtype) +def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode, cache): + comp_backend = Backend(backend, mode, dtype, cache) dtype = getattr(torch, dtype) model_name = get_full_model_name(model_name) @@ -179,7 +179,7 @@ def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode): genlen = int(value) if model_class[get_full_model_name(model_name)] == 'AutoModelForMaskedLM': - latency = bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode) + latency = bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode, cache) elif model_class[get_full_model_name(model_name)] == 'AutoModelForCausalLM': latency = bench_causal_lm( model_name, bs=bs, genlen=genlen, dtype=dtype, backend=backend, mode=mode, cache=cache