Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow omitting return type for #[pyproto] #998

Merged
merged 1 commit into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- `#[pyproto]` is re-implemented without specialization. [#961](https://github.com/PyO3/pyo3/pull/961)
- `PyClassAlloc::alloc` is renamed to `PyClassAlloc::new`. [#990](https://github.com/PyO3/pyo3/pull/990)
- `#[pyproto]` methods can now have return value `T` or `PyResult<T>` (previously only `PyResult<T>` was supported). [#996](https://github.com/PyO3/pyo3/pull/996)
- `#[pyproto]` methods can now skip annotating the return type if it is `()`. [#998](https://github.com/PyO3/pyo3/pull/998)

### Removed
- Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930)
Expand Down
2 changes: 1 addition & 1 deletion guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ Python object behavior, you need to implement the specific trait for your struct
each protocol implementation block has to be annotated with the `#[pyproto]` attribute.

All `#[pyproto]` methods which can be defined below can return `T` instead of `PyResult<T>` if the
method implementation is infallible.
method implementation is infallible. In addition, if the return type is `()`, it can be omitted altogether.

### Basic object customization

Expand Down
23 changes: 10 additions & 13 deletions pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
use crate::utils::print_err;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use syn::Token;

// TODO:
Expand Down Expand Up @@ -75,21 +75,18 @@ pub(crate) fn impl_method_proto(
sig: &mut syn::Signature,
meth: &MethodProto,
) -> TokenStream {
if let MethodProto::Free { proto, .. } = meth {
let p: syn::Path = syn::parse_str(proto).unwrap();
return quote! {
impl<'p> #p<'p> for #cls {}
};
}

let ret_ty = &*if let syn::ReturnType::Type(_, ref ty) = sig.output {
ty.clone()
} else {
panic!("fn return type is not supported")
let ret_ty = match &sig.output {
syn::ReturnType::Default => quote! { () },
syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
};

match *meth {
MethodProto::Free { .. } => unreachable!(),
MethodProto::Free { proto, .. } => {
let p: syn::Path = syn::parse_str(proto).unwrap();
quote! {
impl<'p> #p<'p> for #cls {}
}
}
MethodProto::Unary { proto, .. } => {
let p: syn::Path = syn::parse_str(proto).unwrap();

Expand Down
186 changes: 89 additions & 97 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ impl UnaryArithmetic {

#[pyproto]
impl PyObjectProtocol for UnaryArithmetic {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("UA({})", self.inner))
fn __repr__(&self) -> String {
format!("UA({})", self.inner)
}
}

#[pyproto]
impl PyNumberProtocol for UnaryArithmetic {
fn __neg__(&self) -> PyResult<Self> {
Ok(Self::new(-self.inner))
fn __neg__(&self) -> Self {
Self::new(-self.inner)
}

fn __pos__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner))
fn __pos__(&self) -> Self {
Self::new(self.inner)
}

fn __abs__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner.abs()))
fn __abs__(&self) -> Self {
Self::new(self.inner.abs())
}

fn __round__(&self, _ndigits: Option<u32>) -> PyResult<Self> {
Ok(Self::new(self.inner.round()))
fn __round__(&self, _ndigits: Option<u32>) -> Self {
Self::new(self.inner.round())
}
}

Expand All @@ -60,8 +60,8 @@ struct BinaryArithmetic {}

#[pyproto]
impl PyObjectProtocol for BinaryArithmetic {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("BA")
fn __repr__(&self) -> &'static str {
"BA"
}
}

Expand All @@ -72,56 +72,47 @@ struct InPlaceOperations {

#[pyproto]
impl PyObjectProtocol for InPlaceOperations {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("IPO({:?})", self.value))
fn __repr__(&self) -> String {
format!("IPO({:?})", self.value)
}
}

#[pyproto]
impl PyNumberProtocol for InPlaceOperations {
fn __iadd__(&mut self, other: u32) -> PyResult<()> {
fn __iadd__(&mut self, other: u32) {
self.value += other;
Ok(())
}

fn __isub__(&mut self, other: u32) -> PyResult<()> {
fn __isub__(&mut self, other: u32) {
self.value -= other;
Ok(())
}

fn __imul__(&mut self, other: u32) -> PyResult<()> {
fn __imul__(&mut self, other: u32) {
self.value *= other;
Ok(())
}

fn __ilshift__(&mut self, other: u32) -> PyResult<()> {
fn __ilshift__(&mut self, other: u32) {
self.value <<= other;
Ok(())
}

fn __irshift__(&mut self, other: u32) -> PyResult<()> {
fn __irshift__(&mut self, other: u32) {
self.value >>= other;
Ok(())
}

fn __iand__(&mut self, other: u32) -> PyResult<()> {
fn __iand__(&mut self, other: u32) {
self.value &= other;
Ok(())
}

fn __ixor__(&mut self, other: u32) -> PyResult<()> {
fn __ixor__(&mut self, other: u32) {
self.value ^= other;
Ok(())
}

fn __ior__(&mut self, other: u32) -> PyResult<()> {
fn __ior__(&mut self, other: u32) {
self.value |= other;
Ok(())
}

fn __ipow__(&mut self, other: u32) -> PyResult<()> {
fn __ipow__(&mut self, other: u32) {
self.value = self.value.pow(other);
Ok(())
}
}

Expand Down Expand Up @@ -151,40 +142,40 @@ fn inplace_operations() {

#[pyproto]
impl PyNumberProtocol for BinaryArithmetic {
fn __add__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + {:?}", lhs, rhs))
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - {:?}", lhs, rhs))
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}

fn __mul__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} * {:?}", lhs, rhs))
fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}

fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} << {:?}", lhs, rhs))
fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}

fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} >> {:?}", lhs, rhs))
fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}

fn __and__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} & {:?}", lhs, rhs))
fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}

fn __xor__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} ^ {:?}", lhs, rhs))
fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}

fn __or__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | {:?}", lhs, rhs))
fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_))
fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> String {
format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_)
}
}

