diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 79fdd9103371..9e8937434025 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -13,6 +13,7 @@ generate_dunder_wrapper, generate_get_wrapper, generate_hash_wrapper, + generate_ipow_wrapper, generate_len_wrapper, generate_richcompare_wrapper, generate_set_del_item_wrapper, @@ -109,6 +110,11 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: "__ior__": ("nb_inplace_or", generate_dunder_wrapper), "__ixor__": ("nb_inplace_xor", generate_dunder_wrapper), "__imatmul__": ("nb_inplace_matrix_multiply", generate_dunder_wrapper), + # Ternary operations. (yes, really) + # These are special cased in generate_bin_op_wrapper(). + "__pow__": ("nb_power", generate_bin_op_wrapper), + "__rpow__": ("nb_power", generate_bin_op_wrapper), + "__ipow__": ("nb_inplace_power", generate_ipow_wrapper), } AS_ASYNC_SLOT_DEFS: SlotTable = { diff --git a/mypyc/codegen/emitwrapper.py b/mypyc/codegen/emitwrapper.py index 1fa1e8548e07..ed03bb7948cc 100644 --- a/mypyc/codegen/emitwrapper.py +++ b/mypyc/codegen/emitwrapper.py @@ -301,6 +301,32 @@ def generate_dunder_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: return gen.wrapper_name() +def generate_ipow_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: + """Generate a wrapper for native __ipow__. + + Since __ipow__ fills a ternary slot, but almost no one defines __ipow__ to take three + arguments, the wrapper needs to tweaked to force it to accept three arguments. + """ + gen = WrapperGenerator(cl, emitter) + gen.set_target(fn) + assert len(fn.args) in (2, 3), "__ipow__ should only take 2 or 3 arguments" + gen.arg_names = ["self", "exp", "mod"] + gen.emit_header() + gen.emit_arg_processing() + handle_third_pow_argument( + fn, + emitter, + gen, + if_unsupported=[ + 'PyErr_SetString(PyExc_TypeError, "__ipow__ takes 2 positional arguments but 3 were given");', + "return NULL;", + ], + ) + gen.emit_call() + gen.finish() + return gen.wrapper_name() + + def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """Generates a wrapper for a native binary dunder method. @@ -311,13 +337,16 @@ def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """ gen = WrapperGenerator(cl, emitter) gen.set_target(fn) - gen.arg_names = ["left", "right"] + if fn.name in ("__pow__", "__rpow__"): + gen.arg_names = ["left", "right", "mod"] + else: + gen.arg_names = ["left", "right"] wrapper_name = gen.wrapper_name() gen.emit_header() if fn.name not in reverse_op_methods and fn.name in reverse_op_method_names: # There's only a reverse operator method. - generate_bin_op_reverse_only_wrapper(emitter, gen) + generate_bin_op_reverse_only_wrapper(fn, emitter, gen) else: rmethod = reverse_op_methods[fn.name] fn_rev = cl.get_method(rmethod) @@ -334,6 +363,7 @@ def generate_bin_op_forward_only_wrapper( fn: FuncIR, emitter: Emitter, gen: WrapperGenerator ) -> None: gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False) + handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"]) gen.emit_call(not_implemented_handler="goto typefail;") gen.emit_error_handling() emitter.emit_label("typefail") @@ -352,19 +382,16 @@ def generate_bin_op_forward_only_wrapper( # if not isinstance(other, int): # return NotImplemented # ... - rmethod = reverse_op_methods[fn.name] - emitter.emit_line(f"_Py_IDENTIFIER({rmethod});") - emitter.emit_line( - 'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format( - op_methods_to_symbols[fn.name], rmethod - ) - ) + generate_bin_op_reverse_dunder_call(fn, emitter, reverse_op_methods[fn.name]) gen.finish() -def generate_bin_op_reverse_only_wrapper(emitter: Emitter, gen: WrapperGenerator) -> None: +def generate_bin_op_reverse_only_wrapper( + fn: FuncIR, emitter: Emitter, gen: WrapperGenerator +) -> None: gen.arg_names = ["right", "left"] gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False) + handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"]) gen.emit_call() gen.emit_error_handling() emitter.emit_label("typefail") @@ -390,7 +417,14 @@ def generate_bin_op_both_wrappers( ) ) gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False) - gen.emit_call(not_implemented_handler="goto typefail;") + handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail2;"]) + # Ternary __rpow__ calls aren't a thing so immediately bail + # if ternary __pow__ returns NotImplemented. + if fn.name == "__pow__" and len(fn.args) == 3: + fwd_not_implemented_handler = "goto typefail2;" + else: + fwd_not_implemented_handler = "goto typefail;" + gen.emit_call(not_implemented_handler=fwd_not_implemented_handler) gen.emit_error_handling() emitter.emit_line("}") emitter.emit_label("typefail") @@ -402,15 +436,11 @@ def generate_bin_op_both_wrappers( gen.set_target(fn_rev) gen.arg_names = ["right", "left"] gen.emit_arg_processing(error=GotoHandler("typefail2"), raise_exception=False) + handle_third_pow_argument(fn_rev, emitter, gen, if_unsupported=["goto typefail2;"]) gen.emit_call() gen.emit_error_handling() emitter.emit_line("} else {") - emitter.emit_line(f"_Py_IDENTIFIER({fn_rev.name});") - emitter.emit_line( - 'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format( - op_methods_to_symbols[fn.name], fn_rev.name - ) - ) + generate_bin_op_reverse_dunder_call(fn, emitter, fn_rev.name) emitter.emit_line("}") emitter.emit_label("typefail2") emitter.emit_line("Py_INCREF(Py_NotImplemented);") @@ -418,6 +448,47 @@ def generate_bin_op_both_wrappers( gen.finish() +def generate_bin_op_reverse_dunder_call(fn: FuncIR, emitter: Emitter, rmethod: str) -> None: + if fn.name in ("__pow__", "__rpow__"): + # Ternary pow() will never call the reverse dunder. + emitter.emit_line("if (obj_mod == Py_None) {") + emitter.emit_line(f"_Py_IDENTIFIER({rmethod});") + emitter.emit_line( + 'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format( + op_methods_to_symbols[fn.name], rmethod + ) + ) + if fn.name in ("__pow__", "__rpow__"): + emitter.emit_line("} else {") + emitter.emit_line("Py_INCREF(Py_NotImplemented);") + emitter.emit_line("return Py_NotImplemented;") + emitter.emit_line("}") + + +def handle_third_pow_argument( + fn: FuncIR, emitter: Emitter, gen: WrapperGenerator, *, if_unsupported: list[str] +) -> None: + if fn.name not in ("__pow__", "__rpow__", "__ipow__"): + return + + if (fn.name in ("__pow__", "__ipow__") and len(fn.args) == 2) or fn.name == "__rpow__": + # If the power dunder only supports two arguments and the third + # argument (AKA mod) is set to a non-default value, simply bail. + # + # Importantly, this prevents any ternary __rpow__ calls from + # happening (as per the language specification). + emitter.emit_line("if (obj_mod != Py_None) {") + for line in if_unsupported: + emitter.emit_line(line) + emitter.emit_line("}") + # The slot wrapper will receive three arguments, but the call only + # supports two so make sure that the third argument isn't passed + # along. This is needed as two-argument __(i)pow__ is allowed and + # rather common. + if len(gen.arg_names) == 3: + gen.arg_names.pop() + + RICHCOMPARE_OPS = { "__lt__": "Py_LT", "__gt__": "Py_GT", diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index befa397051ef..016a6d3ea9e0 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -344,6 +344,7 @@ CPyTagged CPyObject_Hash(PyObject *o); PyObject *CPyObject_GetAttr3(PyObject *v, PyObject *name, PyObject *defl); PyObject *CPyIter_Next(PyObject *iter); PyObject *CPyNumber_Power(PyObject *base, PyObject *index); +PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index); PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end); diff --git a/mypyc/lib-rt/generic_ops.c b/mypyc/lib-rt/generic_ops.c index 2f4a7941a6da..260cfec5b360 100644 --- a/mypyc/lib-rt/generic_ops.c +++ b/mypyc/lib-rt/generic_ops.c @@ -41,6 +41,11 @@ PyObject *CPyNumber_Power(PyObject *base, PyObject *index) return PyNumber_Power(base, index, Py_None); } +PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index) +{ + return PyNumber_InPlacePower(base, index, Py_None); +} + PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) { PyObject *start_obj = CPyTagged_AsObject(start); PyObject *end_obj = CPyTagged_AsObject(end); diff --git a/mypyc/primitives/generic_ops.py b/mypyc/primitives/generic_ops.py index 4f04608d11f3..3caec0a9875e 100644 --- a/mypyc/primitives/generic_ops.py +++ b/mypyc/primitives/generic_ops.py @@ -109,14 +109,25 @@ priority=0, ) -binary_op( - name="**", - arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, - error_kind=ERR_MAGIC, - c_function_name="CPyNumber_Power", - priority=0, -) +for op, c_function in (("**", "CPyNumber_Power"), ("**=", "CPyNumber_InPlacePower")): + binary_op( + name=op, + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + error_kind=ERR_MAGIC, + c_function_name=c_function, + priority=0, + ) + +for arg_count, c_function in ((2, "CPyNumber_Power"), (3, "PyNumber_Power")): + function_op( + name="builtins.pow", + arg_types=[object_rprimitive] * arg_count, + return_type=object_rprimitive, + error_kind=ERR_MAGIC, + c_function_name=c_function, + priority=0, + ) binary_op( name="in", diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 37aab1d826d7..27e225f273bc 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -22,6 +22,21 @@ def __divmod__(self, other: T_contra) -> T_co: ... class __SupportsRDivMod(Protocol[T_contra, T_co]): def __rdivmod__(self, other: T_contra) -> T_co: ... +_M = TypeVar("_M", contravariant=True) + +class __SupportsPow2(Protocol[T_contra, T_co]): + def __pow__(self, other: T_contra) -> T_co: ... + +class __SupportsPow3NoneOnly(Protocol[T_contra, T_co]): + def __pow__(self, other: T_contra, modulo: None = ...) -> T_co: ... + +class __SupportsPow3(Protocol[T_contra, _M, T_co]): + def __pow__(self, other: T_contra, modulo: _M) -> T_co: ... + +__SupportsSomeKindOfPow = Union[ + __SupportsPow2[Any, Any], __SupportsPow3NoneOnly[Any, Any] | __SupportsPow3[Any, Any, Any] +] + class object: def __init__(self) -> None: pass def __eq__(self, x: object) -> bool: pass @@ -99,6 +114,7 @@ def __add__(self, n: float) -> float: pass def __sub__(self, n: float) -> float: pass def __mul__(self, n: float) -> float: pass def __truediv__(self, n: float) -> float: pass + def __pow__(self, n: float) -> float: pass def __neg__(self) -> float: pass def __pos__(self) -> float: pass def __abs__(self) -> float: pass @@ -318,6 +334,12 @@ def abs(x: __SupportsAbs[T]) -> T: ... def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ... @overload def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ... +@overload +def pow(base: __SupportsPow2[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ... +@overload +def pow(base: __SupportsPow3NoneOnly[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ... +@overload +def pow(base: __SupportsPow3[T_contra, _M, T_co], exp: T_contra, mod: _M) -> T_co: ... def exit() -> None: ... def min(x: T, y: T) -> T: ... def max(x: T, y: T) -> T: ... diff --git a/mypyc/test-data/irbuild-any.test b/mypyc/test-data/irbuild-any.test index 8cc626100262..8d4e085179ae 100644 --- a/mypyc/test-data/irbuild-any.test +++ b/mypyc/test-data/irbuild-any.test @@ -201,6 +201,10 @@ L0: [case testFunctionBasedOps] def f() -> None: a = divmod(5, 2) +def f2() -> int: + return pow(2, 5) +def f3() -> float: + return pow(2, 5, 3) [out] def f(): r0, r1, r2 :: object @@ -212,4 +216,25 @@ L0: r3 = unbox(tuple[float, float], r2) a = r3 return 1 +def f2(): + r0, r1, r2 :: object + r3 :: int +L0: + r0 = object 2 + r1 = object 5 + r2 = CPyNumber_Power(r0, r1) + r3 = unbox(int, r2) + return r3 +def f3(): + r0, r1, r2, r3 :: object + r4 :: int + r5 :: object +L0: + r0 = object 2 + r1 = object 5 + r2 = object 3 + r3 = PyNumber_Power(r0, r1, r2) + r4 = unbox(int, r3) + r5 = box(int, r4) + return r5 diff --git a/mypyc/test-data/run-dunders.test b/mypyc/test-data/run-dunders.test index 23323c7244de..2845187de2c3 100644 --- a/mypyc/test-data/run-dunders.test +++ b/mypyc/test-data/run-dunders.test @@ -405,6 +405,9 @@ class C: def __divmod__(self, y: int) -> int: return self.x + y + 40 + def __pow__(self, y: int) -> int: + return self.x + y + 50 + def test_generic() -> None: a: Any = C() assert a + 3 == 8 @@ -421,12 +424,14 @@ def test_generic() -> None: assert a / 2 == 27 assert a // 2 == 37 assert divmod(a, 2) == 47 + assert a ** 2 == 57 def test_native() -> None: c = C() assert c + 3 == 8 assert c - 3 == 2 assert divmod(c, 3) == 48 + assert c ** 3 == 58 def test_error() -> None: a: Any = C() @@ -442,6 +447,12 @@ def test_error() -> None: assert str(e) == "unsupported operand type(s) for -: 'C' and 'str'" else: assert False + try: + a ** 'x' + except TypeError as e: + assert str(e) == "unsupported operand type(s) for **: 'C' and 'str'" + else: + assert False [case testDundersBinaryReverse] from typing import Any @@ -462,12 +473,20 @@ class C: def __rsub__(self, y: int) -> int: return self.x - y - 1 + def __pow__(self, y: int) -> int: + return self.x**y + + def __rpow__(self, y: int) -> int: + return self.x**y + 1 + def test_generic() -> None: a: Any = C() assert a + 3 == 8 assert 4 + a == 10 assert a - 3 == 2 assert 4 - a == 0 + assert a**3 == 125 + assert 4**a == 626 def test_native() -> None: c = C() @@ -475,6 +494,8 @@ def test_native() -> None: assert 4 + c == 10 assert c - 3 == 2 assert 4 - c == 0 + assert c**3 == 125 + assert 4**c == 626 def test_errors() -> None: a: Any = C() @@ -497,20 +518,37 @@ def test_errors() -> None: 'must be str, not C') else: assert False + try: + 'x' ** a + except TypeError as e: + assert str(e) == "unsupported operand type(s) for ** or pow(): 'str' and 'C'" + else: + assert False + class F: def __add__(self, x: int) -> int: return 5 + def __pow__(self, x: int) -> int: + return -5 + class G: def __add__(self, x: int) -> int: return 33 + def __pow__(self, x: int) -> int: + return -33 + def __radd__(self, x: F) -> int: return 6 + def __rpow__(self, x: F) -> int: + return -6 + def test_type_mismatch_fall_back_to_reverse() -> None: assert F() + G() == 6 + assert F()**G() == -6 [case testDundersBinaryNotImplemented] from typing import Any, Union @@ -718,6 +756,10 @@ class C: self.x += y + 5 return self + def __ipow__(self, y: int, __mod_throwaway: None = None) -> C: + self.x **= y + return self + def test_generic_1() -> None: c: Any = C() c += 3 @@ -732,6 +774,8 @@ def test_generic_1() -> None: assert c.x == 16 c //= 4 assert c.x == 40 + c **= 2 + assert c.x == 1600 def test_generic_2() -> None: c: Any = C() @@ -756,6 +800,8 @@ def test_native() -> None: assert c.x == 3 c *= 3 assert c.x == 9 + c **= 2 + assert c.x == 81 def test_error() -> None: c: Any = C() @@ -812,3 +858,88 @@ def test_dunder_min() -> None: assert max(y2, x2).val == 'xxx' assert min(y2, z2).val == 'zzz' assert max(x2, z2).val == 'zzz' + + +[case testDundersPowerSpecial] +import sys +from typing import Any, Optional +from testutil import assertRaises + +class Forward: + def __pow__(self, exp: int, mod: Optional[int] = None) -> int: + if mod is None: + return 2**exp + else: + return 2**exp % mod + +class ForwardModRequired: + def __pow__(self, exp: int, mod: int) -> int: + return 2**exp % mod + +class ForwardNotImplemented: + def __pow__(self, exp: int, mod: Optional[object] = None) -> Any: + return NotImplemented + +class Reverse: + def __rpow__(self, exp: int) -> int: + return 2**exp + 1 + +class Both: + def __pow__(self, exp: int, mod: Optional[int] = None) -> int: + if mod is None: + return 2**exp + else: + return 2**exp % mod + + def __rpow__(self, exp: int) -> int: + return 2**exp + 1 + +class Child(ForwardNotImplemented): + def __rpow__(self, exp: object) -> int: + return 50 + +class Inplace: + value = 2 + + def __ipow__(self, exp: int, mod: Optional[int] = None) -> "Inplace": + self.value **= exp - (mod or 0) + return self + +def test_native() -> None: + f = Forward() + assert f**3 == 8 + assert pow(f, 3) == 8 + assert pow(f, 3, 3) == 2 + assert pow(ForwardModRequired(), 3, 3) == 2 + b = Both() + assert b**3 == 8 + assert 3**b == 9 + assert pow(b, 3) == 8 + assert pow(b, 3, 3) == 2 + i = Inplace() + i **= 2 + assert i.value == 4 + +def test_errors() -> None: + if sys.version_info[0] >= 3 and sys.version_info[1] >= 10: + op = "** or pow()" + else: + op = "pow()" + + f = Forward() + with assertRaises(TypeError, f"unsupported operand type(s) for {op}: 'Forward', 'int', 'str'"): + pow(f, 3, "x") # type: ignore + with assertRaises(TypeError, "unsupported operand type(s) for **: 'Forward' and 'str'"): + f**"x" # type: ignore + r = Reverse() + with assertRaises(TypeError, "unsupported operand type(s) for ** or pow(): 'str' and 'Reverse'"): + "x"**r # type: ignore + with assertRaises(TypeError, f"unsupported operand type(s) for {op}: 'int', 'Reverse', 'int'"): + # Ternary pow() does not fallback to __rpow__ if LHS's __pow__ returns NotImplemented. + pow(3, r, 3) # type: ignore + with assertRaises(TypeError, f"unsupported operand type(s) for {op}: 'ForwardNotImplemented', 'Child', 'int'"): + # Ternary pow() does not try RHS's __rpow__ first when it's a subclass and redefines + # __rpow__ unlike other ops. + pow(ForwardNotImplemented(), Child(), 3) # type: ignore + with assertRaises(TypeError, "unsupported operand type(s) for ** or pow(): 'ForwardModRequired' and 'int'"): + ForwardModRequired()**3 # type: ignore