Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix add_class_trait in the presence of subclasses #1461

Merged
merged 7 commits into from
May 17, 2021
47 changes: 41 additions & 6 deletions traits/has_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def update_traits_class_dict(class_name, bases, class_dict):
# Make sure the trait prefixes are sorted longest to shortest
# so that we can easily bind dynamic traits to the longest matching
# prefix:
prefix_list.sort(key=lambda x: -len(x))
prefix_list.sort(key=len, reverse=True)

# Get the list of all possible 'Instance'/'List(Instance)' handlers:
instance_traits = _get_instance_handlers(class_dict, hastraits_bases)
Expand Down Expand Up @@ -1113,6 +1113,8 @@ def _trait_added_changed(self, name):
def add_class_trait(cls, name, *trait):
""" Adds a named trait attribute to this class.

Also adds the same attribute to all subclasses.

Parameters
----------
name : str
Expand All @@ -1135,14 +1137,39 @@ def add_class_trait(cls, name, *trait):
trait = trait_for(trait[0])

# Add the trait to the class:
cls._add_class_trait(name, trait, False)
cls._add_class_trait(name, trait, is_subclass=False)
mdickinson marked this conversation as resolved.
Show resolved Hide resolved

# Also add the trait to all subclasses of this class:
for subclass in cls.trait_subclasses(True):
subclass._add_class_trait(name, trait, True)
subclass._add_class_trait(name, trait, is_subclass=True)

@classmethod
def _add_class_trait(cls, name, trait, is_subclass):
"""
Add a named trait attribute to this class.

Does not affect subclasses.

Parameters
----------
name : str
Name of the attribute to add.
trait : CTrait
The trait to be added.
is_subclass : bool
True if we're adding the trait to a strict subclass of the
original class that add_class_trait was called for. This is used
to decide how to behave if ``cls`` already has a trait named
``name``: in that circumstance, if ``is_subclass`` is False, an
error will be raised, while if ``is_subclass`` is True, no trait
will be added.

Raises
------
TraitError
If a trait with the given name already exists, and is_subclass
is ``False``.
"""
# Get a reference to the class's dictionary and 'prefix' traits:
class_dict = cls.__dict__
prefix_traits = class_dict[PrefixTraits]
Expand All @@ -1161,7 +1188,7 @@ def _add_class_trait(cls, name, trait, is_subclass):
prefix_list.append(name)

# Resort the list from longest to shortest:
prefix_list.sort(lambda x, y: len(y) - len(x))
prefix_list.sort(key=len, reverse=True)

return

Expand All @@ -1177,9 +1204,17 @@ def _add_class_trait(cls, name, trait, is_subclass):
handler = trait.handler
if handler is not None:
if handler.has_items:
cls.add_class_trait(name + "_items", handler.items_event())
cls._add_class_trait(
name + "_items",
handler.items_event(),
is_subclass=is_subclass,
)
if handler.is_mapped:
cls.add_class_trait(name + "_", mapped_trait_for(trait, name))
cls._add_class_trait(
name + "_",
mapped_trait_for(trait, name),
is_subclass=is_subclass,
)

# Make the new trait inheritable (if allowed):
if trait.is_base is not False:
Expand Down
9 changes: 8 additions & 1 deletion traits/tests/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@ def tearDown(self):
warnings.filters[:] = self.old_filters

def test_new(self):
# Previously, this test used HasTraits(x=10). That has the
# side-effect of creating an `x` trait on HasTraits, possibly
# causing interactions with other tests.
# xref: enthought/traits#58
class A(HasTraits):
pass

# Should not raise DeprecationWarning.
HasTraits(x=10)
A(x=10)


class AbstractFoo(ABCHasTraits):
Expand Down
100 changes: 100 additions & 0 deletions traits/tests/test_has_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from traits.traits import ForwardProperty, generic_trait
from traits.trait_types import Event, Float, Instance, Int, List, Map, Str
from traits.trait_errors import TraitError


def _dummy_getter(self):
Expand Down Expand Up @@ -464,6 +465,105 @@ class A(HasTraits):
self.assertIsNot(objs_copy[0], objs[0])
self.assertIs(objs_copy[0], objs_copy[1])

def test_add_class_trait(self):
# Testing basic usage.
class A(HasTraits):
pass

A.add_class_trait("y", Str())

a = A()

self.assertEqual(a.y, "")

def test_add_class_trait_affects_existing_instances(self):
class A(HasTraits):
pass

a = A()

A.add_class_trait("y", Str())

self.assertEqual(a.y, "")

def test_add_class_trait_affects_subclasses(self):
class A(HasTraits):
pass

class B(A):
pass

class C(B):
pass

class D(B):
pass

A.add_class_trait("y", Str())
self.assertEqual(A().y, "")
self.assertEqual(B().y, "")
self.assertEqual(C().y, "")
self.assertEqual(D().y, "")

def test_add_class_trait_has_items_and_subclasses(self):
# Regression test for enthought/traits#1460
class A(HasTraits):
pass

class B(A):
pass

class C(B):
pass

# Code branch for traits with items.
A.add_class_trait("x", List(Int))
self.assertEqual(A().x, [])
self.assertEqual(B().x, [])
self.assertEqual(C().x, [])

# Exercise the code branch for mapped traits.
A.add_class_trait("y", Map({"yes": 1, "no": 0}, default_value="no"))
self.assertEqual(A().y, "no")
self.assertEqual(B().y, "no")
self.assertEqual(C().y, "no")

def test_add_class_trait_add_prefix_traits(self):

class A(HasTraits):
pass

A.add_class_trait("abc_", Str())
A.add_class_trait("abc_def_", Int())

a = A()
self.assertEqual(a.abc_def_g, 0)
self.assertEqual(a.abc_z, "")

def test_add_class_trait_when_trait_already_exists(self):

class A(HasTraits):
foo = Int()

with self.assertRaises(TraitError):
A.add_class_trait("foo", List())

self.assertEqual(A().foo, 0)
with self.assertRaises(AttributeError):
A().foo_items

def test_add_class_trait_when_trait_already_exists_in_subclass(self):
class A(HasTraits):
pass

class B(A):
foo = Int()

A.add_class_trait("foo", Str())

self.assertEqual(A().foo, "")
self.assertEqual(B().foo, 0)


class TestObjectNotifiers(unittest.TestCase):
""" Test calling object notifiers. """
Expand Down