Skip to content

Commit 860b95a

Browse files
authored
[red-knot] Binary operator inference for union types (#16601)
## Summary Properly handle binary operator inference for union types. This fixes a bug I noticed while looking at ecosystem results. The MRE version of it is this: ```py def sub(x: float, y: float): # Red Knot: Operator `-` is unsupported between objects of type `int | float` and `int | float` return x - y ``` ## Test Plan - New Markdown tests. - Expected diff in the ecosystem checks
1 parent 6de2b28 commit 860b95a

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Binary operations on union types
2+
3+
Binary operations on union types are only available if they are supported for all possible
4+
combinations of types:
5+
6+
```py
7+
def f1(i: int, u: int | None):
8+
# error: [unsupported-operator] "Operator `+` is unsupported between objects of type `int` and `int | None`"
9+
reveal_type(i + u) # revealed: Unknown
10+
# error: [unsupported-operator] "Operator `+` is unsupported between objects of type `int | None` and `int`"
11+
reveal_type(u + i) # revealed: Unknown
12+
```
13+
14+
`int` can be added to `int`, and `str` can be added to `str`, but expressions of type `int | str`
15+
cannot be added, because that would require addition of `int` and `str` or vice versa:
16+
17+
```py
18+
def f2(i: int, s: str, int_or_str: int | str):
19+
i + i
20+
s + s
21+
# error: [unsupported-operator] "Operator `+` is unsupported between objects of type `int | str` and `int | str`"
22+
reveal_type(int_or_str + int_or_str) # revealed: Unknown
23+
```
24+
25+
However, if an operation is supported for all possible combinations, the result will be a union of
26+
the possible outcomes:
27+
28+
```py
29+
from typing import Literal
30+
31+
def f3(two_or_three: Literal[2, 3], a_or_b: Literal["a", "b"]):
32+
reveal_type(two_or_three + two_or_three) # revealed: Literal[4, 5, 6]
33+
reveal_type(two_or_three**two_or_three) # revealed: Literal[4, 8, 9, 27]
34+
35+
reveal_type(a_or_b + a_or_b) # revealed: Literal["aa", "ab", "ba", "bb"]
36+
37+
reveal_type(two_or_three * a_or_b) # revealed: Literal["aa", "bb", "aaa", "bbb"]
38+
```
39+
40+
We treat a type annotation of `float` as a union of `int` and `float`, so union handling is relevant
41+
here:
42+
43+
```py
44+
def f4(x: float, y: float):
45+
reveal_type(x + y) # revealed: int | float
46+
reveal_type(x - y) # revealed: int | float
47+
reveal_type(x * y) # revealed: int | float
48+
reveal_type(x / y) # revealed: int | float
49+
reveal_type(x // y) # revealed: int | float
50+
reveal_type(x % y) # revealed: int | float
51+
```

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4144,6 +4144,23 @@ impl<'db> TypeInferenceBuilder<'db> {
41444144
op: ast::Operator,
41454145
) -> Option<Type<'db>> {
41464146
match (left_ty, right_ty, op) {
4147+
(Type::Union(lhs_union), rhs, _) => {
4148+
let mut union = UnionBuilder::new(self.db());
4149+
for lhs in lhs_union.elements(self.db()) {
4150+
let result = self.infer_binary_expression_type(*lhs, rhs, op)?;
4151+
union = union.add(result);
4152+
}
4153+
Some(union.build())
4154+
}
4155+
(lhs, Type::Union(rhs_union), _) => {
4156+
let mut union = UnionBuilder::new(self.db());
4157+
for rhs in rhs_union.elements(self.db()) {
4158+
let result = self.infer_binary_expression_type(lhs, *rhs, op)?;
4159+
union = union.add(result);
4160+
}
4161+
Some(union.build())
4162+
}
4163+
41474164
// Non-todo Anys take precedence over Todos (as if we fix this `Todo` in the future,
41484165
// the result would then become Any or Unknown, respectively).
41494166
(any @ Type::Dynamic(DynamicType::Any), _, _)
@@ -4275,7 +4292,6 @@ impl<'db> TypeInferenceBuilder<'db> {
42754292
| Type::SubclassOf(_)
42764293
| Type::Instance(_)
42774294
| Type::KnownInstance(_)
4278-
| Type::Union(_)
42794295
| Type::Intersection(_)
42804296
| Type::AlwaysTruthy
42814297
| Type::AlwaysFalsy
@@ -4292,7 +4308,6 @@ impl<'db> TypeInferenceBuilder<'db> {
42924308
| Type::SubclassOf(_)
42934309
| Type::Instance(_)
42944310
| Type::KnownInstance(_)
4295-
| Type::Union(_)
42964311
| Type::Intersection(_)
42974312
| Type::AlwaysTruthy
42984313
| Type::AlwaysFalsy

0 commit comments

Comments
 (0)