Skip to content

Commit e9e9555

Browse files
cantoniostensorflower-gardener
authored andcommitted
Fix pywrap attribute read security vulnerability.
If a list of quantized tensors is assigned to an attribute, the pywrap code was failing to parse the tensor and returning a `nullptr`, which wasn't caught. Here we check the return value and set an appropriate error status. PiperOrigin-RevId: 476981029
1 parent e7ed22e commit e9e9555

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

Diff for: tensorflow/python/eager/pywrap_tfe_src.cc

+14-5
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,20 @@ bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key,
397397
const int num_values = PySequence_Size(py_list);
398398
if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
399399

400-
#define PARSE_LIST(c_type, parse_fn) \
401-
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
402-
for (int i = 0; i < num_values; ++i) { \
403-
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
404-
if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
400+
#define PARSE_LIST(c_type, parse_fn) \
401+
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
402+
for (int i = 0; i < num_values; ++i) { \
403+
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
404+
if (py_value == nullptr) { \
405+
TF_SetStatus(status, TF_INVALID_ARGUMENT, \
406+
tensorflow::strings::StrCat( \
407+
"Expecting sequence of " #c_type " for attr ", key, \
408+
", got ", py_list->ob_type->tp_name) \
409+
.c_str()); \
410+
return false; \
411+
} else if (!parse_fn(key, py_value.get(), status, &values[i])) { \
412+
return false; \
413+
} \
405414
}
406415

407416
if (type == TF_ATTR_STRING) {

Diff for: tensorflow/python/kernel_tests/image_ops/extract_image_patches_op_test.py

+13
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import numpy as np
1818

1919
from tensorflow.python.framework import constant_op
20+
from tensorflow.python.framework import dtypes
2021
from tensorflow.python.ops import array_ops
22+
from tensorflow.python.ops import math_ops
2123
from tensorflow.python.platform import test
2224

2325

@@ -139,6 +141,17 @@ def testComplexDataTypes(self):
139141
padding=padding,
140142
patches=patches)
141143

144+
def testInvalidAttributes(self):
145+
"""Test for passing weird things into ksizes."""
146+
with self.assertRaisesRegex(TypeError, "Expected list"):
147+
image = constant_op.constant([0.0])
148+
ksizes = math_ops.cast(
149+
constant_op.constant(dtype=dtypes.int16, value=[[1, 4], [5, 2]]),
150+
dtype=dtypes.qint16)
151+
strides = [1, 1, 1, 1]
152+
self.evaluate(
153+
array_ops.extract_image_patches(
154+
image, ksizes=ksizes, strides=strides, padding="SAME"))
142155

143156
if __name__ == "__main__":
144157
test.main()

0 commit comments

Comments
 (0)