@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
535535 return False
536536
537537 def visit_union_type (self , left : UnionType ) -> bool :
538+ if isinstance (self .right , Instance ):
539+ literal_types : Set [Instance ] = set ()
540+ # avoid redundant check for union of literals
541+ for item in left .relevant_items ():
542+ item = get_proper_type (item )
543+ lit_type = mypy .typeops .simple_literal_type (item )
544+ if lit_type is not None :
545+ if lit_type in literal_types :
546+ continue
547+ literal_types .add (lit_type )
548+ item = lit_type
549+ if not self ._is_subtype (item , self .orig_right ):
550+ return False
551+ return True
538552 return all (self ._is_subtype (item , self .orig_right ) for item in left .items )
539553
540554 def visit_partial_type (self , left : PartialType ) -> bool :
@@ -1199,6 +1213,27 @@ def report(*args: Any) -> None:
11991213 return applied
12001214
12011215
1216+ def try_restrict_literal_union (t : UnionType , s : Type ) -> Optional [List [Type ]]:
1217+ """Return the items of t, excluding any occurrence of s, if and only if
1218+ - t only contains simple literals
1219+ - s is a simple literal
1220+
1221+ Otherwise, returns None
1222+ """
1223+ ps = get_proper_type (s )
1224+ if not mypy .typeops .is_simple_literal (ps ):
1225+ return None
1226+
1227+ new_items : List [Type ] = []
1228+ for i in t .relevant_items ():
1229+ pi = get_proper_type (i )
1230+ if not mypy .typeops .is_simple_literal (pi ):
1231+ return None
1232+ if pi != ps :
1233+ new_items .append (i )
1234+ return new_items
1235+
1236+
12021237def restrict_subtype_away (t : Type , s : Type , * , ignore_promotions : bool = False ) -> Type :
12031238 """Return t minus s for runtime type assertions.
12041239
@@ -1212,10 +1247,14 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
12121247 s = get_proper_type (s )
12131248
12141249 if isinstance (t , UnionType ):
1215- new_items = [restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1216- for item in t .relevant_items ()
1217- if (isinstance (get_proper_type (item ), AnyType ) or
1218- not covers_at_runtime (item , s , ignore_promotions ))]
1250+ new_items = try_restrict_literal_union (t , s )
1251+ if new_items is None :
1252+ new_items = [
1253+ restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1254+ for item in t .relevant_items ()
1255+ if (isinstance (get_proper_type (item ), AnyType ) or
1256+ not covers_at_runtime (item , s , ignore_promotions ))
1257+ ]
12191258 return UnionType .make_union (new_items )
12201259 elif covers_at_runtime (t , s , ignore_promotions ):
12211260 return UninhabitedType ()
@@ -1285,11 +1324,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
12851324 right = get_proper_type (right )
12861325
12871326 if isinstance (right , UnionType ) and not isinstance (left , UnionType ):
1288- return any ([ is_proper_subtype (orig_left , item ,
1289- ignore_promotions = ignore_promotions ,
1290- erase_instances = erase_instances ,
1291- keep_erased_types = keep_erased_types )
1292- for item in right .items ] )
1327+ return any (is_proper_subtype (orig_left , item ,
1328+ ignore_promotions = ignore_promotions ,
1329+ erase_instances = erase_instances ,
1330+ keep_erased_types = keep_erased_types )
1331+ for item in right .items )
12931332 return left .accept (ProperSubtypeVisitor (orig_right ,
12941333 ignore_promotions = ignore_promotions ,
12951334 erase_instances = erase_instances ,
@@ -1495,7 +1534,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
14951534 return False
14961535
14971536 def visit_union_type (self , left : UnionType ) -> bool :
1498- return all ([ self ._is_proper_subtype (item , self .orig_right ) for item in left .items ] )
1537+ return all (self ._is_proper_subtype (item , self .orig_right ) for item in left .items )
14991538
15001539 def visit_partial_type (self , left : PartialType ) -> bool :
15011540 # TODO: What's the right thing to do here?
0 commit comments