Skip to content

Commit 4ca6ed0

Browse files
committed
[mypyc] Use table-driven helper for imports
Change how imports (not from imports!) are processed so they can be table-driven and compact. Here's how it works: Import nodes are divided in groups (in the prebuild visitor). Each group consists of consecutive Import nodes: import mod <| group python#1 import mod2 | def foo() -> None: import mod3 <- group python#2 import mod4 <- group python#3 Every time we encounter the first import of a group, build IR to call CPyImport_ImportMany() that will perform all of the group's imports in one go. Previously, each module would imported and placed in globals manually in IR, leading to some pretty verbose code. The other option to collect all imports and perform them all at once in the helper would remove even more ops, however, it's problematic for the same reasons from the previous commit (spoiler: it's not safe). Implementation notes: - I had to add support for loading the address of a static directly, so I shoehorned in LoadLiteral support for LoadAddress. - Unfortunately by replacing multiple nodes with a single function call at the IR level, the traceback line number is static. Even if an import several lines down a group fails, the line # of the first import in the group would be printed. To fix this, I had to make CPyImport_ImportMany() add the traceback entry itself on failure (instead of letting codegen handle it automatically). This is admittedly ugly.
1 parent 11bfb04 commit 4ca6ed0

File tree

12 files changed

+352
-67
lines changed

12 files changed

+352
-67
lines changed

mypyc/codegen/emitfunc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,13 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None:
727727
def visit_load_address(self, op: LoadAddress) -> None:
728728
typ = op.type
729729
dest = self.reg(op)
730-
src = self.reg(op.src) if isinstance(op.src, Register) else op.src
730+
if isinstance(op.src, Register):
731+
src = self.reg(op.src)
732+
elif isinstance(op.src, LoadStatic):
733+
prefix = self.PREFIX_MAP[op.src.namespace]
734+
src = self.emitter.static_name(op.src.identifier, op.src.module_name, prefix)
735+
else:
736+
src = op.src
731737
self.emit_line(f"{dest} = ({typ._ctype})&{src};")
732738

733739
def visit_keep_alive(self, op: KeepAlive) -> None:

