Skip to content

Commit

Permalink
gh-103365: [Enum] STRICT boundary corrections (GH-103494)
Browse files Browse the repository at this point in the history
STRICT boundary:

- fix bitwise operations
- make default for Flag
  • Loading branch information
ethanfurman authored Apr 13, 2023
1 parent efb8a25 commit 2194071
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 38 deletions.
5 changes: 3 additions & 2 deletions Doc/library/enum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,8 @@ Data Types

.. attribute:: STRICT

Out-of-range values cause a :exc:`ValueError` to be raised::
Out-of-range values cause a :exc:`ValueError` to be raised. This is the
default for :class:`Flag`::

>>> from enum import Flag, STRICT, auto
>>> class StrictFlag(Flag, boundary=STRICT):
Expand All @@ -714,7 +715,7 @@ Data Types
.. attribute:: CONFORM

Out-of-range values have invalid values removed, leaving a valid *Flag*
value. This is the default for :class:`Flag`::
value::

>>> from enum import Flag, CONFORM, auto
>>> class ConformFlag(Flag, boundary=CONFORM):
Expand Down
67 changes: 39 additions & 28 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ def __set_name__(self, enum_class, member_name):
enum_member.__objclass__ = enum_class
enum_member.__init__(*args)
enum_member._sort_order_ = len(enum_class._member_names_)

if Flag is not None and issubclass(enum_class, Flag):
enum_class._flag_mask_ |= value
if _is_single_bit(value):
enum_class._singles_mask_ |= value
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1

# If another member with the same value was already defined, the
# new member becomes an alias to the existing one.
try:
Expand Down Expand Up @@ -532,12 +539,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
classdict['_use_args_'] = use_args
#
# convert future enum members into temporary _proto_members
# and record integer values in case this will be a Flag
flag_mask = 0
for name in member_names:
value = classdict[name]
if isinstance(value, int):
flag_mask |= value
classdict[name] = _proto_member(value)
#
# house-keeping structures
Expand All @@ -554,8 +557,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
boundary
or getattr(first_enum, '_boundary_', None)
)
classdict['_flag_mask_'] = flag_mask
classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
classdict['_flag_mask_'] = 0
classdict['_singles_mask_'] = 0
classdict['_all_bits_'] = 0
classdict['_inverted_'] = None
try:
exc = None
Expand Down Expand Up @@ -644,21 +648,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
):
delattr(enum_class, '_boundary_')
delattr(enum_class, '_flag_mask_')
delattr(enum_class, '_singles_mask_')
delattr(enum_class, '_all_bits_')
delattr(enum_class, '_inverted_')
elif Flag is not None and issubclass(enum_class, Flag):
# ensure _all_bits_ is correct and there are no missing flags
single_bit_total = 0
multi_bit_total = 0
for flag in enum_class._member_map_.values():
flag_value = flag._value_
if _is_single_bit(flag_value):
single_bit_total |= flag_value
else:
# multi-bit flags are considered aliases
multi_bit_total |= flag_value
enum_class._flag_mask_ = single_bit_total
#
# set correct __iter__
member_list = [m._value_ for m in enum_class]
if member_list != sorted(member_list):
Expand Down Expand Up @@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
class FlagBoundary(StrEnum):
"""
control how out of range values are handled
"strict" -> error is raised
"conform" -> extra bits are discarded [default for Flag]
"strict" -> error is raised [default for Flag]
"conform" -> extra bits are discarded
"eject" -> lose flag status
"keep" -> keep flag status and all bits [default for IntFlag]
"""
Expand All @@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
STRICT, CONFORM, EJECT, KEEP = FlagBoundary


