Skip to content

Commit

Permalink
[mypyc] Generate smaller code for casts (#12839)
Browse files Browse the repository at this point in the history
Merge a cast op followed by a branch that does an error check and adds a
traceback entry. Since casts are very common, this reduces the size of
the generated code a fair amount.

Old code generated for a cast:
```
    if (likely(PyUnicode_Check(cpy_r_x)))
        cpy_r_r0 = cpy_r_x;
    else {
        CPy_TypeError("str", cpy_r_x);
        cpy_r_r0 = NULL;
    }
    if (unlikely(cpy_r_r0 == NULL)) {
        CPy_AddTraceback("t/t.py", "foo", 2, CPyStatic_globals);
        goto CPyL2;
    }
```

New code:
```
    if (likely(PyUnicode_Check(cpy_r_x)))
        cpy_r_r0 = cpy_r_x;
    else {
        CPy_TypeErrorTraceback("t/t.py", "foo", 2, CPyStatic_globals, "str", cpy_r_x);
        goto CPyL2;
    }
```
  • Loading branch information
JukkaL committed May 23, 2022
1 parent c8efeed commit 040f3ab
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 57 deletions.
156 changes: 118 additions & 38 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from mypy.backports import OrderedDict
from typing import List, Set, Dict, Optional, Callable, Union, Tuple
from typing_extensions import Final

import sys

from mypyc.common import (
Expand All @@ -23,6 +25,10 @@
from mypyc.sametype import is_same_type
from mypyc.codegen.literals import Literals

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS: Final = False


class HeaderDeclaration:
"""A representation of a declaration in C.
Expand Down Expand Up @@ -104,6 +110,20 @@ def __init__(self, label: str) -> None:
self.label = label


class TracebackAndGotoHandler(ErrorHandler):
"""Add traceback item and goto label on error."""

def __init__(self,
label: str,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int]) -> None:
self.label = label
self.source_path = source_path
self.module_name = module_name
self.traceback_entry = traceback_entry


class ReturnHandler(ErrorHandler):
"""Return a constant value on error."""

Expand Down Expand Up @@ -439,18 +459,6 @@ def emit_cast(self,
likely: If the cast is likely to succeed (can be False for unions)
"""
error = error or AssignHandler()
if isinstance(error, AssignHandler):
handle_error = '%s = NULL;' % dest
elif isinstance(error, GotoHandler):
handle_error = 'goto %s;' % error.label
else:
assert isinstance(error, ReturnHandler)
handle_error = 'return %s;' % error.value
if raise_exception:
raise_exc = f'CPy_TypeError("{self.pretty_name(typ)}", {src}); '
err = raise_exc + handle_error
else:
err = handle_error

# Special case casting *from* optional
if src_type and is_optional_type(src_type) and not is_object_rprimitive(typ):
Expand All @@ -465,9 +473,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
return

# TODO: Verify refcount handling.
Expand Down Expand Up @@ -500,9 +508,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(prefix, src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_bytes_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -512,9 +520,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_tuple_rprimitive(typ):
if declare_dest:
self.emit_line(f'{self.ctype(typ)} {dest};')
Expand All @@ -525,9 +533,9 @@ def emit_cast(self,
check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif isinstance(typ, RInstance):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -551,10 +559,10 @@ def emit_cast(self,
check = f'(likely{check})'
self.emit_arg_check(src, dest, typ, check, optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
f' {dest} = {src};'.format(dest, src),
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_none_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -565,9 +573,9 @@ def emit_cast(self,
check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_object_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -576,21 +584,51 @@ def emit_cast(self,
if optional:
self.emit_line('}')
elif isinstance(typ, RUnion):
self.emit_union_cast(src, dest, typ, declare_dest, err, optional, src_type)
self.emit_union_cast(src, dest, typ, declare_dest, error, optional, src_type,
raise_exception)
elif isinstance(typ, RTuple):
assert not optional
self.emit_tuple_cast(src, dest, typ, declare_dest, err, src_type)
self.emit_tuple_cast(src, dest, typ, declare_dest, error, src_type)
else:
assert False, 'Cast not implemented: %s' % typ

def emit_cast_error_handler(self,
error: ErrorHandler,
src: str,
dest: str,
typ: RType,
raise_exception: bool) -> None:
if raise_exception:
if isinstance(error, TracebackAndGotoHandler):
# Merge raising and emitting traceback entry into a single call.
self.emit_type_error_traceback(
error.source_path, error.module_name, error.traceback_entry,
typ=typ,
src=src)
self.emit_line('goto %s;' % error.label)
return
self.emit_line('CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src))
if isinstance(error, AssignHandler):
self.emit_line('%s = NULL;' % dest)
elif isinstance(error, GotoHandler):
self.emit_line('goto %s;' % error.label)
elif isinstance(error, TracebackAndGotoHandler):
self.emit_line('%s = NULL;' % dest)
self.emit_traceback(error.source_path, error.module_name, error.traceback_entry)
self.emit_line('goto %s;' % error.label)
else:
assert isinstance(error, ReturnHandler)
self.emit_line('return %s;' % error.value)

def emit_union_cast(self,
src: str,
dest: str,
typ: RUnion,
declare_dest: bool,
err: str,
error: ErrorHandler,
optional: bool,
src_type: Optional[RType]) -> None:
src_type: Optional[RType],
raise_exception: bool) -> None:
"""Emit cast to a union type.
The arguments are similar to emit_cast.
Expand All @@ -613,11 +651,11 @@ def emit_union_cast(self,
likely=False)
self.emit_line(f'if ({dest} != NULL) goto {good_label};')
# Handle cast failure.
self.emit_line(err)
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_label(good_label)

def emit_tuple_cast(self, src: str, dest: str, typ: RTuple, declare_dest: bool,
err: str, src_type: Optional[RType]) -> None:
error: ErrorHandler, src_type: Optional[RType]) -> None:
"""Emit cast to a tuple type.
The arguments are similar to emit_cast.
Expand Down Expand Up @@ -740,7 +778,8 @@ def emit_unbox(self,
self.emit_line('} else {')

cast_temp = self.temp_name()
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, err='', src_type=None)
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, error=error,
src_type=None)
self.emit_line(f'if (unlikely({cast_temp} == NULL)) {{')

# self.emit_arg_check(src, dest, typ,
Expand Down Expand Up @@ -886,3 +925,44 @@ def emit_gc_clear(self, target: str, rtype: RType) -> None:
self.emit_line(f'Py_CLEAR({target});')
else:
assert False, 'emit_gc_clear() not implemented for %s' % repr(rtype)

def emit_traceback(self,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int]) -> None:
return self._emit_traceback('CPy_AddTraceback', source_path, module_name, traceback_entry)

def emit_type_error_traceback(
self,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int],
*,
typ: RType,
src: str) -> None:
func = 'CPy_TypeErrorTraceback'
type_str = f'"{self.pretty_name(typ)}"'
return self._emit_traceback(
func, source_path, module_name, traceback_entry, type_str=type_str, src=src)

def _emit_traceback(self,
func: str,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int],
type_str: str = '',
src: str = '') -> None:
globals_static = self.static_name('globals', module_name)
line = '%s("%s", "%s", %d, %s' % (
func,
source_path.replace("\\", "\\\\"),
traceback_entry[0],
traceback_entry[1],
globals_static)
if type_str:
assert src
line += f', {type_str}, {src}'
line += ');'
self.emit_line(line)
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
35 changes: 20 additions & 15 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypyc.common import (
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
)
from mypyc.codegen.emit import Emitter
from mypyc.codegen.emit import Emitter, TracebackAndGotoHandler, DEBUG_ERRORS
from mypyc.ir.ops import (
Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
Expand All @@ -23,10 +23,6 @@
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.analysis.blockfreq import frequently_executed_blocks

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS = False


def native_function_type(fn: FuncIR, emitter: Emitter) -> str:
args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void'
Expand Down Expand Up @@ -322,7 +318,7 @@ def visit_get_attr(self, op: GetAttr) -> None:
and branch.traceback_entry is not None
and not branch.negated):
# Generate code for the following branch here to avoid
# redundant branches in the generate code.
# redundant branches in the generated code.
self.emit_attribute_error(branch, cl.name, op.attr)
self.emit_line('goto %s;' % self.label(branch.true))
merged_branch = branch
Expand Down Expand Up @@ -485,8 +481,24 @@ def visit_box(self, op: Box) -> None:
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)

def visit_cast(self, op: Cast) -> None:
branch = self.next_branch()
handler = None
if branch is not None:
if (branch.value is op
and branch.op == Branch.IS_ERROR
and branch.traceback_entry is not None
and not branch.negated
and branch.false is self.next_block):
# Generate code also for the following branch here to avoid
# redundant branches in the generated code.
handler = TracebackAndGotoHandler(self.label(branch.true),
self.source_path,
self.module_name,
branch.traceback_entry)
self.op_index += 1

self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type,
src_type=op.src.type)
src_type=op.src.type, error=handler)

def visit_unbox(self, op: Unbox) -> None:
self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type)
Expand Down Expand Up @@ -647,14 +659,7 @@ def emit_declaration(self, line: str) -> None:

