Skip to content

Commit

Permalink
Merge pull request #1194 from Hardcode84/numba-mlir-integr2
Browse files Browse the repository at this point in the history
Numba-mlir integration
  • Loading branch information
ZzEeKkAa authored Nov 14, 2023
2 parents 35890ba + ac4cb5d commit 7315aef
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 15 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ jobs:
python: ['3.9', '3.10', '3.11']
os: [ubuntu-20.04, ubuntu-latest, windows-latest]
experimental: [false]
use_mlir: [false]

continue-on-error: ${{ matrix.experimental }}
continue-on-error: ${{ matrix.experimental || matrix.use_mlir }}

steps:
- name: Setup miniconda
Expand Down Expand Up @@ -169,6 +170,10 @@ jobs:
- name: Install builded package
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} intel::intel-opencl-rt pytest -c ${{ env.CHANNEL_PATH }}

- name: Install numba-mlir
if: matrix.use_mlir
run: mamba install numba-mlir -c dppy/label/dev -c conda-forge -c intel

- name: Setup OpenCL CPU device
if: runner.os == 'Windows'
shell: pwsh
Expand All @@ -184,9 +189,13 @@ jobs:
python -c "import dpcpp_llvm_spirv as p; print(p.get_llvm_spirv_path())"
- name: Smoke test
env:
NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }}
run: python -c "import dpnp, dpctl, numba_dpex; dpctl.lsplatform()"

- name: Run tests
env:
NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }}
run: |
pytest -q -ra --disable-warnings --pyargs ${{ env.MODULE_NAME }} -vv
Expand Down
2 changes: 2 additions & 0 deletions numba_dpex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,5 @@ def __getattr__(name):
DPEX_OPT = _readenv("NUMBA_DPEX_OPT", int, 2)

INLINE_THRESHOLD = _readenv("NUMBA_DPEX_INLINE_THRESHOLD", int, None)

USE_MLIR = _readenv("NUMBA_DPEX_USE_MLIR", int, 0)
2 changes: 2 additions & 0 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ class DpexTargetOptions(CPUTargetOptions):
experimental = _option_mapping("experimental")
release_gil = _option_mapping("release_gil")
no_compile = _option_mapping("no_compile")
use_mlir = _option_mapping("use_mlir")

def finalize(self, flags, options):
super().finalize(flags, options)
_inherit_if_not_set(flags, options, "experimental", False)
_inherit_if_not_set(flags, options, "release_gil", False)
_inherit_if_not_set(flags, options, "no_compile", True)
_inherit_if_not_set(flags, options, "use_mlir", False)


class DpexKernelTarget(TargetDescriptor):
Expand Down
42 changes: 35 additions & 7 deletions numba_dpex/core/pipelines/dpjit_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class _DpjitPassBuilder(object):
execution.
"""

_use_mlir = False

@staticmethod
def define_typed_pipeline(state, name="dpex_dpjit_typed"):
"""Returns the typed part of the nopython pipeline"""
Expand All @@ -55,19 +57,31 @@ def define_typed_pipeline(state, name="dpex_dpjit_typed"):
pm.add_pass(NopythonRewrites, "nopython rewrites")
pm.add_pass(ParforPass, "convert to parfors")
pm.add_pass(
ParforLegalizeCFDPass, "Legalize parfors for compute follows data"
ParforLegalizeCFDPass,
"Legalize parfors for compute follows data",
)
pm.add_pass(ParforFusionPass, "fuse parfors")
pm.add_pass(ParforPreLoweringPass, "parfor prelowering")

pm.finalize()
return pm

@staticmethod
def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
@classmethod
def define_nopython_lowering_pipeline(
cls, state, name="dpex_dpjit_lowering"
):
"""Returns an nopython mode pipeline based PassManager"""
pm = PassManager(name)

flags = state.flags
if cls._use_mlir or hasattr(flags, "use_mlir") and flags.use_mlir:
from numba_mlir.mlir.passes import MlirReplaceParfors

pm.add_pass(
MlirReplaceParfors,
"Lower parfor using MLIR pipeline",
)

# legalize
pm.add_pass(
NoPythonSupportedFeatureValidation,
Expand All @@ -85,11 +99,11 @@ def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
pm.finalize()
return pm

@staticmethod
def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
@classmethod
def define_nopython_pipeline(cls, state, name="dpex_dpjit_nopython"):
"""Returns an nopython mode pipeline based PassManager"""
# compose pipeline from untyped, typed and lowering parts
dpb = _DpjitPassBuilder
dpb = cls
pm = PassManager(name)
untyped_passes = DefaultPassBuilder.define_untyped_pipeline(state)
pm.passes.extend(untyped_passes.passes)
Expand All @@ -104,17 +118,31 @@ def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
return pm


class _DpjitPassBuilderMlir(_DpjitPassBuilder):
_use_mlir = True


class DpjitCompiler(CompilerBase):
"""Dpex's compiler pipeline to offload parfor nodes into SYCL kernels."""

