diff --git a/python/hidet/backend/build.py b/python/hidet/backend/build.py index 327df1c89..0022a4acb 100644 --- a/python/hidet/backend/build.py +++ b/python/hidet/backend/build.py @@ -122,6 +122,8 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str, *['-I{}'.format(include_dir) for include_dir in self.include_dirs], # the library directories. *['-L{}'.format(library_dir) for library_dir in self.library_dirs], + # enable openmp support for cpu kernels + '-Xcompiler -fopenmp', # the target PTX and SASS version. '-gencode arch=compute_{cc},code=sm_{cc}'.format(cc=cc_code), # allow ptxas (PTX assembler) to output information like register/smem usage. @@ -181,6 +183,8 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str, '-O3', # compile into position independent code. '-fPIC', + # enable OpenMP. + '-fopenmp', # link the hidet runtime, all APIs for communication between kernels and host system are in hidet runtime. '-lhidet_runtime', # generate shared library (lib.so). diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index 3a425573c..c617856dd 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -412,16 +412,17 @@ def visit_ForStmt(self, stmt: ForStmt): cond_doc = self(v < stmt.extent) update_doc = self(v) + ' = ' + self(v + 1) doc = Text('') - if stmt.attr.unroll is not None: - assert not stmt.attr.explicit_unroll, 'explicit_unroll should be lowered before codegen' - if isinstance(stmt.attr.unroll, bool): - if stmt.attr.unroll: - doc += NewLine() + '#pragma unroll' # complete unroll - else: - doc += NewLine() + '#pragma unroll 1' # prevent from unrolling + if stmt.attr.unroll: + assert not stmt.attr.unroll_explicit, 'explicit_unroll should be lowered before codegen' + if stmt.attr.unroll_factor: + doc += NewLine() + '#pragma unroll {}'.format(stmt.attr.unroll_factor) else: - assert isinstance(stmt.attr.unroll, int) - doc += NewLine() + '#pragma unroll {}'.format(stmt.attr.unroll) + doc += NewLine() + '#pragma unroll' + elif stmt.attr.parallel: + if stmt.attr.parallel_threads: + doc += NewLine() + '#pragma omp parallel for num_threads({})'.format(stmt.attr.parallel_threads) + else: + doc += NewLine() + '#pragma omp parallel for' doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') ' body_doc = self(stmt.body) doc += Text('{') + body_doc.indent() + NewLine() + Text('} ') diff --git a/python/hidet/ir/stmt.py b/python/hidet/ir/stmt.py index d91098513..e0cec7a23 100644 --- a/python/hidet/ir/stmt.py +++ b/python/hidet/ir/stmt.py @@ -42,77 +42,24 @@ def from_str(name): return DeclareScope.Default -class Stmt(Node): - pass - - -class EvaluateStmt(Stmt): - def __init__(self, expr): - super().__init__() - self.expr: Expr = convert(expr) - - -class DeclareStmt(Stmt): - def __init__(self, var, init: Optional[Expr] = None, is_static=False, scope: Optional[DeclareScope] = None): - super().__init__() - self.var: Var = var - self.init: Optional[Expr] = convert(init) - self.is_static: bool = is_static - self.scope: Optional[DeclareScope] = scope if scope else DeclareScope.Default - - -class BufferStoreStmt(Stmt): - def __init__(self, buf, indices, value, protected=False): - super().__init__() - assert isinstance(indices, (list, tuple)), type(indices) - self.buf: Union[Var, TensorNode] = buf - self.indices = convert(indices) - self.value = convert(value) - self.protected = protected - - -class AssignStmt(Stmt): - def __init__(self, var, value): - super().__init__() - self.var: Var = var - self.value: Expr = convert(value) - - -class ReturnStmt(Stmt): - def __init__(self, ret_value: Optional[Expr] = None): - super().__init__() - self.ret_value: Optional[Expr] = ret_value - - -class LetStmt(Stmt): - def __init__(self, bind_vars, bind_values, body=None): - if not isinstance(bind_vars, (list, tuple)): - bind_vars = [bind_vars] - if not isinstance(bind_values, (list, tuple)): - bind_values = [bind_values] - assert len(bind_vars) == len(bind_values) - assert len(bind_vars) > 0 - bind_values = [convert(bind_value) for bind_value in bind_values] - self.bind_vars: List[Var] = bind_vars - self.bind_values: List[Expr] = bind_values - self.body: Optional[Stmt] = body - - class ForStmtAttr: - def __init__(self, unroll=None, explicit_unroll=False): - self.unroll: Union[int, bool, None] = unroll - self.explicit_unroll: bool = explicit_unroll + def __init__(self, unroll=False, unroll_factor=None, unroll_explicit=False, parallel=False, parallel_threads=None): + self.unroll: bool = unroll + self.unroll_factor: Optional[int] = unroll_factor + self.unroll_explicit: bool = unroll_explicit + self.parallel: bool = parallel + self.parallel_threads: Optional[int] = parallel_threads def __str__(self): if self.unroll is None: return '.' elif isinstance(self.unroll, bool): - if self.explicit_unroll: + if self.unroll_explicit: return 'u+' else: return 'u' else: - if self.explicit_unroll: + if self.unroll_explicit: return f'u{self.unroll}+' else: return f'u{self.unroll}' @@ -125,11 +72,15 @@ def parse(attr: str) -> List[ForStmtAttr]: attr-string: attr+ attr: | unroll + | parallel | default unroll: | 'u' # unroll | 'u' INT+ # unroll with factor, e.g., u1 u2 u3. u1 indicates unroll with factor 1 (i.e., no unroll) | 'u' '+' # explicit unroll, will be unrolled by hidet instead of underlying compiler + parallel: + | 'p' # parallel with available number of threads + | 'p' INT+ # parallel with specified number of threads default: '.' @@ -155,29 +106,99 @@ def cur() -> Optional[str]: while idx < len(s): if s[idx] == '.': idx += 1 - attrs.append(ForStmtAttr(unroll=None, explicit_unroll=False)) + attrs.append(ForStmtAttr()) elif s[idx] == 'u': idx += 1 c = cur() if c == '+': - attrs.append(ForStmtAttr(unroll=True, explicit_unroll=True)) + attrs.append(ForStmtAttr(unroll=True, unroll_explicit=True)) idx += 1 elif c and c.isdigit(): - unroll = 0 + unroll_factor = 0 while c and c.isdigit(): - unroll = unroll * 10 + int(c) + unroll_factor = unroll_factor * 10 + int(c) idx += 1 c = cur() - if unroll == 0: + if unroll_factor == 0: raise ValueError(f"Invalid attribute string: {attr}") - attrs.append(ForStmtAttr(unroll=unroll, explicit_unroll=False)) + attrs.append(ForStmtAttr(unroll=True, unroll_factor=unroll_factor)) else: - attrs.append(ForStmtAttr(unroll=True, explicit_unroll=False)) + attrs.append(ForStmtAttr(unroll=True, unroll_explicit=False)) + elif s[idx] == 'p': + idx += 1 + c = cur() + if c and c.isdigit(): + parallel_threads = 0 + while c and c.isdigit(): + parallel_threads = parallel_threads * 10 + int(c) + idx += 1 + c = cur() + if parallel_threads == 0: + raise ValueError(f"Invalid attribute string: {attr}") + attrs.append(ForStmtAttr(parallel=True, parallel_threads=parallel_threads)) + else: + attrs.append(ForStmtAttr(parallel=True)) else: raise ValueError(f"Invalid attribute string: {attr}") return attrs +class Stmt(Node): + pass + + +class EvaluateStmt(Stmt): + def __init__(self, expr): + super().__init__() + self.expr: Expr = convert(expr) + + +class DeclareStmt(Stmt): + def __init__(self, var, init: Optional[Expr] = None, is_static=False, scope: Optional[DeclareScope] = None): + super().__init__() + self.var: Var = var + self.init: Optional[Expr] = convert(init) + self.is_static: bool = is_static + self.scope: Optional[DeclareScope] = scope if scope else DeclareScope.Default + + +class BufferStoreStmt(Stmt): + def __init__(self, buf, indices, value, protected=False): + super().__init__() + assert isinstance(indices, (list, tuple)), type(indices) + self.buf: Union[Var, TensorNode] = buf + self.indices = convert(indices) + self.value = convert(value) + self.protected = protected + + +class AssignStmt(Stmt): + def __init__(self, var, value): + super().__init__() + self.var: Var = var + self.value: Expr = convert(value) + + +class ReturnStmt(Stmt): + def __init__(self, ret_value: Optional[Expr] = None): + super().__init__() + self.ret_value: Optional[Expr] = ret_value + + +class LetStmt(Stmt): + def __init__(self, bind_vars, bind_values, body=None): + if not isinstance(bind_vars, (list, tuple)): + bind_vars = [bind_vars] + if not isinstance(bind_values, (list, tuple)): + bind_values = [bind_values] + assert len(bind_vars) == len(bind_values) + assert len(bind_vars) > 0 + bind_values = [convert(bind_value) for bind_value in bind_values] + self.bind_vars: List[Var] = bind_vars + self.bind_values: List[Expr] = bind_values + self.body: Optional[Stmt] = body + + class ForStmt(Stmt): DEFAULT_UNROLL_LIMIT = 32 diff --git a/python/hidet/lang/__init__.py b/python/hidet/lang/__init__.py index 5dd14389f..ff8cd879e 100644 --- a/python/hidet/lang/__init__.py +++ b/python/hidet/lang/__init__.py @@ -71,11 +71,11 @@ def grid(*dim_extents, attrs: Optional[str] = None): Parameters ---------- - dim_extents: Sequence[Expr or int] - The length of each dimension. + dim_extents: Sequence[Expr or int or str] + The length of each dimension. The last one can be the attrs. attrs: Optional[str] - The attributes of each loop. See ForStmtAttr for more information. + The attributes of each loop. See hidet.stmt.ForStmtAttr for more information. Returns ------- diff --git a/python/hidet/lang/transpiler.py b/python/hidet/lang/transpiler.py index ca4d3a024..22a61a41e 100644 --- a/python/hidet/lang/transpiler.py +++ b/python/hidet/lang/transpiler.py @@ -776,11 +776,17 @@ def declare_loop_vars(num: int): attr_string = self.visit(keyword.value) else: raise HidetProgramError(self, call, 'Can not recognize keyword argument: {}.'.format(keyword.arg)) + extents = [self.visit(arg) for arg in call.args] + if len(extents) > 0 and isinstance(extents[-1], str): + if attr_string is not None: + raise HidetProgramError( + self, call, 'Can not specify attrs in both positional and keyword arguments.' + ) + attr_string = extents.pop() if attr_string is None: attrs = [ForStmtAttr() for _ in range(len(call.args))] else: attrs = ForStmtAttr.parse(attr_string) - extents = [self.visit(arg) for arg in call.args] declare_loop_vars(num=len(extents)) body = visit_body() for loop_var, extent, attr in zip(reversed(loop_vars), reversed(extents), reversed(attrs)): diff --git a/python/hidet/transforms/explicit_unroll.py b/python/hidet/transforms/explicit_unroll.py index 2fde3b8f9..a25531af3 100644 --- a/python/hidet/transforms/explicit_unroll.py +++ b/python/hidet/transforms/explicit_unroll.py @@ -24,7 +24,7 @@ class ExplicitUnrollRewriter(IRRewriter): def visit_ForStmt(self, stmt: ForStmt): - if stmt.attr.unroll and stmt.attr.explicit_unroll: + if stmt.attr.unroll and stmt.attr.unroll_explicit: if not isinstance(stmt.attr.unroll, bool): raise NotImplementedError('Explicit unroll with unroll factor is not supported yet') extent_expr: Expr = simplify(stmt.extent) diff --git a/tests/lang/scripts/test_parallel.py b/tests/lang/scripts/test_parallel.py new file mode 100644 index 000000000..4458557df --- /dev/null +++ b/tests/lang/scripts/test_parallel.py @@ -0,0 +1,93 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import numpy.testing +import hidet + + +def test_parallel(): + from hidet.lang import printf, attr, grid, repeat, tensor + + with hidet.script_module() as script_module: + + @hidet.script + def example(): + attr.func_kind = 'host_kernel' + a = tensor('global', 'float32', shape=[10]) + + for i in grid(10, attrs='p'): # unroll + a[i] = i + + for i in grid(10, attrs='p2'): # unroll explicitly + a[i] = i + + extent = 10 + for i in grid(extent, attrs='u+'): # explicit unroll, extent must be a compilation-time constant + printf("i = %d\n", i) + + for i, j in grid(2, 5, attrs='pp'): # unroll the first loop while keep the second loop unchanged + a[i * 5 + j] = i + + b = tensor('global', 'float32', shape=[8, 64]) + for w in range(32): + for i, j in repeat(2, 8).spatial(4, 8).on(w): + b[i, j] = i + + for i, j in repeat(2, 8, attrs='pp').spatial(4, 8).on(w): + b[i, j] = i + + for i, j in repeat(2, 8, attrs='p.').spatial(4, 8).on(w): + b[i, j] = i + + for i, j in repeat(2, 8, attrs='.p').spatial(4, 8).on(w): + b[i, j] = i + + ir_module = script_module.ir_module() + func = hidet.driver.build_ir_module(ir_module) + source_code = func.source() + assert "#pragma omp parallel" in source_code + return func + + +def matmul(m_size, n_size, k_size): + from hidet.lang import grid, attr, f32 + from hidet.lang.mapping import spatial + + with hidet.script_module() as script_module: + + @hidet.script + def matmul(a: f32[m_size, k_size], b: f32[k_size, n_size], c: f32[m_size, n_size]): + attr.func_kind = 'host_kernel' + ij_size = m_size * n_size + for ij in grid(ij_size, 'p'): + for i, j in spatial(m_size, n_size).on(ij): + c[i, j] = 0.0 + for k in range(k_size): + c[i, j] += a[i, k] * b[k, j] + + ir_module = script_module.ir_module() + return hidet.driver.build_ir_module(ir_module) + + +def test_parallel_v2(): + m_size, n_size, k_size = 32, 32, 32 + func = matmul(m_size, n_size, k_size) + a = hidet.randn((m_size, k_size)) + b = hidet.randn((k_size, n_size)) + c = hidet.empty((m_size, n_size)) + cc = a @ b + func(a, b, c) + numpy.testing.assert_allclose(c.numpy(), cc.numpy(), atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + pytest.main([__file__])