Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Add openmp support #208

Merged
merged 1 commit into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
# enable openmp support for cpu kernels
'-Xcompiler -fopenmp',
# the target PTX and SASS version.
'-gencode arch=compute_{cc},code=sm_{cc}'.format(cc=cc_code),
# allow ptxas (PTX assembler) to output information like register/smem usage.
Expand Down Expand Up @@ -181,6 +183,8 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
'-O3',
# compile into position independent code.
'-fPIC',
# enable OpenMP.
'-fopenmp',
# link the hidet runtime, all APIs for communication between kernels and host system are in hidet runtime.
'-lhidet_runtime',
# generate shared library (lib.so).
Expand Down
19 changes: 10 additions & 9 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,17 @@ def visit_ForStmt(self, stmt: ForStmt):
cond_doc = self(v < stmt.extent)
update_doc = self(v) + ' = ' + self(v + 1)
doc = Text('')
if stmt.attr.unroll is not None:
assert not stmt.attr.explicit_unroll, 'explicit_unroll should be lowered before codegen'
if isinstance(stmt.attr.unroll, bool):
if stmt.attr.unroll:
doc += NewLine() + '#pragma unroll' # complete unroll
else:
doc += NewLine() + '#pragma unroll 1' # prevent from unrolling
if stmt.attr.unroll:
assert not stmt.attr.unroll_explicit, 'explicit_unroll should be lowered before codegen'
if stmt.attr.unroll_factor:
doc += NewLine() + '#pragma unroll {}'.format(stmt.attr.unroll_factor)
else:
assert isinstance(stmt.attr.unroll, int)
doc += NewLine() + '#pragma unroll {}'.format(stmt.attr.unroll)
doc += NewLine() + '#pragma unroll'
elif stmt.attr.parallel:
if stmt.attr.parallel_threads:
doc += NewLine() + '#pragma omp parallel for num_threads({})'.format(stmt.attr.parallel_threads)
else:
doc += NewLine() + '#pragma omp parallel for'
doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') '
body_doc = self(stmt.body)
doc += Text('{') + body_doc.indent() + NewLine() + Text('} ')
Expand Down
157 changes: 89 additions & 68 deletions python/hidet/ir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,77 +42,24 @@ def from_str(name):
return DeclareScope.Default


class Stmt(Node):
pass


class EvaluateStmt(Stmt):
def __init__(self, expr):
super().__init__()
self.expr: Expr = convert(expr)


class DeclareStmt(Stmt):
def __init__(self, var, init: Optional[Expr] = None, is_static=False, scope: Optional[DeclareScope] = None):
super().__init__()
self.var: Var = var
self.init: Optional[Expr] = convert(init)
self.is_static: bool = is_static
self.scope: Optional[DeclareScope] = scope if scope else DeclareScope.Default


class BufferStoreStmt(Stmt):
def __init__(self, buf, indices, value, protected=False):
super().__init__()
assert isinstance(indices, (list, tuple)), type(indices)
self.buf: Union[Var, TensorNode] = buf
self.indices = convert(indices)
self.value = convert(value)
self.protected = protected


class AssignStmt(Stmt):
def __init__(self, var, value):
super().__init__()
self.var: Var = var
self.value: Expr = convert(value)


class ReturnStmt(Stmt):
def __init__(self, ret_value: Optional[Expr] = None):
super().__init__()
self.ret_value: Optional[Expr] = ret_value


class LetStmt(Stmt):
def __init__(self, bind_vars, bind_values, body=None):
if not isinstance(bind_vars, (list, tuple)):
bind_vars = [bind_vars]
if not isinstance(bind_values, (list, tuple)):
bind_values = [bind_values]
assert len(bind_vars) == len(bind_values)
assert len(bind_vars) > 0
bind_values = [convert(bind_value) for bind_value in bind_values]
self.bind_vars: List[Var] = bind_vars
self.bind_values: List[Expr] = bind_values
self.body: Optional[Stmt] = body


class ForStmtAttr:
def __init__(self, unroll=None, explicit_unroll=False):
self.unroll: Union[int, bool, None] = unroll
self.explicit_unroll: bool = explicit_unroll
def __init__(self, unroll=False, unroll_factor=None, unroll_explicit=False, parallel=False, parallel_threads=None):
self.unroll: bool = unroll
self.unroll_factor: Optional[int] = unroll_factor
self.unroll_explicit: bool = unroll_explicit
self.parallel: bool = parallel
self.parallel_threads: Optional[int] = parallel_threads

