Skip to content

Commit

Permalink
[Dynamic][Enhancement] Convert div and mod including symbolvars to fa…
Browse files Browse the repository at this point in the history
…st int div/mod (#464)

Reopen CentML/hidet#405
  • Loading branch information
maxyanghu authored and vadiklyutiy committed Dec 20, 2024
1 parent 6c8ad3e commit c8d9158
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 4 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_library(hidet_runtime SHARED
src/hidet/runtime/callbacks.cpp
src/hidet/runtime/logging.cpp
src/hidet/runtime/symbols.cpp
src/hidet/runtime/int_fastdiv.cpp
src/hidet/runtime/llm/tokenizer/decoders.cpp
src/hidet/runtime/llm/tokenizer/models.cpp
src/hidet/runtime/llm/tokenizer/normalizers.cpp
Expand Down
22 changes: 22 additions & 0 deletions include/hidet/runtime/int_fastdiv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <hidet/runtime/common.h>

#ifdef __CUDA_ARCH__
#define HOST_DEVICE __host__ __device__
#else
#define HOST_DEVICE
#endif

HOST_DEVICE void calculate_magic_numbers(int d, int &m, int &s, int &as);
3 changes: 2 additions & 1 deletion python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cuda/complex.h>') + NewLine()
doc += Text('#include <hidet/runtime/cuda/context.h>') + NewLine()
doc += Text("#include <hidet/runtime/logging.h>") + NewLine()

doc += Text('#include <hidet/runtime/int_fastdiv.h>') + NewLine()
for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()

Expand Down Expand Up @@ -801,6 +801,7 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cpu/context.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/float32.h>') + NewLine()
doc += Text("#include <hidet/runtime/logging.h>") + NewLine()
doc += Text('#include <hidet/runtime/int_fastdiv.h>') + NewLine()

if self.require_complex:
doc += Text('#include <hidet/runtime/cpu/complex.h>') + NewLine()
Expand Down
64 changes: 64 additions & 0 deletions python/hidet/ir/primitives/cuda/fastintdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.dtypes import i32
from hidet.ir.expr import Expr, cast
from hidet.ir.primitives.func import register_primitive_function, call_primitive_func
from hidet.utils import initialize


@initialize()
def register_fastdiv_functions():
from hidet.lang import script, attrs, asm

dtype = i32
div_func_name = 'fastintdiv'

@script
def div_op(dividend: dtype, divisor: dtype, m: dtype, s: dtype, ads: dtype) -> dtype:
attrs.func_kind = 'cuda_internal'
attrs.func_name = div_func_name
q = 0
asm('mul.hi.s32 %0, %1, %2;', outputs=[q], inputs=[m, dividend])
q = q + dividend * ads
if s >= 0:
q = q >> s
q = q + (cast(q, 'uint32') >> 31)
return q

register_primitive_function(name=div_func_name, func_or_type=div_op)

mod_func_name = 'fastintmod'

@script
def mod_op(dividend: dtype, divisor: dtype, m: dtype, s: dtype, ads: dtype) -> dtype:
attrs.func_kind = 'cuda_internal'
attrs.func_name = mod_func_name
q = 0
asm('mul.hi.s32 %0, %1, %2;', outputs=[q], inputs=[m, dividend])
q = q + dividend * ads
if s >= 0:
q = q >> s
q = q + (cast(q, 'uint32') >> 31)
remainder = dividend - q * divisor
return remainder

register_primitive_function(name=mod_func_name, func_or_type=mod_op)


# fast int div and fast int mod's implementation are borrowed from:
# https://github.com/milakov/int_fastdiv
def fast_intdiv(dividend: Expr, divisor: Expr, m: int, s: int, ads: int):
return call_primitive_func('fastintdiv', [dividend, divisor, m, s, ads])


def fast_intmod(dividend: Expr, divisor: Expr, m: int, s: int, ads: int):
return call_primitive_func('fastintdiv', [dividend, divisor, m, s, ads])
9 changes: 9 additions & 0 deletions python/hidet/ir/primitives/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def register_functions():
register_primitive_function(
name='get_nccl_comm', func_or_type=FuncType([int32], void_p), codegen_name='get_nccl_comm'
)
register_primitive_function(
name='calculate_magic_numbers',
func_or_type=FuncType([int32, int32, int32, int32], void_p),
codegen_name='calculate_magic_numbers',
)


def get_cuda_stream() -> void_p:
Expand Down Expand Up @@ -96,3 +101,7 @@ def memory_planner_used(idx: Union[int, Expr]):

def get_nccl_comm(idx: int) -> void_p:
return call_primitive_func('get_nccl_comm', [idx])


def calculate_magic_numbers(divisor: int, m: int, s: int, ads: int):
return call_primitive_func('calculate_magic_numbers', [divisor, m, s, ads])
1 change: 1 addition & 0 deletions python/hidet/lang/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from hidet.ir.primitives.cuda.time import nano_sleep
from hidet.ir.primitives.cuda.memcpy import memcpy_async, memcpy
from hidet.ir.primitives.cuda.atomic import atomic_add, atomic_sub, atomic_min, atomic_max, atomic_exchange, atomic_cas
from hidet.ir.primitives.cuda.fastintdiv import fast_intdiv, fast_intmod
from hidet.ir.primitives.cuda.shfl import shfl_sync, shfl_up_sync, shfl_xor_sync, shfl_down_sync
from hidet.ir.primitives.cuda.mutex import acquire_lock, release_lock, acquire_seq_semaphore, release_seq_semaphore
from hidet.lang.constructs.declare import register_tensor, shared_tensor
Expand Down
1 change: 1 addition & 0 deletions python/hidet/testing/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def init_hidet(self):
# hidet.option.debug_cache_tuning(True)
# hidet.option.save_lower_ir(True)
# hidet.option.debug_show_verbose_flow_graph(True)
# hidet.torch.dynamo_config.dump_graph_ir("./graph_ir")

# Initialise compiler server
if os.environ.get('CI_CS_HOSTNAME'):
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .attach_hash_to_signature import attach_hash_to_signature
from .unify_global_objects import unify_global_objects_pass
from .convert_div_to_fastintdiv import convert_div_to_fastintdiv_pass
from .flatten_tensor_slice import flatten_tensor_slice_pass
from .flatten_tensor_index import flatten_tensor_index_pass
from .generate_launch_func import generate_launch_func_pass
Expand Down Expand Up @@ -97,6 +98,8 @@ def lower(ir_module: IRModule) -> IRModule:
add_explicit_cast_pass(),
declare_to_let_pass(),
instantiate_symbols_pass(),
convert_div_to_fastintdiv_pass(),
import_primitive_functions_pass(),
check_launch_configuration_pass(),
# simplification
expand_let_expr_pass(),
Expand Down
189 changes: 189 additions & 0 deletions python/hidet/transforms/convert_div_to_fastintdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Dict
from hidet.ir.expr import Var, Div, Mod
from hidet.ir.func import Function
from hidet.ir.stmt import DeclareStmt, LetStmt
from hidet.ir.module import IRModule
from hidet.ir.functors import IRRewriter, IRVisitor
from hidet.ir.type import FuncType
from hidet.ir.dtypes.integer import i32
from hidet.transforms import Pass
from hidet.ir.builders import StmtBuilder
from hidet.ir.stmt import LaunchKernelStmt
from hidet.ir.expr import Call, constant
from hidet.ir.primitives.runtime import calculate_magic_numbers
from hidet.logging import logger


def is_launch_function(func: Function):
return func.kind == 'public' and 'launch' in func.name


def is_kernel_function(func: Function, func_name: str):
return func.kind == "cuda_kernel" and func.name == func_name


def is_required_letstmt(stmt: LetStmt):
return (
len(stmt.bind_values) != 0
and all(hasattr(bv, 'func_var') for bv in stmt.bind_values)
and all(bv.func_var.name == 'get_symbol_value' for bv in stmt.bind_values)
)


class CollectSymVarsAndFuncNames(IRVisitor):
def __init__(self):
super().__init__()
self.sym_var_names: Optional[List[str]] = None
self.kernel_function_name: Optional[str] = None

def visit_Function(self, func: Function):
if is_launch_function(func):
super().visit_Function(func)
return func

def visit_LetStmt(self, stmt: LetStmt):
if not is_required_letstmt(stmt):
logger.warning(
f'public launch function contains LetStmt {stmt} \
that may be optimized with fast int div'
)
return stmt
super().visit_LetStmt(stmt)
self.sym_var_names = [bv.hint for bv in stmt.bind_vars]
return stmt

def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
self.kernel_function_name = stmt.func_var.name
return stmt


# This visitor creates a filtered symbol var pool that contains only
# symbol vars that are used in the kernel function as a divisor (/ symbolvar)
# or modulus (% symbolvar)
class FilterSymbolVar(IRVisitor):
def __init__(self, sym_var_names: List[str], kernel_function_name: str):
super().__init__()
self.sym_var_names = sym_var_names
self.kernel_function_name = kernel_function_name
self.filtered_sym_var_names = set()

def visit_Function(self, func: Function):
if is_kernel_function(func, self.kernel_function_name):
super().visit_Function(func)
return func

def visit_Div(self, e: Div):
if isinstance(e.b, Var) and e.b.hint in self.sym_var_names:
self.filtered_sym_var_names.add(e.b.hint)
return super().visit_Div(e)

def visit_Mod(self, e: Mod):
if isinstance(e.b, Var) and e.b.hint in self.sym_var_names:
self.filtered_sym_var_names.add(e.b.hint)
return super().visit_Mod(e)


class GenerateMagicVarsRewriter(IRRewriter):
def __init__(self, filtered_sym_var_names: List[str]):
super().__init__()
self.magic_vars: Dict[str, List[Var]] = {}
self.filtered_sym_var_names = filtered_sym_var_names

def visit_Function(self, func: Function):
if is_launch_function(func):
new_func = super().visit_Function(func)
return new_func
return func

def visit_LetStmt(self, stmt: LetStmt):
if not is_required_letstmt(stmt):
return stmt
sb = StmtBuilder()
for bind_var in stmt.bind_vars:
if bind_var.hint not in self.filtered_sym_var_names:
continue
magic_m = Var(f'magic_m_{bind_var.hint}', bind_var.type, f'magic_m_{bind_var.hint}')
magic_s = Var(f'magic_s_{bind_var.hint}', bind_var.type, f'magic_s_{bind_var.hint}')
magic_as = Var(f'magic_as_{bind_var.hint}', bind_var.type, f'magic_as_{bind_var.hint}')
sb += DeclareStmt(magic_m, constant(0, i32))
sb += DeclareStmt(magic_s, constant(0, i32))
sb += DeclareStmt(magic_as, constant(0, i32))
sb += calculate_magic_numbers(bind_var, magic_m, magic_s, magic_as)
self.magic_vars[bind_var.hint] = [magic_m, magic_s, magic_as]
super().visit_LetStmt(stmt)
sb += stmt.body
stmt.body = sb.finish()
return LetStmt(stmt.bind_vars, stmt.bind_values, stmt.body)

def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
if self.magic_vars:
stmt.args = stmt.args + [item for sublist in list(self.magic_vars.values()) for item in sublist]
return LaunchKernelStmt(
stmt.func_var, stmt.args, stmt.grid_dim, stmt.cluster_dim, stmt.block_dim, stmt.shared_mem_bytes
)
else:
return stmt


class ExpandFunctionParamRewriter(IRRewriter):
def __init__(self, magic_vars: Dict[str, List[Var]], kernel_function_name: str):
super().__init__()
self.magic_vars = magic_vars
self.kernel_function_name = kernel_function_name

def visit_Function(self, func: Function):
if is_kernel_function(func, self.kernel_function_name):
func = super().visit_Function(func)
func.params = func.params + [item for sublist in list(self.magic_vars.values()) for item in sublist]
return Function(func.name, func.params, func.body, func.ret_type, func.kind, func.attrs)
return func

def visit_Div(self, e: Div):
if isinstance(e.b, Var) and e.b.hint in self.magic_vars.keys():
fastdiv_prim = Var('fastintdiv', FuncType([i32, i32, i32, i32, i32], i32), 'fastintdiv')
return Call(
fastdiv_prim,
(e.a, e.b, self.magic_vars[e.b.hint][0], self.magic_vars[e.b.hint][1], self.magic_vars[e.b.hint][2]),
)
return super().visit_Div(e)

def visit_Mod(self, e: Mod):
if isinstance(e.b, Var) and e.b.hint in self.magic_vars.keys():
fastmod_prim = Var('fastintmod', FuncType([i32, i32, i32, i32, i32], i32), 'fastintmod')
return Call(
fastmod_prim,
(e.a, e.b, self.magic_vars[e.b.hint][0], self.magic_vars[e.b.hint][1], self.magic_vars[e.b.hint][2]),
)
return super().visit_Mod(e)


class ConvertDivToFastIntDivPass(Pass):
def process_module(self, ir_module: IRModule) -> IRModule:
collector = CollectSymVarsAndFuncNames()
collector.visit(ir_module)
if collector.sym_var_names is None or collector.kernel_function_name is None:
return ir_module
filter = FilterSymbolVar(collector.sym_var_names, collector.kernel_function_name)
filter.visit(ir_module)
if not filter.filtered_sym_var_names:
return ir_module
generate_rewriter = GenerateMagicVarsRewriter(list(filter.filtered_sym_var_names))
ir_module = generate_rewriter.visit(ir_module)
expand_rewriter = ExpandFunctionParamRewriter(generate_rewriter.magic_vars, collector.kernel_function_name)
ir_module = expand_rewriter.visit(ir_module)
return ir_module


def convert_div_to_fastintdiv_pass() -> Pass:
return ConvertDivToFastIntDivPass()
Loading

0 comments on commit c8d9158

Please sign in to comment.