_pass_builder = _DpjitPassBuilder

def define_pipelines(self):
pms = []
self.state.parfor_diagnostics = ExtendedParforDiagnostics()
self.state.metadata[
"parfor_diagnostics"
] = self.state.parfor_diagnostics
if not self.state.flags.force_pyobject:
pms.append(_DpjitPassBuilder.define_nopython_pipeline(self.state))
pms.append(self._pass_builder.define_nopython_pipeline(self.state))
if self.state.status.can_fallback or self.state.flags.force_pyobject:
raise UnsupportedCompilationModeError()
return pms


class DpjitCompilerMlir(DpjitCompiler):
_pass_builder = _DpjitPassBuilderMlir


def get_compiler(use_mlir):
return DpjitCompilerMlir if use_mlir else DpjitCompiler
9 changes: 7 additions & 2 deletions numba_dpex/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
compile_func,
compile_func_template,
)
from numba_dpex.core.pipelines.dpjit_compiler import DpjitCompiler
from numba_dpex.core.pipelines.dpjit_compiler import get_compiler

from .config import USE_MLIR


def kernel(
Expand Down Expand Up @@ -152,9 +154,12 @@ def dpjit(*args, **kws):
"pipeline class is set for dpjit and is ignored", RuntimeWarning
)
del kws["forceobj"]

use_mlir = kws.pop("use_mlir", bool(USE_MLIR))

kws.update({"nopython": True})
kws.update({"parallel": True})
kws.update({"pipeline_class": DpjitCompiler})
kws.update({"pipeline_class": get_compiler(use_mlir)})

# FIXME: When trying to use dpex's target context, overloads do not work
# properly. We will turn on dpex target once the issue is fixed.
Expand Down
25 changes: 24 additions & 1 deletion numba_dpex/tests/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@

import contextlib
import shutil
from functools import cache

import dpctl
import dpnp
import pytest

from numba_dpex import config, numba_sem_version
from numba_dpex import config, dpjit, numba_sem_version


@cache
def has_numba_mlir():
try:
import numba_mlir
except ImportError:
return False

return True


def has_opencl_gpu():
Expand Down Expand Up @@ -89,6 +100,10 @@ def is_windows():
not has_level_zero(),
reason="No level-zero GPU platforms available",
)
skip_no_numba_mlir = pytest.mark.skipif(
not has_numba_mlir(),
reason="numba-mlir package is not availabe",
)

filter_strings = [
pytest.param("level_zero:gpu:0", marks=skip_no_level_zero_gpu),
Expand Down Expand Up @@ -123,6 +138,14 @@ def is_windows():
)


decorators = [
pytest.param(dpjit, id="dpjit"),
pytest.param(
dpjit(use_mlir=True), id="dpjit_mlir", marks=skip_no_numba_mlir
),
]


@contextlib.contextmanager
def override_config(name, value, config=config):
"""
Expand Down
39 changes: 35 additions & 4 deletions numba_dpex/tests/test_prange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@

from numba_dpex import dpjit, prange

from ._helper import decorators

def test_one_prange_mul():
@dpjit

@pytest.mark.parametrize("jit", decorators)
def test_one_prange_mul(jit):
@jit
def f(a, b):
for i in prange(4):
b[i, 0] = a[i, 0] * 10
Expand All @@ -35,6 +38,33 @@ def f(a, b):
assert nb[i, 0] == na[i, 0] * 10


@pytest.mark.parametrize("jit", decorators)
def test_one_prange_mul_nested(jit):
@jit
def f_inner(a, b):
for i in prange(4):
b[i, 0] = a[i, 0] * 10
return

@jit
def f(a, b):
return f_inner(a, b)

device = dpctl.select_default_device()

m = 8
n = 8
a = dpnp.ones((m, n), device=device)
b = dpnp.ones((m, n), device=device)

f(a, b)
na = dpnp.asnumpy(a)
nb = dpnp.asnumpy(b)

for i in range(4):
assert nb[i, 0] == na[i, 0] * 10


@pytest.mark.skip(reason="dpnp.add() doesn't support variable + scalar.")
def test_one_prange_add_scalar():
@dpjit
Expand Down Expand Up @@ -155,8 +185,9 @@ def f(a, b):
assert np.all(b.asnumpy() == 12)


def test_two_consecutive_prange():
@dpjit
@pytest.mark.parametrize("jit", decorators)
def test_two_consecutive_prange(jit):
@jit
def prange_example(a, b, c, d):
for i in prange(n):
c[i] = a[i] + b[i]
Expand Down

0 comments on commit 7315aef

Please sign in to comment.