Skip to content

Commit

Permalink
[mypyc] Raise ImportError instead of AttributeError when from-import …
Browse files Browse the repository at this point in the history
…fails (#10641)

Fixes mypyc/mypyc#707.
  • Loading branch information
97littleleaf11 committed Jun 16, 2021
1 parent 790ab35 commit a9fda34
Show file tree
Hide file tree
Showing 9 changed files with 392 additions and 315 deletions.
12 changes: 6 additions & 6 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,20 @@ def add_to_non_ext_dict(self, non_ext: NonExtClassInfo,
key_unicode = self.load_str(key)
self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line)

def gen_import_from(self, id: str, line: int, imported: List[str]) -> None:
def gen_import_from(self, id: str, globals_dict: Value,
imported: List[str], line: int) -> Value:
self.imports[id] = None

globals_dict = self.load_globals_dict()
null = Integer(0, dict_rprimitive, line)
null_dict = Integer(0, dict_rprimitive, line)
names_to_import = self.new_list_op([self.load_str(name) for name in imported], line)

level = Integer(0, c_int_rprimitive, line)
zero_int = Integer(0, c_int_rprimitive, line)
value = self.call_c(
import_extra_args_op,
[self.load_str(id), globals_dict, null, names_to_import, level],
[self.load_str(id), globals_dict, null_dict, names_to_import, zero_int],
line,
)
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))
return value

def gen_import(self, id: str, line: int) -> None:
self.imports[id] = None
Expand Down
14 changes: 8 additions & 6 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from mypyc.ir.rtypes import RInstance, exc_rtuple
from mypyc.primitives.generic_ops import py_delattr_op
from mypyc.primitives.misc_ops import type_op
from mypyc.primitives.misc_ops import type_op, import_from_op
from mypyc.primitives.exc_ops import (
raise_exception_op, reraise_exception_op, error_catch_op, exc_matches_op, restore_exc_info_op,
get_exc_value_op, keep_propagating_op, get_exc_info_op
Expand Down Expand Up @@ -172,18 +172,20 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None:

id = importlib.util.resolve_name('.' * node.relative + node.id, module_package)

imported = [name for name, _ in node.names]
builder.gen_import_from(id, node.line, imported)
module = builder.load_module(id)
globals = builder.load_globals_dict()
imported_names = [name for name, _ in node.names]
module = builder.gen_import_from(id, globals, imported_names, node.line)

# Copy everything into our module's dict.
# Note that we miscompile import from inside of functions here,
# since that case *shouldn't* load it into the globals dict.
# This probably doesn't matter much and the code runs basically right.
globals = builder.load_globals_dict()
for name, maybe_as_name in node.names:
as_name = maybe_as_name or name
obj = builder.py_get_attr(module, name, node.line)
obj = builder.call_c(import_from_op,
[module, builder.load_str(id),
builder.load_str(name), builder.load_str(as_name)],
node.line)
builder.gen_method_call(
globals, '__setitem__', [builder.load_str(as_name), obj],
result_type=None, line=node.line)
Expand Down
3 changes: 3 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ int CPyStatics_Initialize(PyObject **statics,
const int *tuples);
PyObject *CPy_Super(PyObject *builtins, PyObject *self);

PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
PyObject *import_name, PyObject *as_name);

#ifdef __cplusplus
}
#endif
Expand Down
33 changes: 33 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -643,3 +643,36 @@ CPy_Super(PyObject *builtins, PyObject *self) {
Py_DECREF(super_type);
return result;
}

// This helper function is a simplification of cpython/ceval.c/import_from()
PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
PyObject *import_name, PyObject *as_name) {
// check if the imported module has an attribute by that name
PyObject *x = PyObject_GetAttr(module, import_name);
if (x == NULL) {
// if not, attempt to import a submodule with that name
PyObject *fullmodname = PyUnicode_FromFormat("%U.%U", package_name, import_name);
if (fullmodname == NULL) {
goto fail;
}

// The following code is a simplification of cpython/import.c/PyImport_GetModule()
x = PyObject_GetItem(module, fullmodname);
Py_DECREF(fullmodname);
if (x == NULL) {
goto fail;
}
}
return x;

fail:
PyErr_Clear();
PyObject *package_path = PyModule_GetFilenameObject(module);
PyObject *errmsg = PyUnicode_FromFormat("cannot import name %R from %R (%S)",
import_name, package_name, package_path);
// NULL checks for errmsg and package_name done by PyErr_SetImportError.
PyErr_SetImportError(errmsg, package_name, package_path);
Py_DECREF(package_path);
Py_DECREF(errmsg);
return NULL;
}
9 changes: 9 additions & 0 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@
error_kind=ERR_MAGIC
)

# Import-from helper op
import_from_op = custom_op(
arg_types=[object_rprimitive, str_rprimitive,
str_rprimitive, str_rprimitive],
return_type=object_rprimitive,
c_function_name='CPyImport_ImportFrom',
error_kind=ERR_MAGIC
)

# Get the sys.modules dictionary
get_module_dict_op = custom_op(
arg_types=[],
Expand Down
2 changes: 2 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ class ValueError(Exception): pass

class AttributeError(Exception): pass

class ImportError(Exception): pass

class NameError(Exception): pass

class LookupError(Exception): pass
Expand Down
Loading

0 comments on commit a9fda34

Please sign in to comment.