Expand Down Expand Up @@ -224,40 +215,40 @@ struct RhsArithmetic {}

#[pyproto]
impl PyNumberProtocol for RhsArithmetic {
fn __radd__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + RA", other))
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}

fn __rsub__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - RA", other))
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}

fn __rmul__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} * RA", other))
fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}

fn __rlshift__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} << RA", other))
fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}

fn __rrshift__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} >> RA", other))
fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}

fn __rand__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} & RA", other))
fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}

fn __rxor__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} ^ RA", other))
fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}

fn __ror__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | RA", other))
fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
}

Expand Down Expand Up @@ -292,35 +283,35 @@ struct LhsAndRhsArithmetic {}

#[pyproto]
impl PyNumberProtocol for LhsAndRhsArithmetic {
fn __radd__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + RA", other))
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}

fn __rsub__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - RA", other))
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}

fn __add__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + {:?}", lhs, rhs))
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - {:?}", lhs, rhs))
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?}", lhs, rhs))
fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
}

#[pyproto]
impl PyObjectProtocol for LhsAndRhsArithmetic {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("BA")
fn __repr__(&self) -> &'static str {
"BA"
}
}

Expand All @@ -345,18 +336,18 @@ struct RichComparisons {}

#[pyproto]
impl PyObjectProtocol for RichComparisons {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("RC")
fn __repr__(&self) -> &'static str {
"RC"
}

fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<String> {
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> String {
match op {
CompareOp::Lt => Ok(format!("{} < {:?}", self.__repr__().unwrap(), other)),
CompareOp::Le => Ok(format!("{} <= {:?}", self.__repr__().unwrap(), other)),
CompareOp::Eq => Ok(format!("{} == {:?}", self.__repr__().unwrap(), other)),
CompareOp::Ne => Ok(format!("{} != {:?}", self.__repr__().unwrap(), other)),
CompareOp::Gt => Ok(format!("{} > {:?}", self.__repr__().unwrap(), other)),
CompareOp::Ge => Ok(format!("{} >= {:?}", self.__repr__().unwrap(), other)),
CompareOp::Lt => format!("{} < {:?}", self.__repr__(), other),
CompareOp::Le => format!("{} <= {:?}", self.__repr__(), other),
CompareOp::Eq => format!("{} == {:?}", self.__repr__(), other),
CompareOp::Ne => format!("{} != {:?}", self.__repr__(), other),
CompareOp::Gt => format!("{} > {:?}", self.__repr__(), other),
CompareOp::Ge => format!("{} >= {:?}", self.__repr__(), other),
}
}
}
Expand All @@ -366,16 +357,17 @@ struct RichComparisons2 {}

#[pyproto]
impl PyObjectProtocol for RichComparisons2 {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("RC2")
fn __repr__(&self) -> &'static str {
"RC2"
}

fn __richcmp__(&self, _other: &PyAny, op: CompareOp) -> PyResult<PyObject> {
fn __richcmp__(&self, _other: &PyAny, op: CompareOp) -> PyObject {
let gil = GILGuard::acquire();
let py = gil.python();
match op {
CompareOp::Eq => Ok(true.to_object(gil.python())),
CompareOp::Ne => Ok(false.to_object(gil.python())),
_ => Ok(gil.python().NotImplemented()),
CompareOp::Eq => true.into_py(py),
CompareOp::Ne => false.into_py(py),
_ => py.NotImplemented(),
}
}
}
Expand Down
Loading