diff --git a/traits/ctraits.c b/traits/ctraits.c index 3c2834bd8..0bc42ce02 100644 --- a/traits/ctraits.c +++ b/traits/ctraits.c @@ -3064,6 +3064,22 @@ _trait_set_default_value(trait_object *trait, PyObject *args) return NULL; } + /* Validate the value */ + switch (value_type) { + /* We only do sufficient validation to avoid segfaults when + unwrapping the value in `default_value_for`. */ + case CALLABLE_AND_ARGS_DEFAULT_VALUE: + if (!PyTuple_Check(value) || PyTuple_GET_SIZE(value) != 3) { + PyErr_SetString( + PyExc_ValueError, + "default value for type DefaultValue.callable_and_args " + "must be a tuple of the form (callable, args, kwds)" + ); + return NULL; + } + break; + } + trait->default_value_type = value_type; /* The DECREF on the old value can call arbitrary code, so take care not to diff --git a/traits/tests/test_ctraits.py b/traits/tests/test_ctraits.py index ea9e29407..2b1e233c7 100644 --- a/traits/tests/test_ctraits.py +++ b/traits/tests/test_ctraits.py @@ -56,6 +56,22 @@ def test_set_and_get_default_value(self): trait.default_value(), (DefaultValue.list_copy, [1, 2, 3]) ) + def test_validate_default_value_for_callable_and_args(self): + + bad_values = [ + None, + 123, + (int, (2,)), + (int, 2, 3, 4), + ] + + trait = CTrait(TraitKind.trait) + for value in bad_values: + with self.subTest(value=value): + with self.assertRaises(ValueError): + trait.set_default_value( + DefaultValue.callable_and_args, value) + def test_default_value_for_set_is_deprecated(self): trait = CTrait(TraitKind.trait) with warnings.catch_warnings(record=True) as warn_msgs: diff --git a/traits/tests/test_trait_types.py b/traits/tests/test_trait_types.py index 59a740b23..561426639 100644 --- a/traits/tests/test_trait_types.py +++ b/traits/tests/test_trait_types.py @@ -132,6 +132,20 @@ class MyTraitType(TraitType): with self.assertRaises(ValueError): ctrait.default_value_for(None, "") + def test_call_sets_default_value_type(self): + class FooTrait(TraitType): + default_value_type = DefaultValue.callable_and_args + + def __init__(self, default_value=NoDefaultSpecified, **metadata): + default_value = (pow, (3, 4), {}) + super().__init__(default_value, **metadata) + + trait = FooTrait() + ctrait = trait.as_ctrait() + self.assertEqual(ctrait.default_value_for(None, "dummy"), 81) + cloned_ctrait = trait(30) + self.assertEqual(cloned_ctrait.default_value_for(None, "dummy"), 30) + class TestDeprecatedTraitTypes(unittest.TestCase): def test_function_deprecated(self): diff --git a/traits/trait_type.py b/traits/trait_type.py index 257a058b0..95d89cd26 100644 --- a/traits/trait_type.py +++ b/traits/trait_type.py @@ -305,15 +305,18 @@ def clone(self, default_value=NoDefaultSpecified, **metadata): new._metadata.update(metadata) if default_value is not NoDefaultSpecified: - new.default_value = default_value if self.validate is not None: try: - new.default_value = self.validate( - None, None, default_value - ) + default_value = self.validate(None, None, default_value) except Exception: pass + # Known issue: this doesn't do the right thing for + # List, Dict and Set, where we really want to make a copy. + # xref: enthought/traits#1630 + new.default_value_type = DefaultValue.constant + new.default_value = default_value + return new def get_value(self, object, name, trait=None):