Skip to content

Commit

Permalink
Simplify test cases where both __*__ and __r*__ are defined
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Aug 18, 2020
1 parent 9d8b1e7 commit e44c85b
Showing 1 changed file with 40 additions and 116 deletions.
156 changes: 40 additions & 116 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,124 +280,16 @@ fn rhs_arithmetic() {
}

#[pyclass]
struct LhsOverridesRhs {}
struct LhsAndRhs {}

#[pyproto]
impl PyNumberProtocol for LhsOverridesRhs {
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}

fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}

fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
impl std::fmt::Debug for LhsAndRhs {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "LR")
}

fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}

fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}

fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}

fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<&PyAny>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}

fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}

fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}

fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}

fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}

fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}

fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}

fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}

fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
}

#[pyproto]
impl PyObjectProtocol for LhsOverridesRhs {
fn __repr__(&self) -> &'static str {
"BA"
}
}

#[test]
fn lhs_overrides_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();

let c = PyCell::new(py, LhsOverridesRhs {}).unwrap();
// Not overrided
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'");
py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'");
py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'");
py_run!(py, c, "assert c.__rand__(1) == '1 & RA'");
py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'");
py_run!(py, c, "assert c.__ror__(1) == '1 | RA'");
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'");
// Overrided
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
py_run!(py, c, "assert 1 * c == '1 * BA'");
py_run!(py, c, "assert 1 << c == '1 << BA'");
py_run!(py, c, "assert 1 >> c == '1 >> BA'");
py_run!(py, c, "assert 1 & c == '1 & BA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
py_run!(py, c, "assert 1 | c == '1 | BA'");
py_run!(py, c, "assert 1 ** c == '1 ** BA'");
}

#[pyclass]
#[derive(Debug)]
struct LhsFellbackToRhs {}

#[pyproto]
impl PyNumberProtocol for LhsFellbackToRhs {
impl PyNumberProtocol for LhsAndRhs {
fn __add__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}
Expand Down Expand Up @@ -480,19 +372,50 @@ impl PyNumberProtocol for LhsFellbackToRhs {
}

#[pyproto]
impl PyObjectProtocol for LhsFellbackToRhs {
impl PyObjectProtocol for LhsAndRhs {
fn __repr__(&self) -> &'static str {
"BA"
}
}

#[test]
fn lhs_overrides_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();

let a = PyCell::new(py, LhsAndRhs {}).unwrap();
let b = PyCell::new(py, LhsAndRhs {}).unwrap();
// Not overrided
py_run!(py, a b, "assert a.__radd__(b) == 'BA + RA'");
py_run!(py, a b, "assert a.__rsub__(b) == 'BA - RA'");
py_run!(py, a b, "assert a.__rmul__(b) == 'BA * RA'");
py_run!(py, a b, "assert a.__rlshift__(b) == 'BA << RA'");
py_run!(py, a b, "assert a.__rrshift__(b) == 'BA >> RA'");
py_run!(py, a b, "assert a.__rand__(b) == 'BA & RA'");
py_run!(py, a b, "assert a.__rxor__(b) == 'BA ^ RA'");
py_run!(py, a b, "assert a.__ror__(b) == 'BA | RA'");
py_run!(py, a b, "assert a.__rpow__(b) == 'BA ** RA'");
py_run!(py, a b, "assert a.__rmatmul__(b) == 'BA @ RA'");
// Overrided
py_run!(py, a b, "assert a + b == 'LR + BA'");
py_run!(py, a b, "assert a - b == 'LR - BA'");
py_run!(py, a b, "assert a * b == 'LR * BA'");
py_run!(py, a b, "assert a << b == 'LR << BA'");
py_run!(py, a b, "assert a >> b == 'LR >> BA'");
py_run!(py, a b, "assert a & b == 'LR & BA'");
py_run!(py, a b, "assert a ^ b == 'LR ^ BA'");
py_run!(py, a b, "assert a | b == 'LR | BA'");
py_run!(py, a b, "assert a ** b == 'LR ** BA'");
py_run!(py, a b, "assert a @ b == 'LR @ BA'");
}

#[test]
fn lhs_fellback_to_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();

let c = PyCell::new(py, LhsFellbackToRhs {}).unwrap();
// Fallbacked to RHS because of type mismatching
let c = PyCell::new(py, LhsAndRhs {}).unwrap();
// Fellback to RHS because of type mismatching
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
Expand All @@ -502,6 +425,7 @@ fn lhs_fellback_to_rhs() {
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
py_run!(py, c, "assert 1 | c == '1 | RA'");
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
py_run!(py, c, "assert 1 @ c == '1 @ RA'");
}

#[pyclass]
Expand Down

0 comments on commit e44c85b

Please sign in to comment.