Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Collections in valid values for Enum trait #889

Merged
merged 19 commits into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 100 additions & 3 deletions traits/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,22 @@ def _get_valid_models(self):
return ["model1", "model2", "model3"]


class EnumListExample(HasTraits):
class CustomCollection:

def __init__(self, *data):
self.data = data

def __len__(self):
return len(self.data)

def __iter__(self):
return iter(self.data)

def __contains__(self, x):
return x in self.data


class EnumListExample(HasTraits):
values = Any(['foo', 'bar', 'baz'])

value = Enum(['foo', 'bar', 'baz'])
Expand All @@ -48,7 +62,6 @@ class EnumListExample(HasTraits):


class EnumTupleExample(HasTraits):

values = Any(('foo', 'bar', 'baz'))

value = Enum(('foo', 'bar', 'baz'))
Expand All @@ -61,7 +74,6 @@ class EnumTupleExample(HasTraits):


class EnumEnumExample(HasTraits):

values = Any(FooEnum)

value = Enum(FooEnum)
Expand All @@ -73,6 +85,26 @@ class EnumEnumExample(HasTraits):
value_name_default = Enum(FooEnum.bar, values='values')


class EnumCollectionExample(HasTraits):
rgb = Enum("red", CustomCollection("red", "green", "blue"))

rgb_char = Enum("r", "g", "b")

numbers = Enum(CustomCollection("one", "two", "three"))

letters = Enum("abcdefg")

int_set_enum = Enum(1, {1, 2})

correct_int_set_enum = Enum([1, {1, 2}])

yes_no = Enum("yes", "no")

digits = Enum(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)

two_digits = Enum(1, 2)


class EnumTestCase(unittest.TestCase):
def test_valid_enum(self):
example_model = ExampleModel(root="model1")
Expand Down Expand Up @@ -154,3 +186,68 @@ def test_enum_enum(self):

with self.assertRaises(TraitError):
example.value_name = FooEnum.bar

def test_enum_collection(self):
collection_enum = EnumCollectionExample()

# Test the default values.
self.assertEqual("red", collection_enum.rgb)
self.assertEqual("r", collection_enum.rgb_char)
self.assertEqual("one", collection_enum.numbers)
self.assertEqual("abcdefg", collection_enum.letters)
self.assertEqual("yes", collection_enum.yes_no)
self.assertEqual(0, collection_enum.digits)
self.assertEqual(1, collection_enum.int_set_enum)
self.assertEqual(1, collection_enum.two_digits)

# Test assigning valid values
collection_enum.rgb = "blue"
self.assertEqual("blue", collection_enum.rgb)

collection_enum.rgb_char = 'g'
self.assertEqual("g", collection_enum.rgb_char)

collection_enum.yes_no = "no"
self.assertEqual("no", collection_enum.yes_no)

for i in range(10):
collection_enum.digits = i
self.assertEqual(i, collection_enum.digits)

collection_enum.two_digits = 2
self.assertEqual(2, collection_enum.two_digits)

# Test assigning invalid values
with self.assertRaises(TraitError):
collection_enum.rgb = "two"

with self.assertRaises(TraitError):
collection_enum.letters = 'b'

with self.assertRaises(TraitError):
collection_enum.yes_no = "n"

with self.assertRaises(TraitError):
collection_enum.digits = 10

# Fixing issue #835 introduces the following behaviour, which would
# have otherwise not thrown a TraitError
with self.assertRaises(TraitError):
collection_enum.int_set_enum = {1, 2}

# But the behaviour can be fixed
# by defining it like correct_int_set_enum
self.assertEqual(1, collection_enum.correct_int_set_enum)

# No more error on assignment
collection_enum.correct_int_set_enum = {1, 2}

with self.assertRaises(TraitError):
collection_enum.correct_int_set_enum = 20

def test_empty_enum(self):
with self.assertRaises(TraitError):
class EmptyEnum(HasTraits):
a = Enum()

EmptyEnum()
31 changes: 9 additions & 22 deletions traits/trait_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,28 +182,6 @@ def safe_contains(value, container):
return False


def enum_default(values):
""" Get a default value from the valid values of an Enum trait.

Parameters
----------
values : tuple, list or enum.Enum
The collection of valid values for an enum trait.

Returns
-------
default : any
The first valid value, or None if the collection is empty.
"""
if isinstance(values, enum.EnumMeta):
default = next(iter(values), None)
elif len(values) > 0:
default = values[0]
else:
default = None
return default


