Skip to content

Commit

Permalink
More tests for RHS
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Aug 16, 2020
1 parent c2d8f30 commit e385b7d
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/class/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ macro_rules! py_binary_reversed_num_func {
}};
}

macro_rules! py_binary_fallbacked_num_func {
macro_rules! py_binary_fallback_num_func {
($class:ident, $lop_trait: ident :: $lop: ident, $rop_trait: ident :: $rop: ident) => {{
unsafe extern "C" fn wrap<T>(
lhs: *mut ffi::PyObject,
Expand Down
24 changes: 12 additions & 12 deletions src/class/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberAddProtocol<'p> + for<'p> PyNumberRAddProtocol<'p>,
{
self.nb_add = py_binary_fallbacked_num_func!(
self.nb_add = py_binary_fallback_num_func!(
T,
PyNumberAddProtocol::__add__,
PyNumberRAddProtocol::__radd__
Expand All @@ -611,7 +611,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberSubProtocol<'p> + for<'p> PyNumberRSubProtocol<'p>,
{
self.nb_subtract = py_binary_fallbacked_num_func!(
self.nb_subtract = py_binary_fallback_num_func!(
T,
PyNumberSubProtocol::__sub__,
PyNumberRSubProtocol::__rsub__
Expand All @@ -633,7 +633,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberMulProtocol<'p> + for<'p> PyNumberRMulProtocol<'p>,
{
self.nb_multiply = py_binary_fallbacked_num_func!(
self.nb_multiply = py_binary_fallback_num_func!(
T,
PyNumberMulProtocol::__mul__,
PyNumberRMulProtocol::__rmul__
Expand Down Expand Up @@ -661,7 +661,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberDivmodProtocol<'p> + for<'p> PyNumberRDivmodProtocol<'p>,
{
self.nb_divmod = py_binary_fallbacked_num_func!(
self.nb_divmod = py_binary_fallback_num_func!(
T,
PyNumberDivmodProtocol::__divmod__,
PyNumberRDivmodProtocol::__rdivmod__
Expand Down Expand Up @@ -780,7 +780,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberLShiftProtocol<'p> + for<'p> PyNumberRLShiftProtocol<'p>,
{
self.nb_lshift = py_binary_fallbacked_num_func!(
self.nb_lshift = py_binary_fallback_num_func!(
T,
PyNumberLShiftProtocol::__lshift__,
PyNumberRLShiftProtocol::__rlshift__
Expand All @@ -802,7 +802,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberRShiftProtocol<'p> + for<'p> PyNumberRRShiftProtocol<'p>,
{
self.nb_rshift = py_binary_fallbacked_num_func!(
self.nb_rshift = py_binary_fallback_num_func!(
T,
PyNumberRShiftProtocol::__rshift__,
PyNumberRRShiftProtocol::__rrshift__
Expand All @@ -824,7 +824,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberAndProtocol<'p> + for<'p> PyNumberRAndProtocol<'p>,
{
self.nb_and = py_binary_fallbacked_num_func!(
self.nb_and = py_binary_fallback_num_func!(
T,
PyNumberAndProtocol::__and__,
PyNumberRAndProtocol::__rand__
Expand All @@ -846,7 +846,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberXorProtocol<'p> + for<'p> PyNumberRXorProtocol<'p>,
{
self.nb_xor = py_binary_fallbacked_num_func!(
self.nb_xor = py_binary_fallback_num_func!(
T,
PyNumberXorProtocol::__xor__,
PyNumberRXorProtocol::__rxor__
Expand All @@ -868,7 +868,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberOrProtocol<'p> + for<'p> PyNumberROrProtocol<'p>,
{
self.nb_or = py_binary_fallbacked_num_func!(
self.nb_or = py_binary_fallback_num_func!(
T,
PyNumberOrProtocol::__or__,
PyNumberROrProtocol::__ror__
Expand Down Expand Up @@ -980,7 +980,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberFloordivProtocol<'p> + for<'p> PyNumberRFloordivProtocol<'p>,
{
self.nb_floor_divide = py_binary_fallbacked_num_func!(
self.nb_floor_divide = py_binary_fallback_num_func!(
T,
PyNumberFloordivProtocol::__floordiv__,
PyNumberRFloordivProtocol::__rfloordiv__
Expand All @@ -1003,7 +1003,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberTruedivProtocol<'p> + for<'p> PyNumberRTruedivProtocol<'p>,
{
self.nb_true_divide = py_binary_fallbacked_num_func!(
self.nb_true_divide = py_binary_fallback_num_func!(
T,
PyNumberTruedivProtocol::__truediv__,
PyNumberRTruedivProtocol::__rtruediv__
Expand Down Expand Up @@ -1046,7 +1046,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberMatmulProtocol<'p> + for<'p> PyNumberRMatmulProtocol<'p>,
{
self.nb_matrix_multiply = py_binary_fallbacked_num_func!(
self.nb_matrix_multiply = py_binary_fallback_num_func!(
T,
PyNumberMatmulProtocol::__matmul__,
PyNumberRMatmulProtocol::__rmatmul__
Expand Down
174 changes: 147 additions & 27 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,46 @@ fn rhs_arithmetic() {
}

#[pyclass]
struct LhsAndRhsArithmetic {}
struct LhsOverridesRhs {}

#[pyproto]
impl PyNumberProtocol for LhsAndRhsArithmetic {
impl PyNumberProtocol for LhsOverridesRhs {
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}

fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}

fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}

fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}

fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}

fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}

fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> String {
format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_)
}

fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}
Expand All @@ -292,25 +328,37 @@ impl PyNumberProtocol for LhsAndRhsArithmetic {
format!("{:?} - RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}

fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}

fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}

fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
}

#[pyproto]
impl PyObjectProtocol for LhsAndRhsArithmetic {
impl PyObjectProtocol for LhsOverridesRhs {
fn __repr__(&self) -> &'static str {
"BA"
}
Expand All @@ -321,66 +369,138 @@ fn lhs_override_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();

let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap();
let c = PyCell::new(py, LhsOverridesRhs {}).unwrap();
// Not overrided
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'");
py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'");
py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'");
py_run!(py, c, "assert c.__rand__(1) == '1 & RA'");
py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'");
py_run!(py, c, "assert c.__ror__(1) == '1 | RA'");
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'");
// Overrided
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
py_run!(py, c, "assert 1 * c == '1 * BA'");
py_run!(py, c, "assert 1 << c == '1 << BA'");
py_run!(py, c, "assert 1 >> c == '1 >> BA'");
py_run!(py, c, "assert 1 & c == '1 & BA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
py_run!(py, c, "assert 1 | c == '1 | BA'");
py_run!(py, c, "assert 1 ** c == '1 ** BA'");
}

#[pyclass]
#[derive(Debug)]
struct Lhs2Rhs {}
struct LhsFellbackToRhs {}

#[pyproto]
impl PyNumberProtocol for Lhs2Rhs {
fn __add__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
impl PyNumberProtocol for LhsFellbackToRhs {
fn __add__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}
fn __sub__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {

fn __sub__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}
fn __pow__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny, _mod: Option<usize>) -> String {

fn __mul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}

fn __lshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}

fn __rshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}

fn __and__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}

fn __xor__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}

fn __or__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}

fn __pow__(lhs: PyRef<Self>, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
fn __matmul__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {

fn __matmul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} @ {:?}", lhs, rhs)
}

fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}
fn __rsub__(&self, rhs: &PyAny) -> String {
format!("{:?} - RA", rhs)

fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}

fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}
fn __rpow__(&self, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** RA", rhs)

fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}
fn __rmatmul__(&self, rhs: &PyAny) -> String {
format!("{:?} @ RA", rhs)

fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}

fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}

fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}

fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}

fn __rmatmul__(&self, other: &PyAny) -> String {
format!("{:?} @ RA", other)
}
}

#[pyproto]
impl PyObjectProtocol for Lhs2Rhs {
impl PyObjectProtocol for LhsFellbackToRhs {
fn __repr__(&self) -> &'static str {
"BA"
}
}

#[test]
fn lhs_fallbacked_to_rhs() {
fn lhs_fellback_to_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();

let c = PyCell::new(py, Lhs2Rhs {}).unwrap();
let c = PyCell::new(py, LhsFellbackToRhs {}).unwrap();
// Fallbacked to RHS because of type mismatching
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
py_run!(py, c, "assert 1 << c == '1 << RA'");
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
py_run!(py, c, "assert 1 & c == '1 & RA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
py_run!(py, c, "assert 1 | c == '1 | RA'");
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
}

Expand Down

0 comments on commit e385b7d

Please sign in to comment.