Skip to content

Commit 0118d10

Browse files
authored
[3.7] bpo-16575: Add checks for unions passed by value to functions. (GH-16799) (GH-17017)
(cherry picked from commit 79d4ed1)
1 parent a28cf14 commit 0118d10

File tree

5 files changed

+177
-0
lines changed

5 files changed

+177
-0
lines changed

Lib/ctypes/test/test_structures.py

+80
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,86 @@ class U(Union):
532532
self.assertEqual(f2, [0x4567, 0x0123, 0xcdef, 0x89ab,
533533
0x3210, 0x7654, 0xba98, 0xfedc])
534534

535+
def test_union_by_value(self):
536+
# See bpo-16575
537+
538+
# These should mirror the structures in Modules/_ctypes/_ctypes_test.c
539+
540+
class Nested1(Structure):
541+
_fields_ = [
542+
('an_int', c_int),
543+
('another_int', c_int),
544+
]
545+
546+
class Test4(Union):
547+
_fields_ = [
548+
('a_long', c_long),
549+
('a_struct', Nested1),
550+
]
551+
552+
class Nested2(Structure):
553+
_fields_ = [
554+
('an_int', c_int),
555+
('a_union', Test4),
556+
]
557+
558+
class Test5(Structure):
559+
_fields_ = [
560+
('an_int', c_int),
561+
('nested', Nested2),
562+
('another_int', c_int),
563+
]
564+
565+
test4 = Test4()
566+
dll = CDLL(_ctypes_test.__file__)
567+
with self.assertRaises(TypeError) as ctx:
568+
func = dll._testfunc_union_by_value1
569+
func.restype = c_long
570+
func.argtypes = (Test4,)
571+
result = func(test4)
572+
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
573+
'a union by value, which is unsupported.')
574+
test5 = Test5()
575+
with self.assertRaises(TypeError) as ctx:
576+
func = dll._testfunc_union_by_value2
577+
func.restype = c_long
578+
func.argtypes = (Test5,)
579+
result = func(test5)
580+
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
581+
'a union by value, which is unsupported.')
582+
583+
# passing by reference should be OK
584+
test4.a_long = 12345;
585+
func = dll._testfunc_union_by_reference1
586+
func.restype = c_long
587+
func.argtypes = (POINTER(Test4),)
588+
result = func(byref(test4))
589+
self.assertEqual(result, 12345)
590+
self.assertEqual(test4.a_long, 0)
591+
self.assertEqual(test4.a_struct.an_int, 0)
592+
self.assertEqual(test4.a_struct.another_int, 0)
593+
test4.a_struct.an_int = 0x12340000
594+
test4.a_struct.another_int = 0x5678
595+
func = dll._testfunc_union_by_reference2
596+
func.restype = c_long
597+
func.argtypes = (POINTER(Test4),)
598+
result = func(byref(test4))
599+
self.assertEqual(result, 0x12345678)
600+
self.assertEqual(test4.a_long, 0)
601+
self.assertEqual(test4.a_struct.an_int, 0)
602+
self.assertEqual(test4.a_struct.another_int, 0)
603+
test5.an_int = 0x12000000
604+
test5.nested.an_int = 0x345600
605+
test5.another_int = 0x78
606+
func = dll._testfunc_union_by_reference3
607+
func.restype = c_long
608+
func.argtypes = (POINTER(Test5),)
609+
result = func(byref(test5))
610+
self.assertEqual(result, 0x12345678)
611+
self.assertEqual(test5.an_int, 0)
612+
self.assertEqual(test5.nested.an_int, 0)
613+
self.assertEqual(test5.another_int, 0)
614+
535615
class PointerMemberTestCase(unittest.TestCase):
536616

537617
def test(self):

Modules/_ctypes/_ctypes.c

+23
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ StructUnionType_new(PyTypeObject *type, PyObject *args, PyObject *kwds, int isSt
446446
Py_DECREF(result);
447447
return NULL;
448448
}
449+
if (!isStruct) {
450+
dict->flags |= TYPEFLAG_HASUNION;
451+
}
449452
/* replace the class dict by our updated stgdict, which holds info
450453
about storage requirements of the instances */
451454
if (-1 == PyDict_Update((PyObject *)dict, result->tp_dict)) {
@@ -2276,6 +2279,26 @@ converters_from_argtypes(PyObject *ob)
22762279
PyObject *cnv = PyObject_GetAttrString(tp, "from_param");
22772280
if (!cnv)
22782281
goto argtypes_error_1;
2282+
StgDictObject *stgdict = PyType_stgdict(tp);
2283+
2284+
if (stgdict != NULL) {
2285+
if (stgdict->flags & TYPEFLAG_HASUNION) {
2286+
Py_DECREF(converters);
2287+
Py_DECREF(ob);
2288+
if (!PyErr_Occurred()) {
2289+
PyErr_Format(PyExc_TypeError,
2290+
"item %zd in _argtypes_ passes a union by "
2291+
"value, which is unsupported.",
2292+
i + 1);
2293+
}
2294+
return NULL;
2295+
}
2296+
/*
2297+
if (stgdict->flags & TYPEFLAG_HASBITFIELD) {
2298+
printf("found stgdict with bitfield\n");
2299+
}
2300+
*/
2301+
}
22792302
PyTuple_SET_ITEM(converters, i, cnv);
22802303
}
22812304
Py_DECREF(ob);

