diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index fa231134b..64136e4f3 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -34,8 +34,22 @@ def _get_valid_models(self): return ["model1", "model2", "model3"] -class EnumListExample(HasTraits): +class CustomCollection: + + def __init__(self, *data): + self.data = data + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def __contains__(self, x): + return x in self.data + +class EnumListExample(HasTraits): values = Any(['foo', 'bar', 'baz']) value = Enum(['foo', 'bar', 'baz']) @@ -48,7 +62,6 @@ class EnumListExample(HasTraits): class EnumTupleExample(HasTraits): - values = Any(('foo', 'bar', 'baz')) value = Enum(('foo', 'bar', 'baz')) @@ -61,7 +74,6 @@ class EnumTupleExample(HasTraits): class EnumEnumExample(HasTraits): - values = Any(FooEnum) value = Enum(FooEnum) @@ -73,6 +85,26 @@ class EnumEnumExample(HasTraits): value_name_default = Enum(FooEnum.bar, values='values') +class EnumCollectionExample(HasTraits): + rgb = Enum("red", CustomCollection("red", "green", "blue")) + + rgb_char = Enum("r", "g", "b") + + numbers = Enum(CustomCollection("one", "two", "three")) + + letters = Enum("abcdefg") + + int_set_enum = Enum(1, {1, 2}) + + correct_int_set_enum = Enum([1, {1, 2}]) + + yes_no = Enum("yes", "no") + + digits = Enum(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + + two_digits = Enum(1, 2) + + class EnumTestCase(unittest.TestCase): def test_valid_enum(self): example_model = ExampleModel(root="model1") @@ -154,3 +186,68 @@ def test_enum_enum(self): with self.assertRaises(TraitError): example.value_name = FooEnum.bar + + def test_enum_collection(self): + collection_enum = EnumCollectionExample() + + # Test the default values. + self.assertEqual("red", collection_enum.rgb) + self.assertEqual("r", collection_enum.rgb_char) + self.assertEqual("one", collection_enum.numbers) + self.assertEqual("abcdefg", collection_enum.letters) + self.assertEqual("yes", collection_enum.yes_no) + self.assertEqual(0, collection_enum.digits) + self.assertEqual(1, collection_enum.int_set_enum) + self.assertEqual(1, collection_enum.two_digits) + + # Test assigning valid values + collection_enum.rgb = "blue" + self.assertEqual("blue", collection_enum.rgb) + + collection_enum.rgb_char = 'g' + self.assertEqual("g", collection_enum.rgb_char) + + collection_enum.yes_no = "no" + self.assertEqual("no", collection_enum.yes_no) + + for i in range(10): + collection_enum.digits = i + self.assertEqual(i, collection_enum.digits) + + collection_enum.two_digits = 2 + self.assertEqual(2, collection_enum.two_digits) + + # Test assigning invalid values + with self.assertRaises(TraitError): + collection_enum.rgb = "two" + + with self.assertRaises(TraitError): + collection_enum.letters = 'b' + + with self.assertRaises(TraitError): + collection_enum.yes_no = "n" + + with self.assertRaises(TraitError): + collection_enum.digits = 10 + + # Fixing issue #835 introduces the following behaviour, which would + # have otherwise not thrown a TraitError + with self.assertRaises(TraitError): + collection_enum.int_set_enum = {1, 2} + + # But the behaviour can be fixed + # by defining it like correct_int_set_enum + self.assertEqual(1, collection_enum.correct_int_set_enum) + + # No more error on assignment + collection_enum.correct_int_set_enum = {1, 2} + + with self.assertRaises(TraitError): + collection_enum.correct_int_set_enum = 20 + + def test_empty_enum(self): + with self.assertRaises(TraitError): + class EmptyEnum(HasTraits): + a = Enum() + + EmptyEnum() diff --git a/traits/trait_base.py b/traits/trait_base.py index 82a37585b..ba8475784 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -182,28 +182,6 @@ def safe_contains(value, container): return False -def enum_default(values): - """ Get a default value from the valid values of an Enum trait. - - Parameters - ---------- - values : tuple, list or enum.Enum - The collection of valid values for an enum trait. - - Returns - ------- - default : any - The first valid value, or None if the collection is empty. - """ - if isinstance(values, enum.EnumMeta): - default = next(iter(values), None) - elif len(values) > 0: - default = values[0] - else: - default = None - return default - - def class_of(object): """ Returns a string containing the class name of an object with the correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image', @@ -390,3 +368,12 @@ def not_event(value): def is_str(value): return isinstance(value, str) + + +def is_collection(value): + """ Returns true if the value can be iterated over. """ + try: + iter(value) + return True + except TypeError: + return False diff --git a/traits/trait_types.py b/traits/trait_types.py index 1a9aed4d2..e08e34d48 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -31,7 +31,6 @@ get_module_name, HandleWeakRef, class_of, - enum_default, EnumTypes, RangeTypes, safe_contains, @@ -40,6 +39,7 @@ Undefined, TraitsCache, xgetattr, + is_collection, ) from .trait_converters import trait_from, trait_cast from .trait_dict_object import TraitDictEvent, TraitDictObject @@ -1913,16 +1913,24 @@ class BaseEnum(TraitType): The enumeration of all legal values for the trait. The expected signatures are either: - - a single list, enum.Enum or tuple. The default value is the first - item in the collection. + - a collection. The default value is the first item in the + collection. The collection should conform to the + collections.abc.Collection interface. That is, it at least + provides the __contains__, __len__ and __iter__ methods. + Note that although the types str, bytes, and bytearray + conform to the collection interface, these are handled + as discrete units. - a single default value, combined with the values keyword argument. - - a default value, followed by a single list enum.Enum or tuple. + - a default value, followed by a single list enum.Enum, tuple or + collection conforming to collections.abc.Collection - arbitrary positional arguments each giving a valid value. values : str The name of a trait holding the legal values. A default value may - be provided via a positional argument, otherwise it is the first - item stored in the . + be provided via a positional argument, otherwise the first item in + the collection is used as the default value. Note that if the + collection does not have a notion of order like a set, the default + value will be an arbitrary element from the set. **metadata Trait metadata for the trait. @@ -1953,19 +1961,51 @@ def __init__(self, *args, **metadata): "when using the 'values' keyword" ) else: - default_value = args[0] - if (len(args) == 1) and isinstance(default_value, EnumTypes): - args = default_value - default_value = enum_default(args) - elif (len(args) == 2) and isinstance(args[1], EnumTypes): - args = args[1] + if len(args) < 1: + raise TraitError("Enum trait requires at " + "least 1 argument.") + + elif len(args) == 1: + arg = args[0] + if isinstance(arg, EnumTypes): + default_value = next(iter(arg), None) + self.values = tuple(arg) + + # Treat str, bytes and bytearray as discrete units, + # and not as a collection. + elif isinstance(arg, (str, bytes, bytearray)): + default_value = arg + self.values = (arg,) + + # Handle a collection + else: + default_value = next(iter(arg), None) + self.values = arg + + elif len(args) == 2: + # If 2 args, the first is default, second is allowed values. + default_value = args[0] + allowed_vals = args[1] + + # Treat str, bytes and bytearray as discrete units, + # and not as a collection. + if isinstance(allowed_vals, (str, bytes, bytearray)): + self.values = tuple(args) + + elif is_collection(allowed_vals): + self.values = tuple(allowed_vals) + + else: + self.values = tuple(args) + else: + default_value = args[0] + self.values = tuple(args) if isinstance(args, enum.EnumMeta): metadata.setdefault('format_func', operator.attrgetter('name')) metadata.setdefault('evaluate', args) self.name = "" - self.values = tuple(args) self.init_fast_validate(ValidateTrait.enum, self.values) super(BaseEnum, self).__init__(default_value, **metadata) @@ -2019,7 +2059,7 @@ def _get(self, object, name, trait): value = self.get_value(object, name, trait) values = xgetattr(object, self.name) if not safe_contains(value, values): - value = enum_default(values) + value = next(iter(values), None) return value def _set(self, object, name, value): @@ -2044,17 +2084,24 @@ class Enum(BaseEnum): The enumeration of all legal values for the trait. The expected signatures are either: - - a single list, enum.Enum or tuple. The default value is the first - item in the collection. + - a single list, enum.Enum, tuple or a collection. The default value + is the first item in the collection. The collection should conform to + the collections.abc.Collection interface. That is, it at least + provides the __contains__, __len__ and __iter__ methods. + Note that although the types str, bytes, and bytearray + conform to the collection interface, these are handled + as discrete units. - a single default value, combined with the values keyword argument. - - a default value, followed by a single list enum.Enum or tuple. + - a default value, followed by a single list enum.Enum, tuple or + collection conforming to collections.abc.Collection - arbitrary positional arguments each giving a valid value. - values : str The name of a trait holding the legal values. A default value may - be provided via a positional argument, otherwise it is the first - item stored in the . + be provided via a positional argument, otherwise the first item in + the collection is used as the default value. Note that if the + collection does not have a notion of order like a set, the default + value will be an arbitrary element from the set. **metadata Trait metadata for the trait. @@ -2070,6 +2117,10 @@ class Enum(BaseEnum): def init_fast_validate(self, *args): """ Set up C-level fast validation. """ + # Don't use fast validation if second arg is not a tuple. + if len(args) == 2 and not isinstance(args[1], tuple): + return + self.fast_validate = args