From dc12b076b92a4b892dd42ea54f61a348514d9bc5 Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Tue, 19 Apr 2022 12:11:23 +0100 Subject: [PATCH 1/3] Fix invalid specification of default_value without regard to default_value_type --- traits/ctraits.c | 16 ++++++++++++++++ traits/tests/test_ctraits.py | 16 ++++++++++++++++ traits/tests/test_list.py | 13 +++++++++++++ traits/tests/test_trait_types.py | 14 ++++++++++++++ traits/trait_type.py | 11 +++++++---- 5 files changed, 66 insertions(+), 4 deletions(-) diff --git a/traits/ctraits.c b/traits/ctraits.c index 3c2834bd8..d5b8a7f74 100644 --- a/traits/ctraits.c +++ b/traits/ctraits.c @@ -3064,6 +3064,22 @@ _trait_set_default_value(trait_object *trait, PyObject *args) return NULL; } + /* Validate the value */ + switch (value_type) { + /* We only do sufficient validation to avoid segfaults when + unwrapping the value in `default_value_for`. */ + case CALLABLE_AND_ARGS_DEFAULT_VALUE: + if (!PyTuple_Check(value) || PyTuple_GET_SIZE(value) != 3) { + PyErr_SetString( + PyExc_ValueError, + "default value for type DefaultValue.callable_and_args " + "must be a tuple of length 3" + ); + return NULL; + } + break; + } + trait->default_value_type = value_type; /* The DECREF on the old value can call arbitrary code, so take care not to diff --git a/traits/tests/test_ctraits.py b/traits/tests/test_ctraits.py index ea9e29407..2b1e233c7 100644 --- a/traits/tests/test_ctraits.py +++ b/traits/tests/test_ctraits.py @@ -56,6 +56,22 @@ def test_set_and_get_default_value(self): trait.default_value(), (DefaultValue.list_copy, [1, 2, 3]) ) + def test_validate_default_value_for_callable_and_args(self): + + bad_values = [ + None, + 123, + (int, (2,)), + (int, 2, 3, 4), + ] + + trait = CTrait(TraitKind.trait) + for value in bad_values: + with self.subTest(value=value): + with self.assertRaises(ValueError): + trait.set_default_value( + DefaultValue.callable_and_args, value) + def test_default_value_for_set_is_deprecated(self): trait = CTrait(TraitKind.trait) with warnings.catch_warnings(record=True) as warn_msgs: diff --git a/traits/tests/test_list.py b/traits/tests/test_list.py index 5742f377b..78cba000b 100644 --- a/traits/tests/test_list.py +++ b/traits/tests/test_list.py @@ -214,6 +214,19 @@ def test_clone_ref(self): for bar in baz.bars: self.assertIn(bar, baz_copy.bars) + def test_subclass_with_default(self): + class A(HasTraits): + foo = List(Int) + + class B(A): + foo = [1, 2, 3] + + b = B() + self.assertEqual(b.foo, [1, 2, 3]) + # b.foo should still support the usual validation + with self.assertRaises(TraitError): + b.foo.append("a string") + def test_clone_deep_baz(self): baz = Baz() for name in ["a", "b", "c", "d"]: diff --git a/traits/tests/test_trait_types.py b/traits/tests/test_trait_types.py index 59a740b23..561426639 100644 --- a/traits/tests/test_trait_types.py +++ b/traits/tests/test_trait_types.py @@ -132,6 +132,20 @@ class MyTraitType(TraitType): with self.assertRaises(ValueError): ctrait.default_value_for(None, "") + def test_call_sets_default_value_type(self): + class FooTrait(TraitType): + default_value_type = DefaultValue.callable_and_args + + def __init__(self, default_value=NoDefaultSpecified, **metadata): + default_value = (pow, (3, 4), {}) + super().__init__(default_value, **metadata) + + trait = FooTrait() + ctrait = trait.as_ctrait() + self.assertEqual(ctrait.default_value_for(None, "dummy"), 81) + cloned_ctrait = trait(30) + self.assertEqual(cloned_ctrait.default_value_for(None, "dummy"), 30) + class TestDeprecatedTraitTypes(unittest.TestCase): def test_function_deprecated(self): diff --git a/traits/trait_type.py b/traits/trait_type.py index 257a058b0..95d89cd26 100644 --- a/traits/trait_type.py +++ b/traits/trait_type.py @@ -305,15 +305,18 @@ def clone(self, default_value=NoDefaultSpecified, **metadata): new._metadata.update(metadata) if default_value is not NoDefaultSpecified: - new.default_value = default_value if self.validate is not None: try: - new.default_value = self.validate( - None, None, default_value - ) + default_value = self.validate(None, None, default_value) except Exception: pass + # Known issue: this doesn't do the right thing for + # List, Dict and Set, where we really want to make a copy. + # xref: enthought/traits#1630 + new.default_value_type = DefaultValue.constant + new.default_value = default_value + return new def get_value(self, object, name, trait=None): From 4f2245807dac019a38ff4a9a2d30a509284a1667 Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Tue, 19 Apr 2022 12:40:15 +0100 Subject: [PATCH 2/3] Remove extra test that was accidentally included in the wrong PR --- traits/tests/test_list.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/traits/tests/test_list.py b/traits/tests/test_list.py index 78cba000b..5742f377b 100644 --- a/traits/tests/test_list.py +++ b/traits/tests/test_list.py @@ -214,19 +214,6 @@ def test_clone_ref(self): for bar in baz.bars: self.assertIn(bar, baz_copy.bars) - def test_subclass_with_default(self): - class A(HasTraits): - foo = List(Int) - - class B(A): - foo = [1, 2, 3] - - b = B() - self.assertEqual(b.foo, [1, 2, 3]) - # b.foo should still support the usual validation - with self.assertRaises(TraitError): - b.foo.append("a string") - def test_clone_deep_baz(self): baz = Baz() for name in ["a", "b", "c", "d"]: From ffeb4e23e915a8be47dbe7ff73ad69ef4dc12be1 Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Tue, 19 Apr 2022 12:42:20 +0100 Subject: [PATCH 3/3] Better exception message --- traits/ctraits.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/traits/ctraits.c b/traits/ctraits.c index d5b8a7f74..0bc42ce02 100644 --- a/traits/ctraits.c +++ b/traits/ctraits.c @@ -3073,7 +3073,7 @@ _trait_set_default_value(trait_object *trait, PyObject *args) PyErr_SetString( PyExc_ValueError, "default value for type DefaultValue.callable_and_args " - "must be a tuple of length 3" + "must be a tuple of the form (callable, args, kwds)" ); return NULL; }