mypyc/ir/ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,13 +1348,14 @@ class LoadAddress(RegisterOp):
13481348
Attributes:
13491349
type: Type of the loaded address(e.g. ptr/object_ptr)
13501350
src: Source value (str for globals like 'PyList_Type',
1351-
Register for temporary values or locals)
1351+
Register for temporary values or locals, LoadStatic
1352+
for statics.)
13521353
"""
13531354

13541355
error_kind = ERR_NEVER
13551356
is_borrowed = True
13561357

1357-
def __init__(self, type: RType, src: str | Register, line: int = -1) -> None:
1358+
def __init__(self, type: RType, src: str | Register | LoadStatic, line: int = -1) -> None:
13581359
super().__init__(line)
13591360
self.type = type
13601361
self.src = src

mypyc/ir/pprint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> str:
266266
def visit_load_address(self, op: LoadAddress) -> str:
267267
if isinstance(op.src, Register):
268268
return self.format("%r = load_address %r", op, op.src)
269+
elif isinstance(op.src, LoadStatic):
270+
name = op.src.identifier
271+
if op.src.module_name is not None:
272+
name = f"{op.src.module_name}.{name}"
273+
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
269274
else:
270275
return self.format("%r = load_address %s", op, op.src)
271276

mypyc/irbuild/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ def __init__(
187187
self.encapsulating_funcs = pbv.encapsulating_funcs
188188
self.nested_fitems = pbv.nested_funcs.keys()
189189
self.fdefs_to_decorators = pbv.funcs_to_decorators
190+
self.module_import_groups = pbv.module_import_groups
191+
190192
self.singledispatch_impls = singledispatch_impls
191193

192194
self.visitor = visitor

mypyc/irbuild/ll_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,12 @@ def new_list_op(self, values: list[Value], line: int) -> Value:
16931693
def new_set_op(self, values: list[Value], line: int) -> Value:
16941694
return self.call_c(new_set_op, values, line)
16951695

1696+
def setup_rarray(self, item_type: RType, values: Sequence[Value]) -> Value:
1697+
"""Declare and initialize a new RArray, returning its address."""
1698+
array = Register(RArray(item_type, len(values)))
1699+
self.add(AssignMulti(array, list(values)))
1700+
return self.add(LoadAddress(c_pointer_rprimitive, array))
1701+
16961702
def shortcircuit_helper(
16971703
self,
16981704
op: str,

mypyc/irbuild/prebuildvisitor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@
55
Expression,
66
FuncDef,
77
FuncItem,
8+
Import,
89
LambdaExpr,
910
MemberExpr,
1011
MypyFile,
1112
NameExpr,
13+
Node,
1214
SymbolNode,
1315
Var,
1416
)
15-
from mypy.traverser import TraverserVisitor
17+
from mypy.traverser import ExtendedTraverserVisitor
1618
from mypyc.errors import Errors
1719

1820

19-
class PreBuildVisitor(TraverserVisitor):
21+
class PreBuildVisitor(ExtendedTraverserVisitor):
2022
"""Mypy file AST visitor run before building the IR.
2123
2224
This collects various things, including:
@@ -26,6 +28,7 @@ class PreBuildVisitor(TraverserVisitor):
2628
* Find non-local variables (free variables)
2729
* Find property setters
2830
* Find decorators of functions
31+
* Find module import groups
2932
3033
The main IR build pass uses this information.
3134
"""
@@ -68,10 +71,28 @@ def __init__(
6871
# Map function to indices of decorators to remove
6972
self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove
7073

74+
# Map starting module import to import groups. Each group is a
75+
# series of imports with nothing between.
76+
self.module_import_groups: dict[Import, list[Import]] = {}
77+
self._current_import_group: Import | None = None
78+
7179
self.errors: Errors = errors
7280

7381
self.current_file: MypyFile = current_file
7482

83+
def visit(self, o: Node) -> bool:
84+
if isinstance(o, Import):
85+
if self._current_import_group is not None:
86+
self.module_import_groups[self._current_import_group].append(o)
87+
else:
88+
self.module_import_groups[o] = [o]
89+
self._current_import_group = o
90+
# Don't recurse into the import's assignments.
91+
return False
92+
93+
self._current_import_group = None
94+
return True
95+
7596
def visit_decorator(self, dec: Decorator) -> None:
7697
if dec.decorators:
7798
# Only add the function being decorated if there exist

mypyc/irbuild/statement.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
LoadAddress,
5454
LoadErrorValue,
5555
LoadLiteral,
56+
LoadStatic,
5657
MethodCall,
5758
RaiseStandardError,
5859
Register,
@@ -63,6 +64,7 @@
6364
)
6465
from mypyc.ir.rtypes import (
6566
RInstance,
67+
c_pyssize_t_rprimitive,
6668
exc_rtuple,
6769
is_tagged,
6870
none_rprimitive,
@@ -100,6 +102,7 @@
100102
check_stop_op,
101103
coro_op,
102104
import_from_many_op,
105+
import_many_op,
103106
send_op,
104107
type_op,
105108
yield_from_except_op,
@@ -220,32 +223,69 @@ def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignm
220223
def transform_import(builder: IRBuilder, node: Import) -> None:
221224
if node.is_mypy_only:
222225
return
223-
globals = builder.load_globals_dict()
224-
for node_id, as_name in node.ids:
225-
builder.gen_import(node_id, node.line)
226-
227-
# Update the globals dict with the appropriate module:
228-
# * For 'import foo.bar as baz' we add 'foo.bar' with the name 'baz'
229-
# * For 'import foo.bar' we add 'foo' with the name 'foo'
230-
# Typically we then ignore these entries and access things directly
231-
# via the module static, but we will use the globals version for modules
232-
# that mypy couldn't find, since it doesn't analyze module references
233-
# from those properly.
234-
235-
# TODO: Don't add local imports to the global namespace
236-
237-
# Miscompiling imports inside of functions, like below in import from.
238-
if as_name:
239-
name = as_name
240-
base = node_id
241-
else:
242-
base = name = node_id.split(".")[0]
243226

244-
obj = builder.get_module(base, node.line)
227+
# Imports (not from imports!) are processed in an odd way so they can be
228+
# table-driven and compact. Here's how it works:
229+
#
230+
# Import nodes are divided in groups (in the prebuild visitor). Each group
231+
# consists of consecutive Import nodes:
232+
#
233+
# import mod <| group #1
234+
# import mod2 |
235+
#
236+
# def foo() -> None:
237+
# import mod3 <- group #2
238+
#
239+
# import mod4 <- group #3
240+
#
241+
# Every time we encounter the first import of a group, build IR to call a
242+
# helper function that will perform all of the group's imports in one go.
243+
if node not in builder.module_import_groups:
244+
return
245245

246-
builder.gen_method_call(
247-
globals, "__setitem__", [builder.load_str(name), obj], result_type=None, line=node.line
248-
)
246+
modules = []
247+
statics = []
248+
# To show the right line number on failure, we have to add the traceback
249+
# entry within the helper function (which is admittedly ugly). To drive
250+
# this, we'll need the line number corresponding to each import.
251+
import_lines = []
252+
for import_node in builder.module_import_groups[node]:
253+
for mod_id, as_name in import_node.ids:
254+
builder.imports[mod_id] = None
255+
import_lines.append(Integer(import_node.line, c_pyssize_t_rprimitive))
256+
257+
module_static = LoadStatic(object_rprimitive, mod_id, namespace=NAMESPACE_MODULE)
258+
static_ptr = builder.add(LoadAddress(object_pointer_rprimitive, module_static))
259+
statics.append(static_ptr)
260+
# TODO: Don't add local imports to the global namespace
261+
# Update the globals dict with the appropriate module:
262+
# * For 'import foo.bar as baz' we add 'foo.bar' with the name 'baz'
263+
# * For 'import foo.bar' we add 'foo' with the name 'foo'
264+
# Typically we then ignore these entries and access things directly
265+
# via the module static, but we will use the globals version for
266+
# modules that mypy couldn't find, since it doesn't analyze module
267+
# references from those properly.
268+
if as_name or "." not in mod_id:
269+
globals_base = None
270+
else:
271+
globals_base = mod_id.split(".")[0]
272+
modules.append((mod_id, as_name, globals_base))
273+
274+
static_array_ptr = builder.builder.setup_rarray(object_pointer_rprimitive, statics)
275+
import_line_ptr = builder.builder.setup_rarray(c_pyssize_t_rprimitive, import_lines)
276+
function = "<module>" if builder.fn_info.name == "<top level>" else builder.fn_info.name
277+
builder.call_c(
278+
import_many_op,
279+
[
280+
builder.add(LoadLiteral(tuple(modules), object_rprimitive)),
281+
static_array_ptr,
282+
builder.load_globals_dict(),
283+
builder.load_str(builder.module_path),
284+
builder.load_str(function),
285+
import_line_ptr,
286+
],
287+
NO_TRACEBACK_LINE_NO,
288+
)
249289

250290

251291
def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None:

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ PyObject *CPy_Super(PyObject *builtins, PyObject *self);
622622
PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *op,
623623
_Py_Identifier *method);
624624

625+
bool CPyImport_ImportMany(PyObject *modules, CPyModule **statics[], PyObject *globals,
626+
PyObject *tb_path, PyObject *tb_function, Py_ssize_t *tb_lines);
625627
PyObject *CPyImport_ImportFromMany(PyObject *mod_id, PyObject *names, PyObject *as_names,
626628
PyObject *globals);
627629

mypyc/lib-rt/misc_ops.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,68 @@ CPy_Super(PyObject *builtins, PyObject *self) {
669669
return result;
670670
}
671671

672+
static bool import_single(PyObject *mod_id,
673+
PyObject *as_name,
674+
PyObject **mod_static,
675+
PyObject *globals_base,
676+
PyObject *globals) {
677+
if (*mod_static == Py_None) {
678+
CPyModule *mod = PyImport_Import(mod_id);
679+
if (mod == NULL) {
680+
return false;
681+
}
682+
*mod_static = mod;
683+
}
684+
685+
if (as_name == Py_None) {
686+
as_name = mod_id;
687+
}
688+
PyObject *globals_id, *globals_name;
689+
if (globals_base == Py_None) {
690+
globals_id = mod_id;
691+
globals_name = as_name;
692+
} else {
693+
globals_id = globals_name = globals_base;
694+
}
695+
PyObject *mod_dict = PyImport_GetModuleDict();
696+
CPyModule *globals_mod = CPyDict_GetItem(mod_dict, globals_id);
697+
if (globals_mod == NULL) {
698+
return false;
699+
}
700+
int ret = CPyDict_SetItem(globals, globals_name, globals_mod);
701+
Py_DECREF(globals_mod);
702+
if (ret < 0) {
703+
return false;
704+
}
705+
706+
return true;
707+
}
708+
709+
// Table-driven import helper. See transform_import() in irbuild for the details.
710+
bool CPyImport_ImportMany(PyObject *modules, CPyModule **statics[], PyObject *globals,
711+
PyObject *tb_path, PyObject *tb_function, Py_ssize_t *tb_lines) {
712+
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(modules); i++) {
713+
PyObject *module = PyTuple_GET_ITEM(modules, i);
714+
PyObject *mod_id = PyTuple_GET_ITEM(module, 0);
715+
PyObject *as_name = PyTuple_GET_ITEM(module, 1);
716+
PyObject *globals_base = PyTuple_GET_ITEM(module, 2);
717+
718+
if (!import_single(mod_id, as_name, statics[i], globals_base, globals)) {
719+
const char *path = PyUnicode_AsUTF8(tb_path);
720+
if (path == NULL) {
721+
path = "<unable to display>";
722+
}
723+
const char *function = PyUnicode_AsUTF8(tb_function);
724+
if (function == NULL) {
725+
function = "<unable to display>";
726+
}
727+
CPy_AddTraceback(path, function, tb_lines[i], globals);
728+
return false;
729+
}
730+
}
731+
return true;
732+
}
733+
672734
// This helper function is a simplification of cpython/ceval.c/import_from()
673735
static PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
674736
PyObject *import_name, PyObject *as_name) {

mypyc/primitives/misc_ops.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
bit_rprimitive,
88
bool_rprimitive,
99
c_int_rprimitive,
10+
c_pointer_rprimitive,
1011
c_pyssize_t_rprimitive,
1112
dict_rprimitive,
1213
int_rprimitive,
@@ -111,14 +112,29 @@
111112
is_borrowed=True,
112113
)
113114

114-
# Import a module
115+
# Import a module (plain)
115116
import_op = custom_op(
116117
arg_types=[str_rprimitive],
117118
return_type=object_rprimitive,
118119
c_function_name="PyImport_Import",
119120
error_kind=ERR_MAGIC,
120121
)
121122

123+
# Import helper op (handles globals/statics & can import multiple modules)
124+
import_many_op = custom_op(
125+
arg_types=[
126+
object_rprimitive,
127+
c_pointer_rprimitive,
128+
object_rprimitive,
129+
object_rprimitive,
130+
object_rprimitive,
131+
c_pointer_rprimitive,
132+
],
133+
return_type=bit_rprimitive,
134+
c_function_name="CPyImport_ImportMany",
135+
error_kind=ERR_FALSE,
136+
)
137+
122138
# From import helper op
123139
import_from_many_op = custom_op(
124140
arg_types=[object_rprimitive, object_rprimitive, object_rprimitive, object_rprimitive],

0 commit comments

Comments
 (0)