Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Raise ImportError instead of AttributeError when from-import fails #10641

Merged
merged 7 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,20 @@ def add_to_non_ext_dict(self, non_ext: NonExtClassInfo,
key_unicode = self.load_str(key)
self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line)

def gen_import_from(self, id: str, line: int, imported: List[str]) -> None:
def gen_import_from(self, id: str, globals_dict: Value,
imported: List[str], line: int) -> Value:
self.imports[id] = None

globals_dict = self.load_globals_dict()
null = Integer(0, dict_rprimitive, line)
null_dict = Integer(0, dict_rprimitive, line)
names_to_import = self.new_list_op([self.load_str(name) for name in imported], line)

level = Integer(0, c_int_rprimitive, line)
zero_int = Integer(0, c_int_rprimitive, line)
value = self.call_c(
import_extra_args_op,
[self.load_str(id), globals_dict, null, names_to_import, level],
[self.load_str(id), globals_dict, null_dict, names_to_import, zero_int],
line,
)
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))
return value

def gen_import(self, id: str, line: int) -> None:
self.imports[id] = None
Expand Down
14 changes: 8 additions & 6 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from mypyc.ir.rtypes import RInstance, exc_rtuple
from mypyc.primitives.generic_ops import py_delattr_op
from mypyc.primitives.misc_ops import type_op
from mypyc.primitives.misc_ops import type_op, import_from_op
from mypyc.primitives.exc_ops import (
raise_exception_op, reraise_exception_op, error_catch_op, exc_matches_op, restore_exc_info_op,
get_exc_value_op, keep_propagating_op, get_exc_info_op
Expand Down Expand Up @@ -172,18 +172,20 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None:

id = importlib.util.resolve_name('.' * node.relative + node.id, module_package)

imported = [name for name, _ in node.names]
builder.gen_import_from(id, node.line, imported)
module = builder.load_module(id)
globals = builder.load_globals_dict()
imported_names = [name for name, _ in node.names]
module = builder.gen_import_from(id, globals, imported_names, node.line)

# Copy everything into our module's dict.
# Note that we miscompile import from inside of functions here,
# since that case *shouldn't* load it into the globals dict.
# This probably doesn't matter much and the code runs basically right.
globals = builder.load_globals_dict()
for name, maybe_as_name in node.names:
as_name = maybe_as_name or name
obj = builder.py_get_attr(module, name, node.line)
obj = builder.call_c(import_from_op,
[module, builder.load_str(id),
builder.load_str(name), builder.load_str(as_name)],
node.line)
builder.gen_method_call(
globals, '__setitem__', [builder.load_str(as_name), obj],
result_type=None, line=node.line)
Expand Down
3 changes: 3 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ int CPyStatics_Initialize(PyObject **statics,
const int *tuples);
PyObject *CPy_Super(PyObject *builtins, PyObject *self);

PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
PyObject *import_name, PyObject *as_name);

#ifdef __cplusplus
}
#endif
Expand Down
33 changes: 33 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -643,3 +643,36 @@ CPy_Super(PyObject *builtins, PyObject *self) {
Py_DECREF(super_type);
return result;
}

// This helper function is a simplification of cpython/ceval.c/import_from()
PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
PyObject *import_name, PyObject *as_name) {
// check if the imported module has an attribute by that name
PyObject *x = PyObject_GetAttr(module, import_name);
if (x == NULL) {
// if not, attempt to import a submodule with that name
PyObject *fullmodname = PyUnicode_FromFormat("%U.%U", package_name, import_name);
if (fullmodname == NULL) {
goto fail;
}

// The following code is a simplification of cpython/import.c/PyImport_GetModule()
x = PyObject_GetItem(module, fullmodname);
Py_DECREF(fullmodname);
if (x == NULL) {
goto fail;
}
}
return x;

fail:
PyErr_Clear();
PyObject *package_path = PyModule_GetFilenameObject(module);
PyObject *errmsg = PyUnicode_FromFormat("cannot import name %R from %R (%S)",
import_name, package_name, package_path);
// NULL checks for errmsg and package_name done by PyErr_SetImportError.
PyErr_SetImportError(errmsg, package_name, package_path);
Py_DECREF(package_path);
Py_DECREF(errmsg);
return NULL;
}
9 changes: 9 additions & 0 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@
error_kind=ERR_MAGIC
)

# Import-from helper op
import_from_op = custom_op(
arg_types=[object_rprimitive, str_rprimitive,
str_rprimitive, str_rprimitive],
return_type=object_rprimitive,
c_function_name='CPyImport_ImportFrom',
error_kind=ERR_MAGIC
)

# Get the sys.modules dictionary
get_module_dict_op = custom_op(
arg_types=[],
Expand Down
2 changes: 2 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ class ValueError(Exception): pass

class AttributeError(Exception): pass

class ImportError(Exception): pass

class NameError(Exception): pass

class LookupError(Exception): pass
Expand Down
Loading