def __str__(self):
if self.unroll is None:
return '.'
elif isinstance(self.unroll, bool):
if self.explicit_unroll:
if self.unroll_explicit:
return 'u+'
else:
return 'u'
else:
if self.explicit_unroll:
if self.unroll_explicit:
return f'u{self.unroll}+'
else:
return f'u{self.unroll}'
Expand All @@ -125,11 +72,15 @@ def parse(attr: str) -> List[ForStmtAttr]:
attr-string: attr+
attr:
| unroll
| parallel
| default
unroll:
| 'u' # unroll
| 'u' INT+ # unroll with factor, e.g., u1 u2 u3. u1 indicates unroll with factor 1 (i.e., no unroll)
| 'u' '+' # explicit unroll, will be unrolled by hidet instead of underlying compiler
parallel:
| 'p' # parallel with available number of threads
| 'p' INT+ # parallel with specified number of threads
default: '.'


Expand All @@ -155,29 +106,99 @@ def cur() -> Optional[str]:
while idx < len(s):
if s[idx] == '.':
idx += 1
attrs.append(ForStmtAttr(unroll=None, explicit_unroll=False))
attrs.append(ForStmtAttr())
elif s[idx] == 'u':
idx += 1
c = cur()
if c == '+':
attrs.append(ForStmtAttr(unroll=True, explicit_unroll=True))
attrs.append(ForStmtAttr(unroll=True, unroll_explicit=True))
idx += 1
elif c and c.isdigit():
unroll = 0
unroll_factor = 0
while c and c.isdigit():
unroll = unroll * 10 + int(c)
unroll_factor = unroll_factor * 10 + int(c)
idx += 1
c = cur()
if unroll == 0:
if unroll_factor == 0:
raise ValueError(f"Invalid attribute string: {attr}")
attrs.append(ForStmtAttr(unroll=unroll, explicit_unroll=False))
attrs.append(ForStmtAttr(unroll=True, unroll_factor=unroll_factor))
else:
attrs.append(ForStmtAttr(unroll=True, explicit_unroll=False))
attrs.append(ForStmtAttr(unroll=True, unroll_explicit=False))
elif s[idx] == 'p':
idx += 1
c = cur()
if c and c.isdigit():
parallel_threads = 0
while c and c.isdigit():
parallel_threads = parallel_threads * 10 + int(c)
idx += 1
c = cur()
if parallel_threads == 0:
raise ValueError(f"Invalid attribute string: {attr}")
attrs.append(ForStmtAttr(parallel=True, parallel_threads=parallel_threads))
else:
attrs.append(ForStmtAttr(parallel=True))
else:
raise ValueError(f"Invalid attribute string: {attr}")
return attrs


class Stmt(Node):
pass


class EvaluateStmt(Stmt):
def __init__(self, expr):
super().__init__()
self.expr: Expr = convert(expr)


class DeclareStmt(Stmt):
def __init__(self, var, init: Optional[Expr] = None, is_static=False, scope: Optional[DeclareScope] = None):
super().__init__()
self.var: Var = var
self.init: Optional[Expr] = convert(init)
self.is_static: bool = is_static
self.scope: Optional[DeclareScope] = scope if scope else DeclareScope.Default


class BufferStoreStmt(Stmt):
def __init__(self, buf, indices, value, protected=False):
super().__init__()
assert isinstance(indices, (list, tuple)), type(indices)
self.buf: Union[Var, TensorNode] = buf
self.indices = convert(indices)
self.value = convert(value)
self.protected = protected


class AssignStmt(Stmt):
def __init__(self, var, value):
super().__init__()
self.var: Var = var
self.value: Expr = convert(value)


class ReturnStmt(Stmt):
def __init__(self, ret_value: Optional[Expr] = None):
super().__init__()
self.ret_value: Optional[Expr] = ret_value


class LetStmt(Stmt):
def __init__(self, bind_vars, bind_values, body=None):
if not isinstance(bind_vars, (list, tuple)):
bind_vars = [bind_vars]
if not isinstance(bind_values, (list, tuple)):
bind_values = [bind_values]
assert len(bind_vars) == len(bind_values)
assert len(bind_vars) > 0
bind_values = [convert(bind_value) for bind_value in bind_values]
self.bind_vars: List[Var] = bind_vars
self.bind_values: List[Expr] = bind_values
self.body: Optional[Stmt] = body


class ForStmt(Stmt):
DEFAULT_UNROLL_LIMIT = 32

