From 1c00a88cf2600d8753b2db8f16057afa7e62b65d Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Thu, 6 Feb 2020 14:39:06 -0600 Subject: [PATCH 01/16] Allow Collections in valid values for Enum trait --- traits/ctraits.c | 5 +--- traits/tests/test_enum.py | 49 +++++++++++++++++++++++++++++++++++++++ traits/trait_base.py | 10 ++++++++ traits/trait_types.py | 22 +++++++++++++----- 4 files changed, 76 insertions(+), 10 deletions(-) diff --git a/traits/ctraits.c b/traits/ctraits.c index d3cc02621..9a516223c 100644 --- a/traits/ctraits.c +++ b/traits/ctraits.c @@ -4178,10 +4178,7 @@ _trait_set_validate(trait_object *trait, PyObject *args) case 5: /* Enumerated item check: */ if (n == 2) { - v1 = PyTuple_GET_ITEM(validate, 1); - if (PyTuple_CheckExact(v1)) { - goto done; - } + goto done; } break; diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index fa231134b..9eb11efe2 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -8,6 +8,7 @@ # # Thanks for using Enthought open source! +from collections import Collection import enum import unittest @@ -34,6 +35,21 @@ def _get_valid_models(self): return ["model1", "model2", "model3"] +class CustomCollection(Collection): + + def __init__(self, *data): + self.data = tuple(data) + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def __contains__(self, __x: object): + return __x in self.data + + class EnumListExample(HasTraits): values = Any(['foo', 'bar', 'baz']) @@ -73,6 +89,19 @@ class EnumEnumExample(HasTraits): value_name_default = Enum(FooEnum.bar, values='values') +class EnumCollectionExample(HasTraits): + + rgb = Enum("red", CustomCollection("red", "green", "blue")) + + rgb_char = Enum("r", "rgb") + + numbers = Enum(CustomCollection("one", "two", "three")) + + letters = Enum("abcdefg") + + months = Enum("jan", "feb", "mar") + + class EnumTestCase(unittest.TestCase): def test_valid_enum(self): example_model = ExampleModel(root="model1") @@ -154,3 +183,23 @@ def test_enum_enum(self): with self.assertRaises(TraitError): example.value_name = FooEnum.bar + + def test_enum_collection(self): + + collection_enum = EnumCollectionExample() + self.assertEqual("red", collection_enum.rgb) + self.assertEqual("r", collection_enum.rgb_char) + self.assertEqual("one", collection_enum.numbers) + self.assertEqual("a", collection_enum.letters) + self.assertEqual("jan", collection_enum.months) + + collection_enum.rgb_char = 'g' + self.assertEqual("g", collection_enum.rgb_char) + + collection_enum.months = "feb" + + with self.assertRaises(TraitError): + collection_enum.rgb = "two" + + with self.assertRaises(TraitError): + collection_enum.rgb_char = "rgb" diff --git a/traits/trait_base.py b/traits/trait_base.py index c103985cc..0a4ec0390 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -204,6 +204,8 @@ def enum_default(values): """ if isinstance(values, enum.EnumMeta): default = next(iter(values), None) + elif is_collection(values): + default = next(iter(values)) elif len(values) > 0: default = values[0] else: @@ -397,3 +399,11 @@ def not_event(value): def is_str(value): return isinstance(value, str) + + +def is_collection(value): + try: + iter(value) + return True + except TypeError: + return False diff --git a/traits/trait_types.py b/traits/trait_types.py index 1ada4593f..907b5fd24 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -40,6 +40,7 @@ Undefined, TraitsCache, xgetattr, + is_collection, ) from .trait_converters import trait_from, trait_cast from .trait_dict_object import TraitDictEvent, TraitDictObject @@ -1960,18 +1961,27 @@ def __init__(self, *args, **metadata): ) else: default_value = args[0] - if (len(args) == 1) and isinstance(default_value, EnumTypes): - args = default_value - default_value = enum_default(args) - elif (len(args) == 2) and isinstance(args[1], EnumTypes): - args = args[1] + if len(args) == 1: + self.values = args[0] + default_value = enum_default(self.values) + + # If values is not a Collection + if isinstance(self.values, (EnumTypes, str)): + self.values = tuple(self.values) + + elif len(args) == 2: + if isinstance(args[1], str): + self.values = tuple(args[1]) + elif is_collection(args[1]): + self.values = args[1] + else: + self.values = tuple(args) if isinstance(args, enum.EnumMeta): metadata.setdefault('format_func', operator.attrgetter('name')) metadata.setdefault('evaluate', args) self.name = "" - self.values = tuple(args) self.init_fast_validate(ValidateTrait.enum, self.values) super(BaseEnum, self).__init__(default_value, **metadata) From fefe0a005598241413033f8e93dccb68be5518ac Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Thu, 6 Feb 2020 19:26:18 -0600 Subject: [PATCH 02/16] Handle strings correctly --- traits/tests/test_enum.py | 13 +++++-------- traits/trait_base.py | 2 +- traits/trait_types.py | 6 ++++-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index 9eb11efe2..61caf502b 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -99,7 +99,7 @@ class EnumCollectionExample(HasTraits): letters = Enum("abcdefg") - months = Enum("jan", "feb", "mar") + months = Enum("january", "february") class EnumTestCase(unittest.TestCase): @@ -191,15 +191,12 @@ def test_enum_collection(self): self.assertEqual("r", collection_enum.rgb_char) self.assertEqual("one", collection_enum.numbers) self.assertEqual("a", collection_enum.letters) - self.assertEqual("jan", collection_enum.months) + self.assertEqual("january", collection_enum.months) - collection_enum.rgb_char = 'g' - self.assertEqual("g", collection_enum.rgb_char) + collection_enum.rgb_char = 'rgb' + self.assertEqual("rgb", collection_enum.rgb_char) - collection_enum.months = "feb" + collection_enum.months = "february" with self.assertRaises(TraitError): collection_enum.rgb = "two" - - with self.assertRaises(TraitError): - collection_enum.rgb_char = "rgb" diff --git a/traits/trait_base.py b/traits/trait_base.py index 0a4ec0390..4ad7c5aea 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -204,7 +204,7 @@ def enum_default(values): """ if isinstance(values, enum.EnumMeta): default = next(iter(values), None) - elif is_collection(values): + elif is_collection(values) and not isinstance(values, str): default = next(iter(values)) elif len(values) > 0: default = values[0] diff --git a/traits/trait_types.py b/traits/trait_types.py index 907b5fd24..877ce18f9 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1966,12 +1966,14 @@ def __init__(self, *args, **metadata): default_value = enum_default(self.values) # If values is not a Collection - if isinstance(self.values, (EnumTypes, str)): + if isinstance(self.values, EnumTypes): self.values = tuple(self.values) + elif isinstance(self.values, str): + self.values = self.values elif len(args) == 2: if isinstance(args[1], str): - self.values = tuple(args[1]) + self.values = tuple(args) elif is_collection(args[1]): self.values = args[1] else: From 4aab0447f16ce914a26595b5e157773bf4857ec0 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Fri, 7 Feb 2020 10:22:54 -0600 Subject: [PATCH 03/16] Fix import error for python 3.5 --- traits/tests/test_enum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index 61caf502b..e70ed0c75 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -8,7 +8,7 @@ # # Thanks for using Enthought open source! -from collections import Collection +from collections.abc import Collection import enum import unittest From adc5c7045dc0e5c97a1fa62557ba1f5f76ea5b42 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Fri, 7 Feb 2020 10:52:09 -0600 Subject: [PATCH 04/16] Fix test that uses Collection to work on Python3.5 --- traits/tests/test_enum.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index e70ed0c75..a1c52d77d 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -8,7 +8,11 @@ # # Thanks for using Enthought open source! -from collections.abc import Collection +try: + from collections.abc import Collection +except ImportError: + # Python3.5 and below does not have Collection + from collections import Container as Collection import enum import unittest @@ -51,7 +55,6 @@ def __contains__(self, __x: object): class EnumListExample(HasTraits): - values = Any(['foo', 'bar', 'baz']) value = Enum(['foo', 'bar', 'baz']) @@ -64,7 +67,6 @@ class EnumListExample(HasTraits): class EnumTupleExample(HasTraits): - values = Any(('foo', 'bar', 'baz')) value = Enum(('foo', 'bar', 'baz')) @@ -77,7 +79,6 @@ class EnumTupleExample(HasTraits): class EnumEnumExample(HasTraits): - values = Any(FooEnum) value = Enum(FooEnum) @@ -90,7 +91,6 @@ class EnumEnumExample(HasTraits): class EnumCollectionExample(HasTraits): - rgb = Enum("red", CustomCollection("red", "green", "blue")) rgb_char = Enum("r", "rgb") @@ -185,7 +185,6 @@ def test_enum_enum(self): example.value_name = FooEnum.bar def test_enum_collection(self): - collection_enum = EnumCollectionExample() self.assertEqual("red", collection_enum.rgb) self.assertEqual("r", collection_enum.rgb_char) From b11804474c854c39abb965dba5887d7706c4556d Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Fri, 7 Feb 2020 13:59:00 -0600 Subject: [PATCH 05/16] Handle single enum string as a whole item --- traits/tests/test_enum.py | 2 +- traits/trait_base.py | 7 +++++-- traits/trait_types.py | 2 -- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index a1c52d77d..2af763ae4 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -189,7 +189,7 @@ def test_enum_collection(self): self.assertEqual("red", collection_enum.rgb) self.assertEqual("r", collection_enum.rgb_char) self.assertEqual("one", collection_enum.numbers) - self.assertEqual("a", collection_enum.letters) + self.assertEqual("abcdefg", collection_enum.letters) self.assertEqual("january", collection_enum.months) collection_enum.rgb_char = 'rgb' diff --git a/traits/trait_base.py b/traits/trait_base.py index 4ad7c5aea..db8ca44b7 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -204,8 +204,11 @@ def enum_default(values): """ if isinstance(values, enum.EnumMeta): default = next(iter(values), None) - elif is_collection(values) and not isinstance(values, str): - default = next(iter(values)) + elif is_collection(values): + if isinstance(values, str): + default = values + else: + default = next(iter(values)) elif len(values) > 0: default = values[0] else: diff --git a/traits/trait_types.py b/traits/trait_types.py index 877ce18f9..ea068b694 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1968,8 +1968,6 @@ def __init__(self, *args, **metadata): # If values is not a Collection if isinstance(self.values, EnumTypes): self.values = tuple(self.values) - elif isinstance(self.values, str): - self.values = self.values elif len(args) == 2: if isinstance(args[1], str): From 96729b61059e9166858f192ae23a73cdbcd6f318 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Tue, 25 Feb 2020 10:55:46 -0600 Subject: [PATCH 06/16] Some test changes --- traits/tests/test_enum.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index 2af763ae4..fdf72b208 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -93,14 +93,12 @@ class EnumEnumExample(HasTraits): class EnumCollectionExample(HasTraits): rgb = Enum("red", CustomCollection("red", "green", "blue")) - rgb_char = Enum("r", "rgb") + rgb_char = Enum("r", "g", "b") numbers = Enum(CustomCollection("one", "two", "three")) letters = Enum("abcdefg") - months = Enum("january", "february") - class EnumTestCase(unittest.TestCase): def test_valid_enum(self): @@ -190,12 +188,12 @@ def test_enum_collection(self): self.assertEqual("r", collection_enum.rgb_char) self.assertEqual("one", collection_enum.numbers) self.assertEqual("abcdefg", collection_enum.letters) - self.assertEqual("january", collection_enum.months) - collection_enum.rgb_char = 'rgb' - self.assertEqual("rgb", collection_enum.rgb_char) + collection_enum.rgb = "blue" + self.assertEqual("blue", collection_enum.rgb) - collection_enum.months = "february" + collection_enum.rgb_char = 'g' + self.assertEqual("g", collection_enum.rgb_char) with self.assertRaises(TraitError): collection_enum.rgb = "two" From 0580dbd7880a0dc7ffc8799038cdca07402d2ebc Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Tue, 17 Mar 2020 15:15:17 -0500 Subject: [PATCH 07/16] PR review changes --- traits/tests/test_enum.py | 30 ++++++++++++++++++++++++++++-- traits/trait_base.py | 9 +++++++-- traits/trait_types.py | 25 +++++++++++++++++-------- 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index fdf72b208..f585f6812 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -50,8 +50,8 @@ def __len__(self): def __iter__(self): return iter(self.data) - def __contains__(self, __x: object): - return __x in self.data + def __contains__(self, x): + return x in self.data class EnumListExample(HasTraits): @@ -99,6 +99,12 @@ class EnumCollectionExample(HasTraits): letters = Enum("abcdefg") + int_set_enum = Enum(1, {1, 2}) + + correct_int_set_enum = Enum([1, {1, 2}]) + + yes_no = Enum("yes", "no") + class EnumTestCase(unittest.TestCase): def test_valid_enum(self): @@ -197,3 +203,23 @@ def test_enum_collection(self): with self.assertRaises(TraitError): collection_enum.rgb = "two" + + with self.assertRaises(TraitError): + collection_enum.letters = 'b' + + collection_enum.yes_no = "no" + with self.assertRaises(TraitError): + collection_enum.yes_no = "n" + + self.assertEqual(1, collection_enum.int_set_enum) + # Fixing issue #835 introduces the following behaviour, which would + # have otherwise not thrown a TraitError + with self.assertRaises(TraitError): + collection_enum.int_set_enum = {1, 2} + + # But the behaviour can be fixed, as seen below + self.assertEqual(1, collection_enum.correct_int_set_enum) + collection_enum.correct_int_set_enum = {1, 2} + + with self.assertRaises(TraitError): + collection_enum.correct_int_set_enum = 2 diff --git a/traits/trait_base.py b/traits/trait_base.py index db8ca44b7..62934d41c 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -84,7 +84,6 @@ def __reduce_ex__(self, protocol): #: assignment, and treat it specially. Uninitialized = _Uninitialized() - Undefined = None @@ -205,7 +204,7 @@ def enum_default(values): if isinstance(values, enum.EnumMeta): default = next(iter(values), None) elif is_collection(values): - if isinstance(values, str): + if is_excluded_collection(values): default = values else: default = next(iter(values)) @@ -404,6 +403,12 @@ def is_str(value): return isinstance(value, str) +def is_excluded_collection(value): + if isinstance(value, (str, bytes, bytearray)): + return True + return False + + def is_collection(value): try: iter(value) diff --git a/traits/trait_types.py b/traits/trait_types.py index 0f050fc71..f4af2be2a 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -41,6 +41,7 @@ TraitsCache, xgetattr, is_collection, + is_excluded_collection, ) from .trait_converters import trait_from, trait_cast from .trait_dict_object import TraitDictEvent, TraitDictObject @@ -1924,11 +1925,14 @@ class BaseEnum(TraitType): The enumeration of all legal values for the trait. The expected signatures are either: - - a single list, enum.Enum or tuple. The default value is the first - item in the collection. + - a single list, enum.Enum, tuple or a collection. The default value + is the first item in the collection. The collection should conform to + the collections.abc.Collection interface. That is, it at least + provides the __contains__, len and __iter__ methods. - a single default value, combined with the values keyword argument. - - a default value, followed by a single list enum.Enum or tuple. + - a default value, followed by a single list enum.Enum, tuple or + collection conforming to collections.abc.Collection - arbitrary positional arguments each giving a valid value. values : str The name of a trait holding the legal values. A default value may @@ -1969,12 +1973,14 @@ def __init__(self, *args, **metadata): self.values = args[0] default_value = enum_default(self.values) - # If values is not a Collection + # If values is not a Collection, treat them differently if isinstance(self.values, EnumTypes): self.values = tuple(self.values) + elif is_excluded_collection(self.values): + self.values = {self.values} elif len(args) == 2: - if isinstance(args[1], str): + if is_excluded_collection(args[1]): self.values = tuple(args) elif is_collection(args[1]): self.values = args[1] @@ -2064,11 +2070,14 @@ class Enum(BaseEnum): The enumeration of all legal values for the trait. The expected signatures are either: - - a single list, enum.Enum or tuple. The default value is the first - item in the collection. + - a single list, enum.Enum, tuple or a collection. The default value + is the first item in the collection. The collection should conform to + the collections.abc.Collection interface. That is, it at least + provides the __contains__, len and __iter__ methods. - a single default value, combined with the values keyword argument. - - a default value, followed by a single list enum.Enum or tuple. + - a default value, followed by a single list enum.Enum, tuple or + collection conforming to collections.abc.Collection - arbitrary positional arguments each giving a valid value. values : str From c74c1bfeabdfe3672184e56d15eeae337565babb Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Mon, 23 Mar 2020 14:20:55 -0500 Subject: [PATCH 08/16] PR review changes --- traits/ctraits.c | 5 ++++- traits/tests/test_enum.py | 15 +++++++++------ traits/trait_base.py | 10 +++++++--- traits/trait_types.py | 13 ++++++++++--- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/traits/ctraits.c b/traits/ctraits.c index 60f2d1b4a..10f252cc2 100644 --- a/traits/ctraits.c +++ b/traits/ctraits.c @@ -4185,7 +4185,10 @@ _trait_set_validate(trait_object *trait, PyObject *args) case 5: /* Enumerated item check: */ if (n == 2) { - goto done; + v1 = PyTuple_GET_ITEM(validate, 1); + if (PyTuple_CheckExact(v1)) { + goto done; + } } break; diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index f585f6812..1af1ce33a 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -8,11 +8,6 @@ # # Thanks for using Enthought open source! -try: - from collections.abc import Collection -except ImportError: - # Python3.5 and below does not have Collection - from collections import Container as Collection import enum import unittest @@ -39,7 +34,7 @@ def _get_valid_models(self): return ["model1", "model2", "model3"] -class CustomCollection(Collection): +class CustomCollection: def __init__(self, *data): self.data = tuple(data) @@ -105,6 +100,8 @@ class EnumCollectionExample(HasTraits): yes_no = Enum("yes", "no") + digits = Enum(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + class EnumTestCase(unittest.TestCase): def test_valid_enum(self): @@ -223,3 +220,9 @@ def test_enum_collection(self): with self.assertRaises(TraitError): collection_enum.correct_int_set_enum = 2 + + for i in range(10): + collection_enum.digits = i + + with self.assertRaises(TraitError): + collection_enum.digits = 10 diff --git a/traits/trait_base.py b/traits/trait_base.py index 4eb09c1af..beb3bf435 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -83,6 +83,7 @@ def __reduce_ex__(self, protocol): #: assignment, and treat it specially. Uninitialized = _Uninitialized() + Undefined = None @@ -397,12 +398,15 @@ def is_str(value): def is_excluded_collection(value): - if isinstance(value, (str, bytes, bytearray)): - return True - return False + """ Values of type str, bytes or bytearray can be + iterated over and are therefore collections, however, + we want to treat these as discrete units. + This is used by logic inside the Enum trait""" + return isinstance(value, (str, bytes, bytearray)) def is_collection(value): + """ Returns true if the value can be iterated over. """ try: iter(value) return True diff --git a/traits/trait_types.py b/traits/trait_types.py index 0bdca267d..d3f850195 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1970,10 +1970,13 @@ def __init__(self, *args, **metadata): self.values = {self.values} elif len(args) == 2: - if is_excluded_collection(args[1]): - self.values = tuple(args) - elif is_collection(args[1]): + # If 2 args, the first is default, second is allowed values. + allowed_vals = args[1] + excluded_collection = is_excluded_collection(allowed_vals) + if is_collection(allowed_vals) and not excluded_collection: self.values = args[1] + else: + self.values = tuple(args) else: self.values = tuple(args) @@ -2089,6 +2092,10 @@ class Enum(BaseEnum): def init_fast_validate(self, *args): """ Set up C-level fast validation. """ + # Don't use fast validation if second arg is not a tuple. + if len(args) == 2 and not isinstance(args[1], tuple): + return + self.fast_validate = args From bfd8f3d75d2ef432bb55ce5cf8c50e250c0ecae1 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Wed, 25 Mar 2020 20:09:52 -0500 Subject: [PATCH 09/16] Some refactoring --- traits/trait_base.py | 17 ++--------------- traits/trait_types.py | 30 ++++++++++++++++++------------ 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/traits/trait_base.py b/traits/trait_base.py index beb3bf435..8a08d8b42 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -187,7 +187,7 @@ def enum_default(values): Parameters ---------- - values : tuple, list or enum.Enum + values : tuple, list, collection or enum.Enum The collection of valid values for an enum trait. Returns @@ -195,13 +195,8 @@ def enum_default(values): default : any The first valid value, or None if the collection is empty. """ - if isinstance(values, enum.EnumMeta): + if isinstance(values, enum.EnumMeta) or is_collection(values): default = next(iter(values), None) - elif is_collection(values): - if is_excluded_collection(values): - default = values - else: - default = next(iter(values)) elif len(values) > 0: default = values[0] else: @@ -397,14 +392,6 @@ def is_str(value): return isinstance(value, str) -def is_excluded_collection(value): - """ Values of type str, bytes or bytearray can be - iterated over and are therefore collections, however, - we want to treat these as discrete units. - This is used by logic inside the Enum trait""" - return isinstance(value, (str, bytes, bytearray)) - - def is_collection(value): """ Returns true if the value can be iterated over. """ try: diff --git a/traits/trait_types.py b/traits/trait_types.py index d3f850195..162c754fb 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -41,7 +41,6 @@ TraitsCache, xgetattr, is_collection, - is_excluded_collection, ) from .trait_converters import trait_from, trait_cast from .trait_dict_object import TraitDictEvent, TraitDictObject @@ -1958,26 +1957,33 @@ def __init__(self, *args, **metadata): "when using the 'values' keyword" ) else: - default_value = args[0] if len(args) == 1: - self.values = args[0] - default_value = enum_default(self.values) + arg = args[0] + if isinstance(arg, EnumTypes): + default_value = enum_default(arg) + self.values = tuple(arg) - # If values is not a Collection, treat them differently - if isinstance(self.values, EnumTypes): - self.values = tuple(self.values) - elif is_excluded_collection(self.values): - self.values = {self.values} + elif isinstance(arg, (str, bytes, bytearray)): + default_value = arg + self.values = {arg} + + else: + default_value = enum_default(arg) + self.values = arg elif len(args) == 2: # If 2 args, the first is default, second is allowed values. + default_value = args[0] allowed_vals = args[1] - excluded_collection = is_excluded_collection(allowed_vals) - if is_collection(allowed_vals) and not excluded_collection: - self.values = args[1] + + if isinstance(allowed_vals, (str, bytes, bytearray)): + self.values = tuple(args) + elif is_collection(allowed_vals): + self.values = allowed_vals else: self.values = tuple(args) else: + default_value = args[0] self.values = tuple(args) if isinstance(args, enum.EnumMeta): From bdfa427e2c3f6b600292b56a291d7caac1072177 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Wed, 25 Mar 2020 21:20:35 -0500 Subject: [PATCH 10/16] Cleanup test code --- traits/tests/test_enum.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index 1af1ce33a..e318c63a4 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -187,42 +187,54 @@ def test_enum_enum(self): def test_enum_collection(self): collection_enum = EnumCollectionExample() + + # Test the default values. self.assertEqual("red", collection_enum.rgb) self.assertEqual("r", collection_enum.rgb_char) self.assertEqual("one", collection_enum.numbers) self.assertEqual("abcdefg", collection_enum.letters) + self.assertEqual("yes", collection_enum.yes_no) + self.assertEqual(0, collection_enum.digits) + self.assertEqual(1, collection_enum.int_set_enum) + # Test assigning valid values collection_enum.rgb = "blue" self.assertEqual("blue", collection_enum.rgb) collection_enum.rgb_char = 'g' self.assertEqual("g", collection_enum.rgb_char) + collection_enum.yes_no = "no" + self.assertEqual("no", collection_enum.yes_no) + + for i in range(10): + collection_enum.digits = i + self.assertEqual(i, collection_enum.digits) + + # Test assigning invalid values with self.assertRaises(TraitError): collection_enum.rgb = "two" with self.assertRaises(TraitError): collection_enum.letters = 'b' - collection_enum.yes_no = "no" with self.assertRaises(TraitError): collection_enum.yes_no = "n" - self.assertEqual(1, collection_enum.int_set_enum) + with self.assertRaises(TraitError): + collection_enum.digits = 10 + # Fixing issue #835 introduces the following behaviour, which would # have otherwise not thrown a TraitError with self.assertRaises(TraitError): collection_enum.int_set_enum = {1, 2} - # But the behaviour can be fixed, as seen below + # But the behaviour can be fixed + # by defining it like correct_int_set_enum self.assertEqual(1, collection_enum.correct_int_set_enum) - collection_enum.correct_int_set_enum = {1, 2} - - with self.assertRaises(TraitError): - collection_enum.correct_int_set_enum = 2 - for i in range(10): - collection_enum.digits = i + # No more error on assignment + collection_enum.correct_int_set_enum = {1, 2} with self.assertRaises(TraitError): - collection_enum.digits = 10 + collection_enum.correct_int_set_enum = 20 From 6b3852b48c326e5df8314fff2d53230216d8ba13 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Wed, 25 Mar 2020 21:32:32 -0500 Subject: [PATCH 11/16] Fix issue #934 --- traits/tests/test_enum.py | 6 ++++++ traits/trait_types.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index e318c63a4..fb283a740 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -238,3 +238,9 @@ def test_enum_collection(self): with self.assertRaises(TraitError): collection_enum.correct_int_set_enum = 20 + + def test_empty_enum(self): + with self.assertRaises(TraitError): + class EmptyEnum(HasTraits): + a = Enum() + EmptyEnum() diff --git a/traits/trait_types.py b/traits/trait_types.py index 162c754fb..45ff33b84 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1983,6 +1983,10 @@ def __init__(self, *args, **metadata): else: self.values = tuple(args) else: + if len(args) < 1: + raise TraitError("Enum trait requires at " + "least 1 argument.") + default_value = args[0] self.values = tuple(args) From 0c31e2796075ebe6531151f7d51792d9f855327c Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Wed, 25 Mar 2020 21:39:04 -0500 Subject: [PATCH 12/16] Minor change --- traits/trait_types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/traits/trait_types.py b/traits/trait_types.py index 45ff33b84..8ec69a821 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1957,6 +1957,10 @@ def __init__(self, *args, **metadata): "when using the 'values' keyword" ) else: + if len(args) < 1: + raise TraitError("Enum trait requires at " + "least 1 argument.") + if len(args) == 1: arg = args[0] if isinstance(arg, EnumTypes): @@ -1983,10 +1987,6 @@ def __init__(self, *args, **metadata): else: self.values = tuple(args) else: - if len(args) < 1: - raise TraitError("Enum trait requires at " - "least 1 argument.") - default_value = args[0] self.values = tuple(args) From 31546dc9c32f65278e34c1011f31229ae5fc2f08 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Thu, 26 Mar 2020 14:02:05 -0500 Subject: [PATCH 13/16] Some refactoring --- traits/tests/test_enum.py | 7 +++++++ traits/trait_base.py | 22 ---------------------- traits/trait_types.py | 9 ++++----- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index fb283a740..38a866b8d 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -102,6 +102,8 @@ class EnumCollectionExample(HasTraits): digits = Enum(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + two_digits = Enum(1, 2) + class EnumTestCase(unittest.TestCase): def test_valid_enum(self): @@ -196,6 +198,7 @@ def test_enum_collection(self): self.assertEqual("yes", collection_enum.yes_no) self.assertEqual(0, collection_enum.digits) self.assertEqual(1, collection_enum.int_set_enum) + self.assertEqual(1, collection_enum.two_digits) # Test assigning valid values collection_enum.rgb = "blue" @@ -211,6 +214,9 @@ def test_enum_collection(self): collection_enum.digits = i self.assertEqual(i, collection_enum.digits) + collection_enum.two_digits = 2 + self.assertEqual(2, collection_enum.two_digits) + # Test assigning invalid values with self.assertRaises(TraitError): collection_enum.rgb = "two" @@ -243,4 +249,5 @@ def test_empty_enum(self): with self.assertRaises(TraitError): class EmptyEnum(HasTraits): a = Enum() + EmptyEnum() diff --git a/traits/trait_base.py b/traits/trait_base.py index 8a08d8b42..ba8475784 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -182,28 +182,6 @@ def safe_contains(value, container): return False -def enum_default(values): - """ Get a default value from the valid values of an Enum trait. - - Parameters - ---------- - values : tuple, list, collection or enum.Enum - The collection of valid values for an enum trait. - - Returns - ------- - default : any - The first valid value, or None if the collection is empty. - """ - if isinstance(values, enum.EnumMeta) or is_collection(values): - default = next(iter(values), None) - elif len(values) > 0: - default = values[0] - else: - default = None - return default - - def class_of(object): """ Returns a string containing the class name of an object with the correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image', diff --git a/traits/trait_types.py b/traits/trait_types.py index 8ec69a821..c9170933d 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -31,7 +31,6 @@ get_module_name, HandleWeakRef, class_of, - enum_default, EnumTypes, RangeTypes, safe_contains, @@ -1961,10 +1960,10 @@ def __init__(self, *args, **metadata): raise TraitError("Enum trait requires at " "least 1 argument.") - if len(args) == 1: + elif len(args) == 1: arg = args[0] if isinstance(arg, EnumTypes): - default_value = enum_default(arg) + default_value = next(iter(arg), None) self.values = tuple(arg) elif isinstance(arg, (str, bytes, bytearray)): @@ -1972,7 +1971,7 @@ def __init__(self, *args, **metadata): self.values = {arg} else: - default_value = enum_default(arg) + default_value = next(iter(arg), None) self.values = arg elif len(args) == 2: @@ -2048,7 +2047,7 @@ def _get(self, object, name, trait): value = self.get_value(object, name, trait) values = xgetattr(object, self.name) if not safe_contains(value, values): - value = enum_default(values) + value = next(iter(values), None) return value def _set(self, object, name, value): From e5c1412b95750bc4ffa8960c3fcf8826fa433070 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Thu, 26 Mar 2020 14:02:05 -0500 Subject: [PATCH 14/16] Some refactoring --- traits/tests/test_enum.py | 7 +++++++ traits/trait_base.py | 22 ---------------------- traits/trait_types.py | 29 ++++++++++++++++++----------- 3 files changed, 25 insertions(+), 33 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index fb283a740..38a866b8d 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -102,6 +102,8 @@ class EnumCollectionExample(HasTraits): digits = Enum(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + two_digits = Enum(1, 2) + class EnumTestCase(unittest.TestCase): def test_valid_enum(self): @@ -196,6 +198,7 @@ def test_enum_collection(self): self.assertEqual("yes", collection_enum.yes_no) self.assertEqual(0, collection_enum.digits) self.assertEqual(1, collection_enum.int_set_enum) + self.assertEqual(1, collection_enum.two_digits) # Test assigning valid values collection_enum.rgb = "blue" @@ -211,6 +214,9 @@ def test_enum_collection(self): collection_enum.digits = i self.assertEqual(i, collection_enum.digits) + collection_enum.two_digits = 2 + self.assertEqual(2, collection_enum.two_digits) + # Test assigning invalid values with self.assertRaises(TraitError): collection_enum.rgb = "two" @@ -243,4 +249,5 @@ def test_empty_enum(self): with self.assertRaises(TraitError): class EmptyEnum(HasTraits): a = Enum() + EmptyEnum() diff --git a/traits/trait_base.py b/traits/trait_base.py index 8a08d8b42..ba8475784 100644 --- a/traits/trait_base.py +++ b/traits/trait_base.py @@ -182,28 +182,6 @@ def safe_contains(value, container): return False -def enum_default(values): - """ Get a default value from the valid values of an Enum trait. - - Parameters - ---------- - values : tuple, list, collection or enum.Enum - The collection of valid values for an enum trait. - - Returns - ------- - default : any - The first valid value, or None if the collection is empty. - """ - if isinstance(values, enum.EnumMeta) or is_collection(values): - default = next(iter(values), None) - elif len(values) > 0: - default = values[0] - else: - default = None - return default - - def class_of(object): """ Returns a string containing the class name of an object with the correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image', diff --git a/traits/trait_types.py b/traits/trait_types.py index 8ec69a821..70e3d0594 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -31,7 +31,6 @@ get_module_name, HandleWeakRef, class_of, - enum_default, EnumTypes, RangeTypes, safe_contains, @@ -1914,10 +1913,13 @@ class BaseEnum(TraitType): The enumeration of all legal values for the trait. The expected signatures are either: - - a single list, enum.Enum, tuple or a collection. The default value - is the first item in the collection. The collection should conform to - the collections.abc.Collection interface. That is, it at least - provides the __contains__, len and __iter__ methods. + - a collection. The default value is the first item in the + collection. The collection should conform to the + collections.abc.Collection interface. That is, it at least + provides the __contains__, __len__ and __iter__ methods. + Note that although the types str, bytes, and bytearray are + conform to the collection interface, these are handled + as discrete units. - a single default value, combined with the values keyword argument. - a default value, followed by a single list enum.Enum, tuple or @@ -1925,8 +1927,10 @@ class BaseEnum(TraitType): - arbitrary positional arguments each giving a valid value. values : str The name of a trait holding the legal values. A default value may - be provided via a positional argument, otherwise it is the first - item stored in the . + be provided via a positional argument, otherwise the first item in + the collection is used as the default value. Note that if the + collection does not have a notion of order like a set, the default + value will be an arbitrary element from the set. **metadata Trait metadata for the trait. @@ -1961,18 +1965,21 @@ def __init__(self, *args, **metadata): raise TraitError("Enum trait requires at " "least 1 argument.") - if len(args) == 1: + elif len(args) == 1: arg = args[0] if isinstance(arg, EnumTypes): - default_value = enum_default(arg) + default_value = next(iter(arg), None) self.values = tuple(arg) + # Treat str, bytes and bytearray as discrete units, + # and not as a collection. elif isinstance(arg, (str, bytes, bytearray)): default_value = arg self.values = {arg} + # Handle a collection else: - default_value = enum_default(arg) + default_value = next(iter(arg), None) self.values = arg elif len(args) == 2: @@ -2048,7 +2055,7 @@ def _get(self, object, name, trait): value = self.get_value(object, name, trait) values = xgetattr(object, self.name) if not safe_contains(value, values): - value = enum_default(values) + value = next(iter(values), None) return value def _set(self, object, name, value): From ada61f77de639398120c762e403d28293d7c5bc9 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Thu, 26 Mar 2020 14:22:03 -0500 Subject: [PATCH 15/16] Add a comment --- traits/trait_types.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/traits/trait_types.py b/traits/trait_types.py index 70e3d0594..953a02353 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1987,10 +1987,14 @@ def __init__(self, *args, **metadata): default_value = args[0] allowed_vals = args[1] + # Treat str, bytes and bytearray as discrete units, + # and not as a collection. if isinstance(allowed_vals, (str, bytes, bytearray)): self.values = tuple(args) + elif is_collection(allowed_vals): self.values = allowed_vals + else: self.values = tuple(args) else: From 78fa1a29938b37a76a8d22c4cabe3818994d7938 Mon Sep 17 00:00:00 2001 From: Midhun PM Date: Fri, 27 Mar 2020 08:40:19 -0500 Subject: [PATCH 16/16] Some more PR changes --- traits/tests/test_enum.py | 2 +- traits/trait_types.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/traits/tests/test_enum.py b/traits/tests/test_enum.py index 38a866b8d..64136e4f3 100644 --- a/traits/tests/test_enum.py +++ b/traits/tests/test_enum.py @@ -37,7 +37,7 @@ def _get_valid_models(self): class CustomCollection: def __init__(self, *data): - self.data = tuple(data) + self.data = data def __len__(self): return len(self.data) diff --git a/traits/trait_types.py b/traits/trait_types.py index 953a02353..e08e34d48 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -1917,7 +1917,7 @@ class BaseEnum(TraitType): collection. The collection should conform to the collections.abc.Collection interface. That is, it at least provides the __contains__, __len__ and __iter__ methods. - Note that although the types str, bytes, and bytearray are + Note that although the types str, bytes, and bytearray conform to the collection interface, these are handled as discrete units. - a single default value, combined with the values keyword @@ -1975,7 +1975,7 @@ def __init__(self, *args, **metadata): # and not as a collection. elif isinstance(arg, (str, bytes, bytearray)): default_value = arg - self.values = {arg} + self.values = (arg,) # Handle a collection else: @@ -1993,7 +1993,7 @@ def __init__(self, *args, **metadata): self.values = tuple(args) elif is_collection(allowed_vals): - self.values = allowed_vals + self.values = tuple(allowed_vals) else: self.values = tuple(args) @@ -2087,17 +2087,21 @@ class Enum(BaseEnum): - a single list, enum.Enum, tuple or a collection. The default value is the first item in the collection. The collection should conform to the collections.abc.Collection interface. That is, it at least - provides the __contains__, len and __iter__ methods. + provides the __contains__, __len__ and __iter__ methods. + Note that although the types str, bytes, and bytearray + conform to the collection interface, these are handled + as discrete units. - a single default value, combined with the values keyword argument. - a default value, followed by a single list enum.Enum, tuple or collection conforming to collections.abc.Collection - arbitrary positional arguments each giving a valid value. - values : str The name of a trait holding the legal values. A default value may - be provided via a positional argument, otherwise it is the first - item stored in the . + be provided via a positional argument, otherwise the first item in + the collection is used as the default value. Note that if the + collection does not have a notion of order like a set, the default + value will be an arbitrary element from the set. **metadata Trait metadata for the trait.