Skip to content

Commit

Permalink
[mypyc] Support __(r)divmod__ dunders (#14613)
Browse files Browse the repository at this point in the history
Pretty simple. Towards mypyc/mypyc#553.
  • Loading branch information
ichard26 committed Feb 5, 2023
1 parent f527656 commit b64bd3d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"__rtruediv__": ("nb_true_divide", generate_bin_op_wrapper),
"__floordiv__": ("nb_floor_divide", generate_bin_op_wrapper),
"__rfloordiv__": ("nb_floor_divide", generate_bin_op_wrapper),
"__divmod__": ("nb_divmod", generate_bin_op_wrapper),
"__rdivmod__": ("nb_divmod", generate_bin_op_wrapper),
"__lshift__": ("nb_lshift", generate_bin_op_wrapper),
"__rlshift__": ("nb_lshift", generate_bin_op_wrapper),
"__rshift__": ("nb_rshift", generate_bin_op_wrapper),
Expand Down
11 changes: 11 additions & 0 deletions mypyc/primitives/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@
priority=0,
)


function_op(
name="builtins.divmod",
arg_types=[object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
c_function_name="PyNumber_Divmod",
error_kind=ERR_MAGIC,
priority=0,
)


for op, funcname in [
("+=", "PyNumber_InPlaceAdd"),
("-=", "PyNumber_InPlaceSubtract"),
Expand Down
11 changes: 11 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@

T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
T_contra = TypeVar('T_contra', contravariant=True)
S = TypeVar('S')
K = TypeVar('K') # for keys in mapping
V = TypeVar('V') # for values in mapping

class __SupportsAbs(Protocol[T_co]):
def __abs__(self) -> T_co: pass

class __SupportsDivMod(Protocol[T_contra, T_co]):
def __divmod__(self, other: T_contra) -> T_co: ...

class __SupportsRDivMod(Protocol[T_contra, T_co]):
def __rdivmod__(self, other: T_contra) -> T_co: ...

class object:
def __init__(self) -> None: pass
Expand Down Expand Up @@ -42,6 +48,7 @@ def __pow__(self, n: int, modulo: Optional[int] = None) -> int: pass
def __floordiv__(self, x: int) -> int: pass
def __truediv__(self, x: float) -> float: pass
def __mod__(self, x: int) -> int: pass
def __divmod__(self, x: float) -> Tuple[float, float]: pass
def __neg__(self) -> int: pass
def __pos__(self) -> int: pass
def __abs__(self) -> int: pass
Expand Down Expand Up @@ -307,6 +314,10 @@ def zip(x: Iterable[T], y: Iterable[S]) -> Iterator[Tuple[T, S]]: ...
def zip(x: Iterable[T], y: Iterable[S], z: Iterable[V]) -> Iterator[Tuple[T, S, V]]: ...
def eval(e: str) -> Any: ...
def abs(x: __SupportsAbs[T]) -> T: ...
@overload
def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ...
@overload
def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ...
def exit() -> None: ...
def min(x: T, y: T) -> T: ...
def max(x: T, y: T) -> T: ...
Expand Down
15 changes: 15 additions & 0 deletions mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,18 @@ L0:
b = r4
return 1

[case testFunctionBasedOps]
def f() -> None:
a = divmod(5, 2)
[out]
def f():
r0, r1, r2 :: object
r3, a :: tuple[float, float]
L0:
r0 = object 5
r1 = object 2
r2 = PyNumber_Divmod(r0, r1)
r3 = unbox(tuple[float, float], r2)
a = r3
return 1

5 changes: 5 additions & 0 deletions mypyc/test-data/run-dunders.test
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ class C:
def __floordiv__(self, y: int) -> int:
return self.x + y + 30

def __divmod__(self, y: int) -> int:
return self.x + y + 40

def test_generic() -> None:
a: Any = C()
assert a + 3 == 8
Expand All @@ -417,11 +420,13 @@ def test_generic() -> None:
assert a @ 3 == 18
assert a / 2 == 27
assert a // 2 == 37
assert divmod(a, 2) == 47

def test_native() -> None:
c = C()
assert c + 3 == 8
assert c - 3 == 2
assert divmod(c, 3) == 48

def test_error() -> None:
a: Any = C()
Expand Down

0 comments on commit b64bd3d

Please sign in to comment.