Expand Down
6 changes: 3 additions & 3 deletions python/hidet/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def grid(*dim_extents, attrs: Optional[str] = None):

Parameters
----------
dim_extents: Sequence[Expr or int]
The length of each dimension.
dim_extents: Sequence[Expr or int or str]
The length of each dimension. The last one can be the attrs.

attrs: Optional[str]
The attributes of each loop. See ForStmtAttr for more information.
The attributes of each loop. See hidet.stmt.ForStmtAttr for more information.

Returns
-------
Expand Down
8 changes: 7 additions & 1 deletion python/hidet/lang/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,17 @@ def declare_loop_vars(num: int):
attr_string = self.visit(keyword.value)
else:
raise HidetProgramError(self, call, 'Can not recognize keyword argument: {}.'.format(keyword.arg))
extents = [self.visit(arg) for arg in call.args]
if len(extents) > 0 and isinstance(extents[-1], str):
if attr_string is not None:
raise HidetProgramError(
self, call, 'Can not specify attrs in both positional and keyword arguments.'
)
attr_string = extents.pop()
if attr_string is None:
attrs = [ForStmtAttr() for _ in range(len(call.args))]
else:
attrs = ForStmtAttr.parse(attr_string)
extents = [self.visit(arg) for arg in call.args]
declare_loop_vars(num=len(extents))
body = visit_body()
for loop_var, extent, attr in zip(reversed(loop_vars), reversed(extents), reversed(attrs)):
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/transforms/explicit_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class ExplicitUnrollRewriter(IRRewriter):
def visit_ForStmt(self, stmt: ForStmt):
if stmt.attr.unroll and stmt.attr.explicit_unroll:
if stmt.attr.unroll and stmt.attr.unroll_explicit:
if not isinstance(stmt.attr.unroll, bool):
raise NotImplementedError('Explicit unroll with unroll factor is not supported yet')
extent_expr: Expr = simplify(stmt.extent)
Expand Down
93 changes: 93 additions & 0 deletions tests/lang/scripts/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import numpy.testing
import hidet


def test_parallel():
from hidet.lang import printf, attr, grid, repeat, tensor

with hidet.script_module() as script_module:

@hidet.script
def example():
attr.func_kind = 'host_kernel'
a = tensor('global', 'float32', shape=[10])

for i in grid(10, attrs='p'): # unroll
a[i] = i

for i in grid(10, attrs='p2'): # unroll explicitly
a[i] = i

extent = 10
for i in grid(extent, attrs='u+'): # explicit unroll, extent must be a compilation-time constant
printf("i = %d\n", i)

for i, j in grid(2, 5, attrs='pp'): # unroll the first loop while keep the second loop unchanged
a[i * 5 + j] = i

b = tensor('global', 'float32', shape=[8, 64])
for w in range(32):
for i, j in repeat(2, 8).spatial(4, 8).on(w):
b[i, j] = i

for i, j in repeat(2, 8, attrs='pp').spatial(4, 8).on(w):
b[i, j] = i

for i, j in repeat(2, 8, attrs='p.').spatial(4, 8).on(w):
b[i, j] = i

for i, j in repeat(2, 8, attrs='.p').spatial(4, 8).on(w):
b[i, j] = i

ir_module = script_module.ir_module()
func = hidet.driver.build_ir_module(ir_module)
source_code = func.source()
assert "#pragma omp parallel" in source_code
return func


def matmul(m_size, n_size, k_size):
from hidet.lang import grid, attr, f32
from hidet.lang.mapping import spatial

with hidet.script_module() as script_module:

@hidet.script
def matmul(a: f32[m_size, k_size], b: f32[k_size, n_size], c: f32[m_size, n_size]):
attr.func_kind = 'host_kernel'
ij_size = m_size * n_size
for ij in grid(ij_size, 'p'):
for i, j in spatial(m_size, n_size).on(ij):
c[i, j] = 0.0
for k in range(k_size):
c[i, j] += a[i, k] * b[k, j]

ir_module = script_module.ir_module()
return hidet.driver.build_ir_module(ir_module)


def test_parallel_v2():
m_size, n_size, k_size = 32, 32, 32
func = matmul(m_size, n_size, k_size)
a = hidet.randn((m_size, k_size))
b = hidet.randn((k_size, n_size))
c = hidet.empty((m_size, n_size))
cc = a @ b
func(a, b, c)
numpy.testing.assert_allclose(c.numpy(), cc.numpy(), atol=1e-5, rtol=1e-5)


if __name__ == '__main__':
pytest.main([__file__])