Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 345e700

Browse files
author
Ivan Butygin
authored
move mlir numba passes to mlir dir (#212)
1 parent 99de6b5 commit 345e700

File tree

4 files changed

+142
-132
lines changed

4 files changed

+142
-132
lines changed

numba/core/compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@
2828
NopythonRewrites, PreParforPass,
2929
ParforPass, DumpParforDiagnostics,
3030
IRLegalization, NoPythonBackend,
31-
InlineOverloads, PreLowerStripPhis,
32-
MlirDumpPlier, MlirBackend)
31+
InlineOverloads, PreLowerStripPhis)
3332

3433
from numba.core.object_mode_passes import (ObjectModeFrontEnd,
3534
ObjectModeBackEnd, CompileInterpMode)
3635

36+
from numba.mlir.passes import (MlirDumpPlier, MlirBackend)
37+
3738
class Flags(utils.ConfigOptions):
3839
# These options are all false by default, but the defaults are
3940
# different with the @jit decorator (see targets.options.TargetOptions).

numba/core/typed_passes.py

Lines changed: 0 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -467,134 +467,6 @@ def run_pass(self, state):
467467
)
468468
return True
469469

470-
import numba.mlir.settings
471-
import numba.mlir.func_registry
472-
import numba.core.types.functions
473-
_mlir_last_compiled_func = None
474-
_mlir_active_module = None
475-
476-
class MlirBackendBase(FunctionPass):
477-
478-
def __init__(self):
479-
import numba.mlir.func_registry
480-
self._get_func_name = numba.mlir.func_registry.get_func_name
481-
FunctionPass.__init__(self)
482-
483-
def run_pass(self, state):
484-
numba.mlir.func_registry.push_active_funcs_stack()
485-
try:
486-
res = self.run_pass_impl(state)
487-
finally:
488-
numba.mlir.func_registry.pop_active_funcs_stack()
489-
return res
490-
491-
def _resolve_func_name(self, obj):
492-
name, func = self._resolve_func_name_impl(obj)
493-
if not (name is None or func is None):
494-
numba.mlir.func_registry.add_active_funcs(name, func)
495-
return name
496-
497-
def _resolve_func_name_impl(self, obj):
498-
if isinstance(obj, types.Function):
499-
func = obj.typing_key
500-
return (self._get_func_name(func), None)
501-
if isinstance(obj, types.BoundFunction):
502-
return (str(obj.typing_key), None)
503-
if isinstance(obj, numba.core.types.functions.Dispatcher):
504-
func = obj.dispatcher.py_func
505-
return (func.__module__ + "." + func.__qualname__, func)
506-
return (None, None)
507-
508-
def _get_func_context(self, state):
509-
mangler = state.targetctx.mangler
510-
mangler = default_mangler if mangler is None else mangler
511-
unique_name = state.func_ir.func_id.unique_name
512-
modname = state.func_ir.func_id.func.__module__
513-
from numba.core.funcdesc import qualifying_prefix
514-
qualprefix = qualifying_prefix(modname, unique_name)
515-
fn_name = mangler(qualprefix, state.args)
516-
517-
from numba.np.ufunc.parallel import get_thread_count
518-
519-
ctx = {}
520-
ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR}
521-
ctx['typemap'] = lambda op: state.typemap[op.name]
522-
ctx['fnargs'] = lambda: state.args
523-
ctx['restype'] = lambda: state.return_type
524-
ctx['fnname'] = lambda: fn_name
525-
ctx['resolve_func'] = self._resolve_func_name
526-
ctx['fastmath'] = lambda: state.targetctx.fastmath
527-
ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0
528-
return ctx
529-
530-
@register_pass(mutates_CFG=True, analysis_only=False)
531-
class MlirDumpPlier(MlirBackendBase):
532-
533-
_name = "mlir_dump_plier"
534-
535-
def __init__(self):
536-
MlirBackendBase.__init__(self)
537-
538-
def run_pass(self, state):
539-
import mlir_compiler
540-
module = mlir_compiler.create_module()
541-
ctx = self._get_func_context(state)
542-
mlir_compiler.lower_function(ctx, module, state.func_ir)
543-
print(mlir_compiler.module_str(module))
544-
return True
545-
546-
def get_mlir_func():
547-
global _mlir_last_compiled_func
548-
return _mlir_last_compiled_func
549-
550-
@register_pass(mutates_CFG=True, analysis_only=False)
551-
class MlirBackend(MlirBackendBase):
552-
553-
_name = "mlir_backend"
554-
555-
def __init__(self):
556-
MlirBackendBase.__init__(self)
557-
558-
def run_pass_impl(self, state):
559-
import mlir_compiler
560-
global _mlir_active_module
561-
old_module = _mlir_active_module
562-
563-
try:
564-
module = mlir_compiler.create_module()
565-
_mlir_active_module = module
566-
global _mlir_last_compiled_func
567-
ctx = self._get_func_context(state)
568-
_mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir)
569-
mod_ir = mlir_compiler.compile_module(ctx, module)
570-
finally:
571-
_mlir_active_module = old_module
572-
setattr(state, 'mlir_blob', mod_ir)
573-
_reload_parfors()
574-
state.reload_init.append(_reload_parfors)
575-
return True
576-
577-
@register_pass(mutates_CFG=True, analysis_only=False)
578-
class MlirBackendInner(MlirBackendBase):
579-
580-
_name = "mlir_backend_inner"
581-
582-
def __init__(self):
583-
MlirBackendBase.__init__(self)
584-
585-
def run_pass_impl(self, state):
586-
import mlir_compiler
587-
global _mlir_active_module
588-
module = _mlir_active_module
589-
assert not module is None
590-
global _mlir_last_compiled_func
591-
ctx = self._get_func_context(state)
592-
_mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir)
593-
from numba.core.compiler import compile_result
594-
state.cr = compile_result()
595-
return True
596-
597-
598470
@register_pass(mutates_CFG=True, analysis_only=False)
599471
class InlineOverloads(FunctionPass):
600472
"""

numba/mlir/inner_compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from numba.core.typed_passes import get_mlir_func, NopythonTypeInference, AnnotateTypes, MlirBackendInner
1+
from numba.core.typed_passes import NopythonTypeInference, AnnotateTypes
22
from numba.core.compiler import CompilerBase, DefaultPassBuilder, DEFAULT_FLAGS, compile_extra
33
from numba.core.compiler_machinery import PassManager
44
from numba.core import typing, cpu
5-
# from numba import njit
5+
6+
from numba.mlir.passes import MlirBackendInner, get_mlir_func
67

78
class MlirTempCompiler(CompilerBase): # custom compiler extends from CompilerBase
89

numba/mlir/passes.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from numba.core.compiler_machinery import (FunctionPass, register_pass)
2+
from numba.core import (types)
3+
4+
import numba.mlir.settings
5+
import numba.mlir.func_registry
6+
import numba.core.types.functions
7+
_mlir_last_compiled_func = None
8+
_mlir_active_module = None
9+
10+
def _reload_parfors():
11+
"""Reloader for cached parfors
12+
"""
13+
# Re-initialize the parallel backend when load from cache.
14+
from numba.np.ufunc.parallel import _launch_threads
15+
_launch_threads()
16+
17+
class MlirBackendBase(FunctionPass):
18+
19+
def __init__(self):
20+
import numba.mlir.func_registry
21+
self._get_func_name = numba.mlir.func_registry.get_func_name
22+
FunctionPass.__init__(self)
23+
24+
def run_pass(self, state):
25+
numba.mlir.func_registry.push_active_funcs_stack()
26+
try:
27+
res = self.run_pass_impl(state)
28+
finally:
29+
numba.mlir.func_registry.pop_active_funcs_stack()
30+
return res
31+
32+
def _resolve_func_name(self, obj):
33+
name, func = self._resolve_func_name_impl(obj)
34+
if not (name is None or func is None):
35+
numba.mlir.func_registry.add_active_funcs(name, func)
36+
return name
37+
38+
def _resolve_func_name_impl(self, obj):
39+
if isinstance(obj, types.Function):
40+
func = obj.typing_key
41+
return (self._get_func_name(func), None)
42+
if isinstance(obj, types.BoundFunction):
43+
return (str(obj.typing_key), None)
44+
if isinstance(obj, numba.core.types.functions.Dispatcher):
45+
func = obj.dispatcher.py_func
46+
return (func.__module__ + "." + func.__qualname__, func)
47+
return (None, None)
48+
49+
def _get_func_context(self, state):
50+
mangler = state.targetctx.mangler
51+
mangler = default_mangler if mangler is None else mangler
52+
unique_name = state.func_ir.func_id.unique_name
53+
modname = state.func_ir.func_id.func.__module__
54+
from numba.core.funcdesc import qualifying_prefix
55+
qualprefix = qualifying_prefix(modname, unique_name)
56+
fn_name = mangler(qualprefix, state.args)
57+
58+
from numba.np.ufunc.parallel import get_thread_count
59+
60+
ctx = {}
61+
ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR}
62+
ctx['typemap'] = lambda op: state.typemap[op.name]
63+
ctx['fnargs'] = lambda: state.args
64+
ctx['restype'] = lambda: state.return_type
65+
ctx['fnname'] = lambda: fn_name
66+
ctx['resolve_func'] = self._resolve_func_name
67+
ctx['fastmath'] = lambda: state.targetctx.fastmath
68+
ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0
69+
return ctx
70+
71+
@register_pass(mutates_CFG=True, analysis_only=False)
72+
class MlirDumpPlier(MlirBackendBase):
73+
74+
_name = "mlir_dump_plier"
75+
76+
def __init__(self):
77+
MlirBackendBase.__init__(self)
78+
79+
def run_pass(self, state):
80+
import mlir_compiler
81+
module = mlir_compiler.create_module()
82+
ctx = self._get_func_context(state)
83+
mlir_compiler.lower_function(ctx, module, state.func_ir)
84+
print(mlir_compiler.module_str(module))
85+
return True
86+
87+
def get_mlir_func():
88+
global _mlir_last_compiled_func
89+
return _mlir_last_compiled_func
90+
91+
@register_pass(mutates_CFG=True, analysis_only=False)
92+
class MlirBackend(MlirBackendBase):
93+
94+
_name = "mlir_backend"
95+
96+
def __init__(self):
97+
MlirBackendBase.__init__(self)
98+
99+
def run_pass_impl(self, state):
100+
import mlir_compiler
101+
global _mlir_active_module
102+
old_module = _mlir_active_module
103+
104+
try:
105+
module = mlir_compiler.create_module()
106+
_mlir_active_module = module
107+
global _mlir_last_compiled_func
108+
ctx = self._get_func_context(state)
109+
_mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir)
110+
mod_ir = mlir_compiler.compile_module(ctx, module)
111+
finally:
112+
_mlir_active_module = old_module
113+
setattr(state, 'mlir_blob', mod_ir)
114+
_reload_parfors()
115+
state.reload_init.append(_reload_parfors)
116+
return True
117+
118+
@register_pass(mutates_CFG=True, analysis_only=False)
119+
class MlirBackendInner(MlirBackendBase):
120+
121+
_name = "mlir_backend_inner"
122+
123+
def __init__(self):
124+
MlirBackendBase.__init__(self)
125+
126+
def run_pass_impl(self, state):
127+
import mlir_compiler
128+
global _mlir_active_module
129+
module = _mlir_active_module
130+
assert not module is None
131+
global _mlir_last_compiled_func
132+
ctx = self._get_func_context(state)
133+
_mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir)
134+
from numba.core.compiler import compile_result
135+
state.cr = compile_result()
136+
return True

0 commit comments

Comments
 (0)