Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix invalid specification of default_value without regard to default_value_type #1631

Merged
merged 3 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions traits/ctraits.c
Original file line number Diff line number Diff line change
Expand Up @@ -3064,6 +3064,22 @@ _trait_set_default_value(trait_object *trait, PyObject *args)
return NULL;
Copy link
Contributor

@rahulporuri rahulporuri Apr 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is tangential but while looking up the difference between PyErr_Format and PyErr_SetString, I came across the fact that PyErr_Format returns NULL - so we don't need to explicitly return NULL here so we can do return PyErr_Format(...). Right? Or is the convention not to do so?

Ref https://docs.python.org/3/c-api/exceptions.html#c.PyErr_Format

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could! Not in this PR, but I'd be happy to look at a PR that made this and similar cleanups across ctraits.c.

}

/* Validate the value */
switch (value_type) {
/* We only do sufficient validation to avoid segfaults when
unwrapping the value in `default_value_for`. */
case CALLABLE_AND_ARGS_DEFAULT_VALUE:
if (!PyTuple_Check(value) || PyTuple_GET_SIZE(value) != 3) {
PyErr_SetString(
PyExc_ValueError,
"default value for type DefaultValue.callable_and_args "
"must be a tuple of the form (callable, args, kwds)"
);
return NULL;
}
break;
}

trait->default_value_type = value_type;

/* The DECREF on the old value can call arbitrary code, so take care not to
Expand Down
16 changes: 16 additions & 0 deletions traits/tests/test_ctraits.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ def test_set_and_get_default_value(self):
trait.default_value(), (DefaultValue.list_copy, [1, 2, 3])
)

def test_validate_default_value_for_callable_and_args(self):

bad_values = [
None,
123,
(int, (2,)),
(int, 2, 3, 4),
]

trait = CTrait(TraitKind.trait)
for value in bad_values:
with self.subTest(value=value):
with self.assertRaises(ValueError):
trait.set_default_value(
DefaultValue.callable_and_args, value)

def test_default_value_for_set_is_deprecated(self):
trait = CTrait(TraitKind.trait)
with warnings.catch_warnings(record=True) as warn_msgs:
Expand Down
14 changes: 14 additions & 0 deletions traits/tests/test_trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ class MyTraitType(TraitType):
with self.assertRaises(ValueError):
ctrait.default_value_for(None, "<dummy>")

def test_call_sets_default_value_type(self):
class FooTrait(TraitType):
default_value_type = DefaultValue.callable_and_args

def __init__(self, default_value=NoDefaultSpecified, **metadata):
default_value = (pow, (3, 4), {})
super().__init__(default_value, **metadata)

trait = FooTrait()
ctrait = trait.as_ctrait()
self.assertEqual(ctrait.default_value_for(None, "dummy"), 81)
cloned_ctrait = trait(30)
self.assertEqual(cloned_ctrait.default_value_for(None, "dummy"), 30)


class TestDeprecatedTraitTypes(unittest.TestCase):
def test_function_deprecated(self):
Expand Down
11 changes: 7 additions & 4 deletions traits/trait_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,18 @@ def clone(self, default_value=NoDefaultSpecified, **metadata):
new._metadata.update(metadata)

if default_value is not NoDefaultSpecified:
new.default_value = default_value
if self.validate is not None:
try:
new.default_value = self.validate(
None, None, default_value
)
default_value = self.validate(None, None, default_value)
except Exception:
pass

# Known issue: this doesn't do the right thing for
# List, Dict and Set, where we really want to make a copy.
# xref: enthought/traits#1630
new.default_value_type = DefaultValue.constant
new.default_value = default_value

return new

def get_value(self, object, name, trait=None):
Expand Down