Skip to content

Commit

Permalink
[mypyc] Switch to table-driven imports for smaller IR (python#14917)
Browse files Browse the repository at this point in the history
Add CPyImport_ImportFromMany() which imports a module and a tuple of
names, placing them in the globals dict in one go.

Previously, each name would imported and placed in globals manually
in IR, leading to some pretty verbose code.

The other option to collect all from imports and perform them all at
once in the helper would remove even more ops, however, it has some
major downsides:

- It wouldn't be able to be done in IRBuild directly, instead being
  handled in the prebuild visitor and codegen... which all sounds
  really involved.

- It would cause from imports to be performed eagerly, potentially
  causing circular imports (especially in functions whose imports are
  probably there to avoid a circular import!).

The latter is the nail in the coffin for this idea.

---

Change how imports (not from imports!) are processed so they can be
table-driven (tuple-driven, really) and compact. Here's how it works:

Import nodes are divided in groups (in the prebuild visitor). Each group
consists of consecutive Import nodes:

    import mod         <| group 1
    import mod2         |

    def foo() -> None:
        import mod3    <- group 2 (*)
  
    import mod4        <- group 3

Every time we encounter the first import of a group, build IR to call
CPyImport_ImportMany() that will perform all of the group's imports in
one go.

(*) Imports in functions or classes are still transformed into the
original,
verbose IR as speed is more important than codesize.

Previously, each module would imported and placed in globals manually
in IR, leading to some pretty verbose code.

The other option to collect all imports and perform them all at once in
the helper would remove even more ops, however, it's problematic for
the same reasons from the previous commit (spoiler: it's not safe).

Implementation notes:

  - I had to add support for loading the address of a static directly,
    so I shoehorned in LoadStatic support for LoadAddress.

  - Unfortunately by replacing multiple import nodes with a single
    function call at the IR level, if any import within a group fails,
    the traceback line number is static and will be probably wrong
    (pointing to the first import node in my original impl.).

    To fix this, I had to make CPyImport_ImportMany() add the traceback
    entry itself on failure (instead of letting codegen handle it
    automatically). This is admittedly ugly.

Overall, this doesn't speed up initialization. The only real speed
impact is that back to back imports in a tight loop seems to be 10-20%
slower. I believe that's acceptable given the code size reduction.

---

**Other changes:**

- Don't declare internal static for non-compiled modules

  It won't be read anywhere as the internal statics are only used to
  avoid runaway recursion with import cycles in our module init
  functions.

- Wrap long RArray initializers and long annotations in codegen

  Table-driven imports can load some rather large RArrays and tuple
  literals so this was needed to keep the generated C readable.

- Add LLBuilder helper for setting up a RArray

Resolves mypyc/mypyc#591.
  • Loading branch information
ichard26 committed Apr 24, 2023
1 parent aee983e commit ba35026
Show file tree
Hide file tree
Showing 19 changed files with 954 additions and 622 deletions.
59 changes: 57 additions & 2 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import pprint
import sys
import textwrap
from typing import Callable
from typing_extensions import Final

Expand Down Expand Up @@ -191,10 +193,31 @@ def reg(self, reg: Value) -> str:
def attr(self, name: str) -> str:
return ATTR_PREFIX + name

def emit_line(self, line: str = "") -> None:
def object_annotation(self, obj: object, line: str) -> str:
"""Build a C comment with an object's string represention.
If the comment exceeds the line length limit, it's wrapped into a
multiline string (with the extra lines indented to be aligned with
the first line's comment).
If it contains illegal characters, an empty string is returned."""
line_width = self._indent + len(line)
formatted = pprint.pformat(obj, compact=True, width=max(90 - line_width, 20))
if any(x in formatted for x in ("/*", "*/", "\0")):
return ""

if "\n" in formatted:
first_line, rest = formatted.split("\n", maxsplit=1)
comment_continued = textwrap.indent(rest, (line_width + 3) * " ")
return f" /* {first_line}\n{comment_continued} */"
else:
return f" /* {formatted} */"

def emit_line(self, line: str = "", *, ann: object = None) -> None:
if line.startswith("}"):
self.dedent()
self.fragments.append(self._indent * " " + line + "\n")
comment = self.object_annotation(ann, line) if ann is not None else ""
self.fragments.append(self._indent * " " + line + comment + "\n")
if line.endswith("{"):
self.indent()

Expand Down Expand Up @@ -1119,3 +1142,35 @@ def _emit_traceback(
self.emit_line(line)
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')


def c_array_initializer(components: list[str], *, indented: bool = False) -> str:
"""Construct an initializer for a C array variable.
Components are C expressions valid in an initializer.
For example, if components are ["1", "2"], the result
would be "{1, 2}", which can be used like this:
int a[] = {1, 2};
If the result is long, split it into multiple lines.
"""
indent = " " * 4 if indented else ""
res = []
current: list[str] = []
cur_len = 0
for c in components:
if not current or cur_len + 2 + len(indent) + len(c) < 70:
current.append(c)
cur_len += len(c) + 2
else:
res.append(indent + ", ".join(current))
current = [c]
cur_len = len(c)
if not res:
# Result fits on a single line
return "{%s}" % ", ".join(current)
# Multi-line result
res.append(indent + ", ".join(current))
return "{\n " + ",\n ".join(res) + "\n" + indent + "}"
43 changes: 18 additions & 25 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Final

from mypyc.analysis.blockfreq import frequently_executed_blocks
from mypyc.codegen.emit import DEBUG_ERRORS, Emitter, TracebackAndGotoHandler
from mypyc.codegen.emit import DEBUG_ERRORS, Emitter, TracebackAndGotoHandler, c_array_initializer
from mypyc.common import MODULE_PREFIX, NATIVE_PREFIX, REG_PREFIX, STATIC_PREFIX, TYPE_PREFIX
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FuncDecl, FuncIR, all_values
Expand Down Expand Up @@ -262,12 +262,12 @@ def visit_assign_multi(self, op: AssignMulti) -> None:
# RArray values can only be assigned to once, so we can always
# declare them on initialization.
self.emit_line(
"%s%s[%d] = {%s};"
"%s%s[%d] = %s;"
% (
self.emitter.ctype_spaced(typ.item_type),
dest,
len(op.src),
", ".join(self.reg(s) for s in op.src),
c_array_initializer([self.reg(s) for s in op.src], indented=True),
)
)

Expand All @@ -282,15 +282,12 @@ def visit_load_error_value(self, op: LoadErrorValue) -> None:

def visit_load_literal(self, op: LoadLiteral) -> None:
index = self.literals.literal_index(op.value)
s = repr(op.value)
if not any(x in s for x in ("/*", "*/", "\0")):
ann = " /* %s */" % s
else:
ann = ""
if not is_int_rprimitive(op.type):
self.emit_line("%s = CPyStatics[%d];%s" % (self.reg(op), index, ann))
self.emit_line("%s = CPyStatics[%d];" % (self.reg(op), index), ann=op.value)
else:
self.emit_line("%s = (CPyTagged)CPyStatics[%d] | 1;%s" % (self.reg(op), index, ann))
self.emit_line(
"%s = (CPyTagged)CPyStatics[%d] | 1;" % (self.reg(op), index), ann=op.value
)

def get_attr_expr(self, obj: str, op: GetAttr | SetAttr, decl_cl: ClassIR) -> str:
"""Generate attribute accessor for normal (non-property) access.
Expand Down Expand Up @@ -468,12 +465,7 @@ def visit_load_static(self, op: LoadStatic) -> None:
name = self.emitter.static_name(op.identifier, op.module_name, prefix)
if op.namespace == NAMESPACE_TYPE:
name = "(PyObject *)%s" % name
ann = ""
if op.ann:
s = repr(op.ann)
if not any(x in s for x in ("/*", "*/", "\0")):
ann = " /* %s */" % s
self.emit_line(f"{dest} = {name};{ann}")
self.emit_line(f"{dest} = {name};", ann=op.ann)

def visit_init_static(self, op: InitStatic) -> None:
value = self.reg(op.value)
Expand Down Expand Up @@ -636,12 +628,7 @@ def visit_extend(self, op: Extend) -> None:

def visit_load_global(self, op: LoadGlobal) -> None:
dest = self.reg(op)
ann = ""
if op.ann:
s = repr(op.ann)
if not any(x in s for x in ("/*", "*/", "\0")):
ann = " /* %s */" % s
self.emit_line(f"{dest} = {op.identifier};{ann}")
self.emit_line(f"{dest} = {op.identifier};", ann=op.ann)

def visit_int_op(self, op: IntOp) -> None:
dest = self.reg(op)
Expand Down Expand Up @@ -727,7 +714,13 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None:
def visit_load_address(self, op: LoadAddress) -> None:
typ = op.type
dest = self.reg(op)
src = self.reg(op.src) if isinstance(op.src, Register) else op.src
if isinstance(op.src, Register):
src = self.reg(op.src)
elif isinstance(op.src, LoadStatic):
prefix = self.PREFIX_MAP[op.src.namespace]
src = self.emitter.static_name(op.src.identifier, op.src.module_name, prefix)
else:
src = op.src
self.emit_line(f"{dest} = ({typ._ctype})&{src};")

def visit_keep_alive(self, op: KeepAlive) -> None:
Expand Down Expand Up @@ -776,8 +769,8 @@ def c_error_value(self, rtype: RType) -> str:
def c_undefined_value(self, rtype: RType) -> str:
return self.emitter.c_undefined_value(rtype)

def emit_line(self, line: str) -> None:
self.emitter.emit_line(line)
def emit_line(self, line: str, *, ann: object = None) -> None:
self.emitter.emit_line(line, ann=ann)

def emit_lines(self, *lines: str) -> None:
self.emitter.emit_lines(*lines)
Expand Down
56 changes: 13 additions & 43 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from mypy.plugin import Plugin, ReportConfigContext
from mypy.util import hash_digest
from mypyc.codegen.cstring import c_string_initializer
from mypyc.codegen.emit import Emitter, EmitterContext, HeaderDeclaration
from mypyc.codegen.emit import Emitter, EmitterContext, HeaderDeclaration, c_array_initializer
from mypyc.codegen.emitclass import generate_class, generate_class_type_decl
from mypyc.codegen.emitfunc import generate_native_function, native_function_header
from mypyc.codegen.emitwrapper import (
Expand Down Expand Up @@ -296,11 +296,11 @@ def compile_ir_to_c(
# compiled into a separate extension module.
ctext: dict[str | None, list[tuple[str, str]]] = {}
for group_sources, group_name in groups:
group_modules = [
(source.module, modules[source.module])
group_modules = {
source.module: modules[source.module]
for source in group_sources
if source.module in modules
]
}
if not group_modules:
ctext[group_name] = []
continue
Expand Down Expand Up @@ -465,7 +465,7 @@ def group_dir(group_name: str) -> str:
class GroupGenerator:
def __init__(
self,
modules: list[tuple[str, ModuleIR]],
modules: dict[str, ModuleIR],
source_paths: dict[str, str],
group_name: str | None,
group_map: dict[str, str | None],
Expand Down Expand Up @@ -512,7 +512,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
multi_file = self.use_shared_lib and self.multi_file

# Collect all literal refs in IR.
for _, module in self.modules:
for module in self.modules.values():
for fn in module.functions:
collect_literals(fn, self.context.literals)

Expand All @@ -528,7 +528,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:

self.generate_literal_tables()

for module_name, module in self.modules:
for module_name, module in self.modules.items():
if multi_file:
emitter = Emitter(self.context)
emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
Expand Down Expand Up @@ -582,7 +582,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
declarations.emit_line("int CPyGlobalsInit(void);")
declarations.emit_line()

for module_name, module in self.modules:
for module_name, module in self.modules.items():
self.declare_finals(module_name, module.final_names, declarations)
for cl in module.classes:
generate_class_type_decl(cl, emitter, ext_declarations, declarations)
Expand Down Expand Up @@ -790,7 +790,7 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None:
"",
)

for mod, _ in self.modules:
for mod in self.modules:
name = exported_name(mod)
emitter.emit_lines(
f"extern PyObject *CPyInit_{name}(void);",
Expand Down Expand Up @@ -1023,12 +1023,13 @@ def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str
return emitter.static_name(module_name + "_internal", None, prefix=MODULE_PREFIX)

def declare_module(self, module_name: str, emitter: Emitter) -> None:
# We declare two globals for each module:
# We declare two globals for each compiled module:
# one used internally in the implementation of module init to cache results
# and prevent infinite recursion in import cycles, and one used
# by other modules to refer to it.
internal_static_name = self.module_internal_static_name(module_name, emitter)
self.declare_global("CPyModule *", internal_static_name, initializer="NULL")
if module_name in self.modules:
internal_static_name = self.module_internal_static_name(module_name, emitter)
self.declare_global("CPyModule *", internal_static_name, initializer="NULL")
static_name = emitter.static_name(module_name, None, prefix=MODULE_PREFIX)
self.declare_global("CPyModule *", static_name)
self.simple_inits.append((static_name, "Py_None"))
Expand Down Expand Up @@ -1126,37 +1127,6 @@ def collect_literals(fn: FuncIR, literals: Literals) -> None:
literals.record_literal(op.value)


def c_array_initializer(components: list[str]) -> str:
"""Construct an initializer for a C array variable.
Components are C expressions valid in an initializer.
For example, if components are ["1", "2"], the result
would be "{1, 2}", which can be used like this:
int a[] = {1, 2};
If the result is long, split it into multiple lines.
"""
res = []
current: list[str] = []
cur_len = 0
for c in components:
if not current or cur_len + 2 + len(c) < 70:
current.append(c)
cur_len += len(c) + 2
else:
res.append(", ".join(current))
current = [c]
cur_len = len(c)
if not res:
# Result fits on a single line
return "{%s}" % ", ".join(current)
# Multi-line result
res.append(", ".join(current))
return "{\n " + ",\n ".join(res) + "\n}"


def c_string_array_initializer(components: list[bytes]) -> str:
result = []
result.append("{\n")
Expand Down
5 changes: 3 additions & 2 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,13 +1348,14 @@ class LoadAddress(RegisterOp):
Attributes:
type: Type of the loaded address(e.g. ptr/object_ptr)
src: Source value (str for globals like 'PyList_Type',
Register for temporary values or locals)
Register for temporary values or locals, LoadStatic
for statics.)
"""

error_kind = ERR_NEVER
is_borrowed = True

def __init__(self, type: RType, src: str | Register, line: int = -1) -> None:
def __init__(self, type: RType, src: str | Register | LoadStatic, line: int = -1) -> None:
super().__init__(line)
self.type = type
self.src = src
Expand Down
5 changes: 5 additions & 0 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> str:
def visit_load_address(self, op: LoadAddress) -> str:
if isinstance(op.src, Register):
return self.format("%r = load_address %r", op, op.src)
elif isinstance(op.src, LoadStatic):
name = op.src.identifier
if op.src.module_name is not None:
name = f"{op.src.module_name}.{name}"
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
else:
return self.format("%r = load_address %s", op, op.src)

Expand Down
26 changes: 3 additions & 23 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
RType,
RUnion,
bitmap_rprimitive,
c_int_rprimitive,
c_pyssize_t_rprimitive,
dict_rprimitive,
int_rprimitive,
Expand Down Expand Up @@ -127,12 +126,7 @@
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list
from mypyc.primitives.misc_ops import (
check_unpack_count_op,
get_module_dict_op,
import_extra_args_op,
import_op,
)
from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op
from mypyc.primitives.registry import CFunctionDescription, function_ops

# These int binary operations can borrow their operands safely, since the
Expand Down Expand Up @@ -194,6 +188,8 @@ def __init__(
self.encapsulating_funcs = pbv.encapsulating_funcs
self.nested_fitems = pbv.nested_funcs.keys()
self.fdefs_to_decorators = pbv.funcs_to_decorators
self.module_import_groups = pbv.module_import_groups

self.singledispatch_impls = singledispatch_impls

self.visitor = visitor
Expand Down Expand Up @@ -395,22 +391,6 @@ def add_to_non_ext_dict(
key_unicode = self.load_str(key)
self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line)

def gen_import_from(
self, id: str, globals_dict: Value, imported: list[str], line: int
) -> Value:
self.imports[id] = None

null_dict = Integer(0, dict_rprimitive, line)
names_to_import = self.new_list_op([self.load_str(name) for name in imported], line)
zero_int = Integer(0, c_int_rprimitive, line)
value = self.call_c(
import_extra_args_op,
[self.load_str(id), globals_dict, null_dict, names_to_import, zero_int],
line,
)
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))
return value

def gen_import(self, id: str, line: int) -> None:
self.imports[id] = None

Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def gen_func_ns(builder: IRBuilder) -> str:
return "_".join(
info.name + ("" if not info.class_name else "_" + info.class_name)
for info in builder.fn_infos
if info.name and info.name != "<top level>"
if info.name and info.name != "<module>"
)


Expand Down
Loading

0 comments on commit ba35026

Please sign in to comment.