Skip to content

bpo-41194: Convert _ast extension to PEP 489 #21293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 73 additions & 46 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ def visitModule(self, mod):
int res = -1;
PyObject *key, *value, *fields;
astmodulestate *state = get_global_ast_state();
if (state == NULL) {
goto cleanup;
}
if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
Expand Down Expand Up @@ -761,6 +764,10 @@ def visitModule(self, mod):
ast_type_reduce(PyObject *self, PyObject *unused)
{
astmodulestate *state = get_global_ast_state();
if (state == NULL) {
return NULL;
}

PyObject *dict;
if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
return NULL;
Expand Down Expand Up @@ -969,9 +976,8 @@ def visitModule(self, mod):

""", 0, reflow=False)

self.emit("static int init_types(void)",0)
self.emit("static int init_types(astmodulestate *state)",0)
self.emit("{", 0)
self.emit("astmodulestate *state = get_global_ast_state();", 1)
self.emit("if (state->initialized) return 1;", 1)
self.emit("if (init_identifiers(state) < 0) return 0;", 1)
self.emit("state->AST_type = PyType_FromSpec(&AST_type_spec);", 1)
Expand Down Expand Up @@ -1046,40 +1052,55 @@ def emit_defaults(self, name, fields, depth):
class ASTModuleVisitor(PickleVisitor):

def visitModule(self, mod):
self.emit("PyMODINIT_FUNC", 0)
self.emit("PyInit__ast(void)", 0)
self.emit("static int", 0)
self.emit("astmodule_exec(PyObject *m)", 0)
self.emit("{", 0)
self.emit("PyObject *m = PyModule_Create(&_astmodule);", 1)
self.emit("if (!m) {", 1)
self.emit("return NULL;", 2)
self.emit("}", 1)
self.emit('astmodulestate *state = get_ast_state(m);', 1)
self.emit('', 1)
self.emit("", 0)

self.emit("if (!init_types()) {", 1)
self.emit("goto error;", 2)
self.emit("if (!init_types(state)) {", 1)
self.emit("return -1;", 2)
self.emit("}", 1)
self.emit('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {', 1)
self.emit('goto error;', 2)
self.emit('return -1;', 2)
self.emit('}', 1)
self.emit('Py_INCREF(state->AST_type);', 1)
self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1)
self.emit("goto error;", 2)
self.emit("return -1;", 2)
self.emit('}', 1)
self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1)
self.emit("goto error;", 2)
self.emit("return -1;", 2)
self.emit('}', 1)
self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1)
self.emit("goto error;", 2)
self.emit("return -1;", 2)
self.emit('}', 1)
for dfn in mod.dfns:
self.visit(dfn)
self.emit("return m;", 1)
self.emit("", 0)
self.emit("error:", 0)
self.emit("Py_DECREF(m);", 1)
self.emit("return NULL;", 1)
self.emit("return 0;", 1)
self.emit("}", 0)
self.emit("", 0)
self.emit("""
static PyModuleDef_Slot astmodule_slots[] = {
{Py_mod_exec, astmodule_exec},
{0, NULL}
};

static struct PyModuleDef _astmodule = {
PyModuleDef_HEAD_INIT,
.m_name = "_ast",
.m_size = sizeof(astmodulestate),
.m_slots = astmodule_slots,
.m_traverse = astmodule_traverse,
.m_clear = astmodule_clear,
.m_free = astmodule_free,
};

PyMODINIT_FUNC
PyInit__ast(void)
{
return PyModuleDef_Init(&_astmodule);
}
""".strip(), 0, reflow=False)

def visitProduct(self, prod, name):
self.addObj(name)
Expand All @@ -1095,7 +1116,7 @@ def visitConstructor(self, cons, name):
def addObj(self, name):
self.emit("if (PyModule_AddObject(m, \"%s\", "
"state->%s_type) < 0) {" % (name, name), 1)
self.emit("goto error;", 2)
self.emit("return -1;", 2)
self.emit('}', 1)
self.emit("Py_INCREF(state->%s_type);" % name, 1)

Expand Down Expand Up @@ -1255,11 +1276,10 @@ class PartingShots(StaticVisitor):
CODE = """
PyObject* PyAST_mod2obj(mod_ty t)
{
if (!init_types()) {
astmodulestate *state = get_global_ast_state();
if (state == NULL) {
return NULL;
}

astmodulestate *state = get_global_ast_state();
return ast2obj_mod(state, t);
}

Expand All @@ -1281,10 +1301,6 @@ class PartingShots(StaticVisitor):

assert(0 <= mode && mode <= 2);

if (!init_types()) {
return NULL;
}

isinstance = PyObject_IsInstance(ast, req_type[mode]);
if (isinstance == -1)
return NULL;
Expand All @@ -1303,11 +1319,10 @@ class PartingShots(StaticVisitor):

int PyAST_Check(PyObject* obj)
{
if (!init_types()) {
astmodulestate *state = get_global_ast_state();
if (state == NULL) {
return -1;
}

astmodulestate *state = get_global_ast_state();
return PyObject_IsInstance(obj, state->AST_type);
}
"""
Expand Down Expand Up @@ -1358,12 +1373,35 @@ def generate_module_def(f, mod):
f.write(' PyObject *' + s + ';\n')
f.write('} astmodulestate;\n\n')
f.write("""
static astmodulestate global_ast_state;
static astmodulestate*
get_ast_state(PyObject *module)
{
void *state = PyModule_GetState(module);
assert(state != NULL);
return (astmodulestate*)state;
}

static astmodulestate *
get_ast_state(PyObject *Py_UNUSED(module))
static astmodulestate*
get_global_ast_state(void)
{
return &global_ast_state;
_Py_IDENTIFIER(_ast);
PyObject *name = _PyUnicode_FromId(&PyId__ast); // borrowed reference
if (name == NULL) {
return NULL;
}
PyObject *module = PyImport_GetModule(name);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted in https://bugs.python.org/issue41631, this is incorrect :(
PyImport_GetModule looks in sys.modules, which is modifiable by the user, and so it's not guaranteed to return _ast.

if (module == NULL) {
if (PyErr_Occurred()) {
return NULL;
}
module = PyImport_Import(name);
if (module == NULL) {
return NULL;
}
}
astmodulestate *state = get_ast_state(module);
Py_DECREF(module);
return state;
}

static int astmodule_clear(PyObject *module)
Expand All @@ -1390,17 +1428,6 @@ def generate_module_def(f, mod):
astmodule_clear((PyObject*)module);
}

static struct PyModuleDef _astmodule = {
PyModuleDef_HEAD_INIT,
.m_name = "_ast",
.m_size = -1,
.m_traverse = astmodule_traverse,
.m_clear = astmodule_clear,
.m_free = astmodule_free,
};

#define get_global_ast_state() (&global_ast_state)

""")
f.write('static int init_identifiers(astmodulestate *state)\n')
f.write('{\n')
Expand Down
Loading