@@ -693,26 +693,32 @@ class Status(Enum):
693693 if isinstance (typ , UnionType ):
694694 items = [try_expanding_enum_to_union (item , target_fullname ) for item in typ .items ]
695695 return make_simplified_union (items , contract_literals = False )
696- elif isinstance (typ , Instance ) and typ .type .is_enum and typ .type .fullname == target_fullname :
697- new_items = []
698- for name , symbol in typ .type .names .items ():
699- if not isinstance (symbol .node , Var ):
700- continue
701- # Skip "_order_" and "__order__", since Enum will remove it
702- if name in ("_order_" , "__order__" ):
703- continue
704- new_items .append (LiteralType (name , typ ))
705- # SymbolTables are really just dicts, and dicts are guaranteed to preserve
706- # insertion order only starting with Python 3.7. So, we sort these for older
707- # versions of Python to help make tests deterministic.
708- #
709- # We could probably skip the sort for Python 3.6 since people probably run mypy
710- # only using CPython, but we might as well for the sake of full correctness.
711- if sys .version_info < (3 , 7 ):
712- new_items .sort (key = lambda lit : lit .value )
713- return make_simplified_union (new_items , contract_literals = False )
714- else :
715- return typ
696+ elif isinstance (typ , Instance ) and typ .type .fullname == target_fullname :
697+ if typ .type .is_enum :
698+ new_items = []
699+ for name , symbol in typ .type .names .items ():
700+ if not isinstance (symbol .node , Var ):
701+ continue
702+ # Skip "_order_" and "__order__", since Enum will remove it
703+ if name in ("_order_" , "__order__" ):
704+ continue
705+ new_items .append (LiteralType (name , typ ))
706+ # SymbolTables are really just dicts, and dicts are guaranteed to preserve
707+ # insertion order only starting with Python 3.7. So, we sort these for older
708+ # versions of Python to help make tests deterministic.
709+ #
710+ # We could probably skip the sort for Python 3.6 since people probably run mypy
711+ # only using CPython, but we might as well for the sake of full correctness.
712+ if sys .version_info < (3 , 7 ):
713+ new_items .sort (key = lambda lit : lit .value )
714+ return make_simplified_union (new_items , contract_literals = False )
715+ elif typ .type .fullname == "builtins.bool" :
716+ return make_simplified_union (
717+ [LiteralType (True , typ ), LiteralType (False , typ )],
718+ contract_literals = False
719+ )
720+
721+ return typ
716722
717723
718724def try_contracting_literals_in_union (types : Sequence [Type ]) -> List [ProperType ]:
@@ -730,9 +736,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]
730736 for idx , typ in enumerate (proper_types ):
731737 if isinstance (typ , LiteralType ):
732738 fullname = typ .fallback .type .fullname
733- if typ .fallback .type .is_enum :
739+ if typ .fallback .type .is_enum or isinstance ( typ . value , bool ) :
734740 if fullname not in sum_types :
735- sum_types [fullname ] = (set (get_enum_values (typ .fallback )), [])
741+ sum_types [fullname ] = (set (get_enum_values (typ .fallback ))
742+ if typ .fallback .type .is_enum
743+ else set ((True , False )),
744+ [])
736745 literals , indexes = sum_types [fullname ]
737746 literals .discard (typ .value )
738747 indexes .append (idx )
0 commit comments