Modules/_ctypes/_ctypes_test.c

+63
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,69 @@ _testfunc_array_in_struct2a(Test3B in)
135135
return result;
136136
}
137137

138+
typedef union {
139+
long a_long;
140+
struct {
141+
int an_int;
142+
int another_int;
143+
} a_struct;
144+
} Test4;
145+
146+
typedef struct {
147+
int an_int;
148+
struct {
149+
int an_int;
150+
Test4 a_union;
151+
} nested;
152+
int another_int;
153+
} Test5;
154+
155+
EXPORT(long)
156+
_testfunc_union_by_value1(Test4 in) {
157+
long result = in.a_long + in.a_struct.an_int + in.a_struct.another_int;
158+
159+
/* As the union/struct are passed by value, changes to them shouldn't be
160+
* reflected in the caller.
161+
*/
162+
memset(&in, 0, sizeof(in));
163+
return result;
164+
}
165+
166+
EXPORT(long)
167+
_testfunc_union_by_value2(Test5 in) {
168+
long result = in.an_int + in.nested.an_int;
169+
170+
/* As the union/struct are passed by value, changes to them shouldn't be
171+
* reflected in the caller.
172+
*/
173+
memset(&in, 0, sizeof(in));
174+
return result;
175+
}
176+
177+
EXPORT(long)
178+
_testfunc_union_by_reference1(Test4 *in) {
179+
long result = in->a_long;
180+
181+
memset(in, 0, sizeof(Test4));
182+
return result;
183+
}
184+
185+
EXPORT(long)
186+
_testfunc_union_by_reference2(Test4 *in) {
187+
long result = in->a_struct.an_int + in->a_struct.another_int;
188+
189+
memset(in, 0, sizeof(Test4));
190+
return result;
191+
}
192+
193+
EXPORT(long)
194+
_testfunc_union_by_reference3(Test5 *in) {
195+
long result = in->an_int + in->nested.an_int + in->another_int;
196+
197+
memset(in, 0, sizeof(Test5));
198+
return result;
199+
}
200+
138201
EXPORT(void)testfunc_array(int values[4])
139202
{
140203
printf("testfunc_array %d %d %d %d\n",

Modules/_ctypes/ctypes.h

+2
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ PyObject *_ctypes_callproc(PPROC pProc,
288288

289289
#define TYPEFLAG_ISPOINTER 0x100
290290
#define TYPEFLAG_HASPOINTER 0x200
291+
#define TYPEFLAG_HASUNION 0x400
292+
#define TYPEFLAG_HASBITFIELD 0x800
291293

292294
#define DICTFLAG_FINAL 0x1000
293295

Modules/_ctypes/stgdict.c

+9
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,13 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
407407
PyMem_Free(stgdict->ffi_type_pointer.elements);
408408

409409
basedict = PyType_stgdict((PyObject *)((PyTypeObject *)type)->tp_base);
410+
if (basedict) {
411+
stgdict->flags |= (basedict->flags &
412+
(TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD));
413+
}
414+
if (!isStruct) {
415+
stgdict->flags |= TYPEFLAG_HASUNION;
416+
}
410417
if (basedict && !use_broken_old_ctypes_semantics) {
411418
size = offset = basedict->size;
412419
align = basedict->align;
@@ -482,8 +489,10 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
482489
stgdict->ffi_type_pointer.elements[ffi_ofs + i] = &dict->ffi_type_pointer;
483490
if (dict->flags & (TYPEFLAG_ISPOINTER | TYPEFLAG_HASPOINTER))
484491
stgdict->flags |= TYPEFLAG_HASPOINTER;
492+
stgdict->flags |= dict->flags & (TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD);
485493
dict->flags |= DICTFLAG_FINAL; /* mark field type final */
486494
if (PyTuple_Size(pair) == 3) { /* bits specified */
495+
stgdict->flags |= TYPEFLAG_HASBITFIELD;
487496
switch(dict->ffi_type_pointer.type) {
488497
case FFI_TYPE_UINT8:
489498
case FFI_TYPE_UINT16:

0 commit comments

Comments
 (0)