def emit_traceback(self, op: Branch) -> None:
if op.traceback_entry is not None:
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
self.emitter.emit_traceback(self.source_path, self.module_name, op.traceback_entry)

def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
assert op.traceback_entry is not None
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ void _CPy_GetExcInfo(PyObject **p_type, PyObject **p_value, PyObject **p_traceba
void CPyError_OutOfMemory(void);
void CPy_TypeError(const char *expected, PyObject *value);
void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyObject *globals);
void CPy_TypeErrorTraceback(const char *filename, const char *funcname, int line,
PyObject *globals, const char *expected, PyObject *value);
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
const char *attrname, int line, PyObject *globals);

Expand Down
7 changes: 7 additions & 0 deletions mypyc/lib-rt/exc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyOb
_PyErr_ChainExceptions(exc, val, tb);
}

CPy_NOINLINE
void CPy_TypeErrorTraceback(const char *filename, const char *funcname, int line,
PyObject *globals, const char *expected, PyObject *value) {
CPy_TypeError(expected, value);
CPy_AddTraceback(filename, funcname, line, globals);
}

void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
const char *attrname, int line, PyObject *globals) {
char buf[500];
Expand Down
15 changes: 15 additions & 0 deletions mypyc/test-data/run-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,18 @@ def sub(s: str, f: Callable[[str], str]) -> str: ...
def sub(s: bytes, f: Callable[[bytes], bytes]) -> bytes: ...
def sub(s, f):
return f(s)

[case testContextManagerSpecialCase]
from typing import Generator, Callable, Iterator
from contextlib import contextmanager

@contextmanager
def f() -> Iterator[None]:
yield

def g() -> None:
a = ['']
with f():
a.pop()

g()
Loading

0 comments on commit 040f3ab

Please sign in to comment.