def class_of(object):
""" Returns a string containing the class name of an object with the
correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
Expand Down Expand Up @@ -390,3 +368,12 @@ def not_event(value):

def is_str(value):
return isinstance(value, str)


def is_collection(value):
""" Returns true if the value can be iterated over. """
try:
iter(value)
return True
except TypeError:
return False
91 changes: 71 additions & 20 deletions traits/trait_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
get_module_name,
HandleWeakRef,
class_of,
enum_default,
EnumTypes,
RangeTypes,
safe_contains,
Expand All @@ -40,6 +39,7 @@
Undefined,
TraitsCache,
xgetattr,
is_collection,
)
from .trait_converters import trait_from, trait_cast
from .trait_dict_object import TraitDictEvent, TraitDictObject
Expand Down Expand Up @@ -1913,16 +1913,24 @@ class BaseEnum(TraitType):
The enumeration of all legal values for the trait. The expected
signatures are either:

- a single list, enum.Enum or tuple. The default value is the first
item in the collection.
- a collection. The default value is the first item in the
midhun-pm marked this conversation as resolved.
Show resolved Hide resolved
collection. The collection should conform to the
collections.abc.Collection interface. That is, it at least
provides the __contains__, __len__ and __iter__ methods.
Note that although the types str, bytes, and bytearray
conform to the collection interface, these are handled
as discrete units.
- a single default value, combined with the values keyword
argument.
- a default value, followed by a single list enum.Enum or tuple.
- a default value, followed by a single list enum.Enum, tuple or
collection conforming to collections.abc.Collection
midhun-pm marked this conversation as resolved.
Show resolved Hide resolved
- arbitrary positional arguments each giving a valid value.
values : str
The name of a trait holding the legal values. A default value may
be provided via a positional argument, otherwise it is the first
item stored in the .
be provided via a positional argument, otherwise the first item in
the collection is used as the default value. Note that if the
collection does not have a notion of order like a set, the default
value will be an arbitrary element from the set.
**metadata
Trait metadata for the trait.

Expand Down Expand Up @@ -1953,19 +1961,51 @@ def __init__(self, *args, **metadata):
"when using the 'values' keyword"
)
else:
default_value = args[0]
if (len(args) == 1) and isinstance(default_value, EnumTypes):
args = default_value
default_value = enum_default(args)
elif (len(args) == 2) and isinstance(args[1], EnumTypes):
args = args[1]
if len(args) < 1:
raise TraitError("Enum trait requires at "
"least 1 argument.")

elif len(args) == 1:
arg = args[0]
if isinstance(arg, EnumTypes):
default_value = next(iter(arg), None)
self.values = tuple(arg)

# Treat str, bytes and bytearray as discrete units,
# and not as a collection.
elif isinstance(arg, (str, bytes, bytearray)):
midhun-pm marked this conversation as resolved.
Show resolved Hide resolved
default_value = arg
self.values = (arg,)

# Handle a collection
else:
default_value = next(iter(arg), None)
self.values = arg

elif len(args) == 2:
# If 2 args, the first is default, second is allowed values.
default_value = args[0]
allowed_vals = args[1]

# Treat str, bytes and bytearray as discrete units,
# and not as a collection.
if isinstance(allowed_vals, (str, bytes, bytearray)):
self.values = tuple(args)

elif is_collection(allowed_vals):
self.values = tuple(allowed_vals)

else:
self.values = tuple(args)
midhun-pm marked this conversation as resolved.
Show resolved Hide resolved
else:
default_value = args[0]
self.values = tuple(args)

if isinstance(args, enum.EnumMeta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@corranwebster Do you know what this branch is for? As far as I can tell, args will always be a tuple here, so this branch is never exercised, and can be removed. (Possibly it ended up here as a result of a bad merge?)

metadata.setdefault('format_func', operator.attrgetter('name'))
metadata.setdefault('evaluate', args)

self.name = ""
self.values = tuple(args)
self.init_fast_validate(ValidateTrait.enum, self.values)

super(BaseEnum, self).__init__(default_value, **metadata)
Expand Down Expand Up @@ -2019,7 +2059,7 @@ def _get(self, object, name, trait):
value = self.get_value(object, name, trait)
values = xgetattr(object, self.name)
if not safe_contains(value, values):
value = enum_default(values)
value = next(iter(values), None)
return value

def _set(self, object, name, value):
Expand All @@ -2044,17 +2084,24 @@ class Enum(BaseEnum):
The enumeration of all legal values for the trait. The expected
signatures are either:

- a single list, enum.Enum or tuple. The default value is the first
item in the collection.
- a single list, enum.Enum, tuple or a collection. The default value
is the first item in the collection. The collection should conform to
the collections.abc.Collection interface. That is, it at least
provides the __contains__, __len__ and __iter__ methods.
Note that although the types str, bytes, and bytearray
conform to the collection interface, these are handled
as discrete units.
- a single default value, combined with the values keyword
argument.
- a default value, followed by a single list enum.Enum or tuple.
- a default value, followed by a single list enum.Enum, tuple or
collection conforming to collections.abc.Collection
- arbitrary positional arguments each giving a valid value.

values : str
The name of a trait holding the legal values. A default value may
be provided via a positional argument, otherwise it is the first
item stored in the .
be provided via a positional argument, otherwise the first item in
the collection is used as the default value. Note that if the
collection does not have a notion of order like a set, the default
value will be an arbitrary element from the set.
**metadata
Trait metadata for the trait.

Expand All @@ -2070,6 +2117,10 @@ class Enum(BaseEnum):

def init_fast_validate(self, *args):
""" Set up C-level fast validation. """
# Don't use fast validation if second arg is not a tuple.
if len(args) == 2 and not isinstance(args[1], tuple):
return

self.fast_validate = args


Expand Down