Skip to content

Commit b64bd3d

Browse files
authored
[mypyc] Support __(r)divmod__ dunders (#14613)
Pretty simple. Towards mypyc/mypyc#553.
1 parent f527656 commit b64bd3d

File tree

5 files changed

+44
-0
lines changed

5 files changed

+44
-0
lines changed

mypyc/codegen/emitclass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
8282
"__rtruediv__": ("nb_true_divide", generate_bin_op_wrapper),
8383
"__floordiv__": ("nb_floor_divide", generate_bin_op_wrapper),
8484
"__rfloordiv__": ("nb_floor_divide", generate_bin_op_wrapper),
85+
"__divmod__": ("nb_divmod", generate_bin_op_wrapper),
86+
"__rdivmod__": ("nb_divmod", generate_bin_op_wrapper),
8587
"__lshift__": ("nb_lshift", generate_bin_op_wrapper),
8688
"__rlshift__": ("nb_lshift", generate_bin_op_wrapper),
8789
"__rshift__": ("nb_rshift", generate_bin_op_wrapper),

mypyc/primitives/generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@
7575
priority=0,
7676
)
7777

78+
79+
function_op(
80+
name="builtins.divmod",
81+
arg_types=[object_rprimitive, object_rprimitive],
82+
return_type=object_rprimitive,
83+
c_function_name="PyNumber_Divmod",
84+
error_kind=ERR_MAGIC,
85+
priority=0,
86+
)
87+
88+
7889
for op, funcname in [
7990
("+=", "PyNumber_InPlaceAdd"),
8091
("-=", "PyNumber_InPlaceSubtract"),

mypyc/test-data/fixtures/ir.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88

99
T = TypeVar('T')
1010
T_co = TypeVar('T_co', covariant=True)
11+
T_contra = TypeVar('T_contra', contravariant=True)
1112
S = TypeVar('S')
1213
K = TypeVar('K') # for keys in mapping
1314
V = TypeVar('V') # for values in mapping
1415

1516
class __SupportsAbs(Protocol[T_co]):
1617
def __abs__(self) -> T_co: pass
1718

19+
class __SupportsDivMod(Protocol[T_contra, T_co]):
20+
def __divmod__(self, other: T_contra) -> T_co: ...
21+
22+
class __SupportsRDivMod(Protocol[T_contra, T_co]):
23+
def __rdivmod__(self, other: T_contra) -> T_co: ...
1824

1925
class object:
2026
def __init__(self) -> None: pass
@@ -42,6 +48,7 @@ def __pow__(self, n: int, modulo: Optional[int] = None) -> int: pass
4248
def __floordiv__(self, x: int) -> int: pass
4349
def __truediv__(self, x: float) -> float: pass
4450
def __mod__(self, x: int) -> int: pass
51+
def __divmod__(self, x: float) -> Tuple[float, float]: pass
4552
def __neg__(self) -> int: pass
4653
def __pos__(self) -> int: pass
4754
def __abs__(self) -> int: pass
@@ -307,6 +314,10 @@ def zip(x: Iterable[T], y: Iterable[S]) -> Iterator[Tuple[T, S]]: ...
307314
def zip(x: Iterable[T], y: Iterable[S], z: Iterable[V]) -> Iterator[Tuple[T, S, V]]: ...
308315
def eval(e: str) -> Any: ...
309316
def abs(x: __SupportsAbs[T]) -> T: ...
317+
@overload
318+
def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ...
319+
@overload
320+
def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ...
310321
def exit() -> None: ...
311322
def min(x: T, y: T) -> T: ...
312323
def max(x: T, y: T) -> T: ...

mypyc/test-data/irbuild-any.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,18 @@ L0:
198198
b = r4
199199
return 1
200200

201+
[case testFunctionBasedOps]
202+
def f() -> None:
203+
a = divmod(5, 2)
204+
[out]
205+
def f():
206+
r0, r1, r2 :: object
207+
r3, a :: tuple[float, float]
208+
L0:
209+
r0 = object 5
210+
r1 = object 2
211+
r2 = PyNumber_Divmod(r0, r1)
212+
r3 = unbox(tuple[float, float], r2)
213+
a = r3
214+
return 1
215+

mypyc/test-data/run-dunders.test

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,9 @@ class C:
402402
def __floordiv__(self, y: int) -> int:
403403
return self.x + y + 30
404404

405+
def __divmod__(self, y: int) -> int:
406+
return self.x + y + 40
407+
405408
def test_generic() -> None:
406409
a: Any = C()
407410
assert a + 3 == 8
@@ -417,11 +420,13 @@ def test_generic() -> None:
417420
assert a @ 3 == 18
418421
assert a / 2 == 27
419422
assert a // 2 == 37
423+
assert divmod(a, 2) == 47
420424

421425
def test_native() -> None:
422426
c = C()
423427
assert c + 3 == 8
424428
assert c - 3 == 2
429+
assert divmod(c, 3) == 48
425430

426431
def test_error() -> None:
427432
a: Any = C()

0 commit comments

Comments
 (0)