Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.8] bpo-16575: Add checks for unions passed by value to functions. … #17016

Merged
merged 1 commit into from
Oct 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions Lib/ctypes/test/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,86 @@ class U(Union):
self.assertEqual(f2, [0x4567, 0x0123, 0xcdef, 0x89ab,
0x3210, 0x7654, 0xba98, 0xfedc])

def test_union_by_value(self):
# See bpo-16575

# These should mirror the structures in Modules/_ctypes/_ctypes_test.c

class Nested1(Structure):
_fields_ = [
('an_int', c_int),
('another_int', c_int),
]

class Test4(Union):
_fields_ = [
('a_long', c_long),
('a_struct', Nested1),
]

class Nested2(Structure):
_fields_ = [
('an_int', c_int),
('a_union', Test4),
]

class Test5(Structure):
_fields_ = [
('an_int', c_int),
('nested', Nested2),
('another_int', c_int),
]

test4 = Test4()
dll = CDLL(_ctypes_test.__file__)
with self.assertRaises(TypeError) as ctx:
func = dll._testfunc_union_by_value1
func.restype = c_long
func.argtypes = (Test4,)
result = func(test4)
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
'a union by value, which is unsupported.')
test5 = Test5()
with self.assertRaises(TypeError) as ctx:
func = dll._testfunc_union_by_value2
func.restype = c_long
func.argtypes = (Test5,)
result = func(test5)
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
'a union by value, which is unsupported.')

# passing by reference should be OK
test4.a_long = 12345;
func = dll._testfunc_union_by_reference1
func.restype = c_long
func.argtypes = (POINTER(Test4),)
result = func(byref(test4))
self.assertEqual(result, 12345)
self.assertEqual(test4.a_long, 0)
self.assertEqual(test4.a_struct.an_int, 0)
self.assertEqual(test4.a_struct.another_int, 0)
test4.a_struct.an_int = 0x12340000
test4.a_struct.another_int = 0x5678
func = dll._testfunc_union_by_reference2
func.restype = c_long
func.argtypes = (POINTER(Test4),)
result = func(byref(test4))
self.assertEqual(result, 0x12345678)
self.assertEqual(test4.a_long, 0)
self.assertEqual(test4.a_struct.an_int, 0)
self.assertEqual(test4.a_struct.another_int, 0)
test5.an_int = 0x12000000
test5.nested.an_int = 0x345600
test5.another_int = 0x78
func = dll._testfunc_union_by_reference3
func.restype = c_long
func.argtypes = (POINTER(Test5),)
result = func(byref(test5))
self.assertEqual(result, 0x12345678)
self.assertEqual(test5.an_int, 0)
self.assertEqual(test5.nested.an_int, 0)
self.assertEqual(test5.another_int, 0)

class PointerMemberTestCase(unittest.TestCase):

def test(self):
Expand Down
24 changes: 24 additions & 0 deletions Modules/_ctypes/_ctypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ StructUnionType_new(PyTypeObject *type, PyObject *args, PyObject *kwds, int isSt
Py_DECREF(result);
return NULL;
}
if (!isStruct) {
dict->flags |= TYPEFLAG_HASUNION;
}
/* replace the class dict by our updated stgdict, which holds info
about storage requirements of the instances */
if (-1 == PyDict_Update((PyObject *)dict, result->tp_dict)) {
Expand Down Expand Up @@ -2383,6 +2386,27 @@ converters_from_argtypes(PyObject *ob)
for (i = 0; i < nArgs; ++i) {
PyObject *cnv;
PyObject *tp = PyTuple_GET_ITEM(ob, i);
StgDictObject *stgdict = PyType_stgdict(tp);

if (stgdict != NULL) {
if (stgdict->flags & TYPEFLAG_HASUNION) {
Py_DECREF(converters);
Py_DECREF(ob);
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_TypeError,
"item %zd in _argtypes_ passes a union by "
"value, which is unsupported.",
i + 1);
}
return NULL;
}
/*
if (stgdict->flags & TYPEFLAG_HASBITFIELD) {
printf("found stgdict with bitfield\n");
}
*/
}

if (_PyObject_LookupAttrId(tp, &PyId_from_param, &cnv) <= 0) {
Py_DECREF(converters);
Py_DECREF(ob);
Expand Down
63 changes: 63 additions & 0 deletions Modules/_ctypes/_ctypes_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,69 @@ _testfunc_array_in_struct2a(Test3B in)
return result;
}

typedef union {
long a_long;
struct {
int an_int;
int another_int;
} a_struct;
} Test4;

typedef struct {
int an_int;
struct {
int an_int;
Test4 a_union;
} nested;
int another_int;
} Test5;

EXPORT(long)
_testfunc_union_by_value1(Test4 in) {
long result = in.a_long + in.a_struct.an_int + in.a_struct.another_int;

/* As the union/struct are passed by value, changes to them shouldn't be
* reflected in the caller.
*/
memset(&in, 0, sizeof(in));
return result;
}

EXPORT(long)
_testfunc_union_by_value2(Test5 in) {
long result = in.an_int + in.nested.an_int;

/* As the union/struct are passed by value, changes to them shouldn't be
* reflected in the caller.
*/
memset(&in, 0, sizeof(in));
return result;
}

EXPORT(long)
_testfunc_union_by_reference1(Test4 *in) {
long result = in->a_long;

memset(in, 0, sizeof(Test4));
return result;
}

EXPORT(long)
_testfunc_union_by_reference2(Test4 *in) {
long result = in->a_struct.an_int + in->a_struct.another_int;

memset(in, 0, sizeof(Test4));
return result;
}

EXPORT(long)
_testfunc_union_by_reference3(Test5 *in) {
long result = in->an_int + in->nested.an_int + in->another_int;

memset(in, 0, sizeof(Test5));
return result;
}

EXPORT(void)testfunc_array(int values[4])
{
printf("testfunc_array %d %d %d %d\n",
Expand Down
2 changes: 2 additions & 0 deletions Modules/_ctypes/ctypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ PyObject *_ctypes_callproc(PPROC pProc,

#define TYPEFLAG_ISPOINTER 0x100
#define TYPEFLAG_HASPOINTER 0x200
#define TYPEFLAG_HASUNION 0x400
#define TYPEFLAG_HASBITFIELD 0x800

#define DICTFLAG_FINAL 0x1000

Expand Down
9 changes: 9 additions & 0 deletions Modules/_ctypes/stgdict.c
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,13 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
PyMem_Free(stgdict->ffi_type_pointer.elements);

basedict = PyType_stgdict((PyObject *)((PyTypeObject *)type)->tp_base);
if (basedict) {
stgdict->flags |= (basedict->flags &
(TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD));
}
if (!isStruct) {
stgdict->flags |= TYPEFLAG_HASUNION;
}
if (basedict && !use_broken_old_ctypes_semantics) {
size = offset = basedict->size;
align = basedict->align;
Expand Down Expand Up @@ -515,8 +522,10 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
stgdict->ffi_type_pointer.elements[ffi_ofs + i] = &dict->ffi_type_pointer;
if (dict->flags & (TYPEFLAG_ISPOINTER | TYPEFLAG_HASPOINTER))
stgdict->flags |= TYPEFLAG_HASPOINTER;
stgdict->flags |= dict->flags & (TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD);
dict->flags |= DICTFLAG_FINAL; /* mark field type final */
if (PyTuple_Size(pair) == 3) { /* bits specified */
stgdict->flags |= TYPEFLAG_HASBITFIELD;
switch(dict->ffi_type_pointer.type) {
case FFI_TYPE_UINT8:
case FFI_TYPE_UINT16:
Expand Down