|
3 | 3 |
|
4 | 4 | #include "Python.h" |
5 | 5 | #include "pycore_call.h" // _PyObject_CallNoArgs() |
| 6 | +#include "pycore_fileutils.h" // _Py_wgetcwd |
6 | 7 | #include "pycore_interp.h" // PyInterpreterState.importlib |
7 | 8 | #include "pycore_modsupport.h" // _PyModule_CreateInitialized() |
8 | 9 | #include "pycore_moduleobject.h" // _PyModule_GetDef() |
9 | 10 | #include "pycore_object.h" // _PyType_AllocNoTrack |
10 | 11 | #include "pycore_pyerrors.h" // _PyErr_FormatFromCause() |
11 | 12 | #include "pycore_pystate.h" // _PyInterpreterState_GET() |
12 | 13 |
|
| 14 | +#include "osdefs.h" // MAXPATHLEN |
| 15 | +#include "Python/stdlib_module_names.h" // _Py_stdlib_module_names |
13 | 16 |
|
14 | 17 |
|
15 | 18 | static PyMemberDef module_members[] = { |
@@ -784,6 +787,20 @@ _PyModuleSpec_IsUninitializedSubmodule(PyObject *spec, PyObject *name) |
784 | 787 | return rc; |
785 | 788 | } |
786 | 789 |
|
| 790 | +// TODO: deduplicate with suggestions.c. Where should this go? |
| 791 | +static bool |
| 792 | +is_name_stdlib_module(PyObject* name) |
| 793 | +{ |
| 794 | + const char* the_name = PyUnicode_AsUTF8(name); |
| 795 | + Py_ssize_t len = Py_ARRAY_LENGTH(_Py_stdlib_module_names); |
| 796 | + for (Py_ssize_t i = 0; i < len; i++) { |
| 797 | + if (strcmp(the_name, _Py_stdlib_module_names[i]) == 0) { |
| 798 | + return 1; |
| 799 | + } |
| 800 | + } |
| 801 | + return 0; |
| 802 | +} |
| 803 | + |
787 | 804 | PyObject* |
788 | 805 | _Py_module_getattro_impl(PyModuleObject *m, PyObject *name, int suppress) |
789 | 806 | { |
@@ -854,62 +871,34 @@ _Py_module_getattro_impl(PyModuleObject *m, PyObject *name, int suppress) |
854 | 871 | int is_script_shadowing_stdlib = 0; |
855 | 872 | // Check mod.__name__ in sys.stdlib_module_names |
856 | 873 | // and os.path.dirname(mod.__spec__.origin) == os.getcwd() |
857 | | - PyObject *stdlib = NULL; |
858 | | - if (origin) { |
859 | | - // Checks against mod_name are to avoid bad recursion |
860 | | - if ( |
861 | | - PyUnicode_CompareWithASCIIString(mod_name, "sys") != 0 |
862 | | - && PyUnicode_CompareWithASCIIString(mod_name, "builtins") != 0 |
863 | | - ) { |
864 | | - stdlib = _PyImport_GetModuleAttrString("sys", "stdlib_module_names"); |
865 | | - if (!stdlib) { |
866 | | - if (PyErr_ExceptionMatches(PyExc_AttributeError)) { |
867 | | - PyErr_Clear(); |
868 | | - } else { |
869 | | - goto done; |
870 | | - } |
| 874 | + if (origin && is_name_stdlib_module(mod_name)) { |
| 875 | + wchar_t cwdbuf[MAXPATHLEN]; |
| 876 | + if(_Py_wgetcwd(cwdbuf, MAXPATHLEN)) { |
| 877 | + PyObject *cwd = PyUnicode_FromWideChar(cwdbuf, wcslen(cwdbuf)); |
| 878 | + if (!cwd) { |
| 879 | + goto done; |
871 | 880 | } |
872 | | - if (stdlib && PyFrozenSet_Check(stdlib) && PySet_Contains(stdlib, mod_name)) { |
873 | | - if ( |
874 | | - PyUnicode_CompareWithASCIIString(mod_name, "os") != 0 |
875 | | - && PyUnicode_CompareWithASCIIString(mod_name, "posixpath") != 0 |
876 | | - && PyUnicode_CompareWithASCIIString(mod_name, "ntpath") != 0 |
877 | | - ) { |
878 | | - PyObject *os_path = _PyImport_GetModuleAttrString("os", "path"); |
879 | | - if (!os_path) { |
880 | | - goto done; |
881 | | - } |
882 | | - PyObject *dirname = PyObject_GetAttrString(os_path, "dirname"); |
883 | | - Py_DECREF(os_path); |
884 | | - if (!dirname) { |
885 | | - goto done; |
886 | | - } |
887 | | - PyObject *origin_dir = _PyObject_CallOneArg(dirname, origin); |
888 | | - Py_DECREF(dirname); |
889 | | - if (!origin_dir) { |
890 | | - goto done; |
891 | | - } |
892 | | - |
893 | | - PyObject *getcwd = _PyImport_GetModuleAttrString("os", "getcwd"); |
894 | | - if (!getcwd) { |
895 | | - Py_DECREF(origin_dir); |
896 | | - goto done; |
897 | | - } |
898 | | - PyObject *cwd = _PyObject_CallNoArgs(getcwd); |
899 | | - Py_DECREF(getcwd); |
900 | | - if (!cwd) { |
901 | | - Py_DECREF(origin_dir); |
902 | | - goto done; |
903 | | - } |
904 | | - |
905 | | - is_script_shadowing_stdlib = PyObject_RichCompareBool(origin_dir, cwd, Py_EQ); |
906 | | - Py_DECREF(origin_dir); |
907 | | - Py_DECREF(cwd); |
908 | | - if (is_script_shadowing_stdlib < 0) { |
909 | | - goto done; |
910 | | - } |
911 | | - } |
| 881 | + const char sep_char = SEP; |
| 882 | + PyObject *sep = PyUnicode_FromStringAndSize(&sep_char, 1); |
| 883 | + if (!sep) { |
| 884 | + Py_DECREF(cwd); |
| 885 | + goto done; |
| 886 | + } |
| 887 | + PyObject *parts = PyUnicode_RPartition(origin, sep); |
| 888 | + Py_DECREF(sep); |
| 889 | + if (!parts) { |
| 890 | + Py_DECREF(cwd); |
| 891 | + goto done; |
| 892 | + } |
| 893 | + int rc = PyUnicode_Compare(cwd, PyTuple_GET_ITEM(parts, 0)); |
| 894 | + if (rc == -1 && PyErr_Occurred()) { |
| 895 | + Py_DECREF(parts); |
| 896 | + Py_DECREF(cwd); |
| 897 | + goto done; |
912 | 898 | } |
| 899 | + is_script_shadowing_stdlib = rc == 0; |
| 900 | + Py_DECREF(parts); |
| 901 | + Py_DECREF(cwd); |
913 | 902 | } |
914 | 903 | } |
915 | 904 |
|
@@ -954,7 +943,6 @@ _Py_module_getattro_impl(PyModuleObject *m, PyObject *name, int suppress) |
954 | 943 | } |
955 | 944 |
|
956 | 945 | done: |
957 | | - Py_XDECREF(stdlib); |
958 | 946 | Py_XDECREF(origin); |
959 | 947 | Py_XDECREF(spec); |
960 | 948 | Py_DECREF(mod_name); |
|
0 commit comments