class Flag(Enum, boundary=CONFORM):
class Flag(Enum, boundary=STRICT):
"""
Support for flags
"""
Expand Down Expand Up @@ -1394,6 +1387,7 @@ def _missing_(cls, value):
# - value must not include any skipped flags (e.g. if bit 2 is not
# defined, then 0d10 is invalid)
flag_mask = cls._flag_mask_
singles_mask = cls._singles_mask_
all_bits = cls._all_bits_
neg_value = None
if (
Expand Down Expand Up @@ -1425,7 +1419,8 @@ def _missing_(cls, value):
value = all_bits + 1 + value
# get members and unknown
unknown = value & ~flag_mask
member_value = value & flag_mask
aliases = value & ~singles_mask
member_value = value & singles_mask
if unknown and cls._boundary_ is not KEEP:
raise ValueError(
'%s(%r) --> unknown values %r [%s]'
Expand All @@ -1439,11 +1434,25 @@ def _missing_(cls, value):
pseudo_member = cls._member_type_.__new__(cls, value)
if not hasattr(pseudo_member, '_value_'):
pseudo_member._value_ = value
if member_value:
pseudo_member._name_ = '|'.join([
m._name_ for m in cls._iter_member_(member_value)
])
if unknown:
if member_value or aliases:
members = []
combined_value = 0
for m in cls._iter_member_(member_value):
members.append(m)
combined_value |= m._value_
if aliases:
value = member_value | aliases
for n, pm in cls._member_map_.items():
if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
members.append(pm)
combined_value |= pm._value_
unknown = value ^ combined_value
pseudo_member._name_ = '|'.join([m._name_ for m in members])
if not combined_value:
pseudo_member._name_ = None
elif unknown and cls._boundary_ is STRICT:
raise ValueError('%r: no members with value %r' % (cls, unknown))
elif unknown:
pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
else:
pseudo_member._name_ = None
Expand Down Expand Up @@ -1675,6 +1684,7 @@ def convert_class(cls):
body['_boundary_'] = boundary or etype._boundary_
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_singles_mask_'] = None
body['_inverted_'] = None
body['__or__'] = Flag.__or__
body['__xor__'] = Flag.__xor__
Expand Down Expand Up @@ -1750,7 +1760,8 @@ def convert_class(cls):
else:
multi_bits |= value
gnv_last_values.append(value)
enum_class._flag_mask_ = single_bits
enum_class._flag_mask_ = single_bits | multi_bits
enum_class._singles_mask_ = single_bits
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
# set correct __iter__
member_list = [m._value_ for m in enum_class]
Expand Down
47 changes: 39 additions & 8 deletions Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2873,6 +2873,8 @@ def __new__(cls, c):
#
a = ord('a')
#
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
#
Expand All @@ -2887,6 +2889,8 @@ def __new__(cls, c):
a = ord('a')
z = 1
#
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674)
self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672)
self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674)
#
Expand All @@ -2900,6 +2904,8 @@ def __new__(cls, c):
#
a = ord('a')
#
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)

Expand Down Expand Up @@ -3077,18 +3083,18 @@ def test_bool(self):
self.assertEqual(bool(f.value), bool(f))

def test_boundary(self):
self.assertIs(enum.Flag._boundary_, CONFORM)
class Iron(Flag, boundary=STRICT):
self.assertIs(enum.Flag._boundary_, STRICT)
class Iron(Flag, boundary=CONFORM):
ONE = 1
TWO = 2
EIGHT = 8
self.assertIs(Iron._boundary_, STRICT)
self.assertIs(Iron._boundary_, CONFORM)
#
class Water(Flag, boundary=CONFORM):
class Water(Flag, boundary=STRICT):
ONE = 1
TWO = 2
EIGHT = 8
self.assertIs(Water._boundary_, CONFORM)
self.assertIs(Water._boundary_, STRICT)
#
class Space(Flag, boundary=EJECT):
ONE = 1
Expand All @@ -3101,17 +3107,42 @@ class Bizarre(Flag, boundary=KEEP):
c = 4
d = 6
#
self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7)
self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7)
#
self.assertIs(Water(7), Water.ONE|Water.TWO)
self.assertIs(Water(~9), Water.TWO)
self.assertIs(Iron(7), Iron.ONE|Iron.TWO)
self.assertIs(Iron(~9), Iron.TWO)
#
self.assertEqual(Space(7), 7)
self.assertTrue(type(Space(7)) is int)
#
self.assertEqual(list(Bizarre), [Bizarre.c])
self.assertIs(Bizarre(3), Bizarre.b)
self.assertIs(Bizarre(6), Bizarre.d)
#
class SkipFlag(enum.Flag):
A = 1
B = 2
C = 4 | B
#
self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C))
self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42)
#
class SkipIntFlag(enum.IntFlag):
A = 1
B = 2
C = 4 | B
#
self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C))
self.assertEqual(SkipIntFlag(42).value, 42)
#
class MethodHint(Flag):
HiddenText = 0x10
DigitsOnly = 0x01
LettersOnly = 0x02
OnlyMask = 0x0f
#
self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask')


def test_iter(self):
Color = self.Color
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Set default Flag boundary to ``STRICT`` and fix bitwise operations.

0 comments on commit 2194071

Please sign in to comment.