Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make __r*__ methods work with operators #839

Merged
merged 7 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 111 additions & 112 deletions pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,44 @@ pub struct Proto {
pub py_methods: &'static [PyMethod],
}

impl Proto {
pub(crate) fn get_proto<Q>(&self, query: Q) -> Option<&'static MethodProto>
where
Q: PartialEq<&'static str>,
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
{
self.methods.iter().find(|m| query == m.name())
}
pub(crate) fn get_method<Q>(&self, query: Q) -> Option<&'static PyMethod>
where
Q: PartialEq<&'static str>,
{
self.py_methods.iter().find(|m| query == m.name)
}
}

// TODO(kngwyu): Currently only __radd__-like methods use METH_COEXIST to prevent
// __add__-like methods from overriding them.
pub struct PyMethod {
pub name: &'static str,
pub proto: &'static str,
pub can_coexist: bool,
}

impl PyMethod {
const fn coexist(name: &'static str, proto: &'static str) -> Self {
PyMethod {
name,
proto,
can_coexist: true,
}
}
const fn new(name: &'static str, proto: &'static str) -> Self {
PyMethod {
name,
proto,
can_coexist: false,
}
}
}

pub const OBJECT: Proto = Proto {
Expand Down Expand Up @@ -73,18 +108,9 @@ pub const OBJECT: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__format__",
proto: "pyo3::class::basic::FormatProtocolImpl",
},
PyMethod {
name: "__bytes__",
proto: "pyo3::class::basic::BytesProtocolImpl",
},
PyMethod {
name: "__unicode__",
proto: "pyo3::class::basic::UnicodeProtocolImpl",
},
PyMethod::new("__format__", "pyo3::class::basic::FormatProtocolImpl"),
PyMethod::new("__bytes__", "pyo3::class::basic::BytesProtocolImpl"),
PyMethod::new("__unicode__", "pyo3::class::basic::UnicodeProtocolImpl"),
],
};

Expand Down Expand Up @@ -120,14 +146,14 @@ pub const ASYNC: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__aenter__",
proto: "pyo3::class::pyasync::PyAsyncAenterProtocolImpl",
},
PyMethod {
name: "__aexit__",
proto: "pyo3::class::pyasync::PyAsyncAexitProtocolImpl",
},
PyMethod::new(
"__aenter__",
"pyo3::class::pyasync::PyAsyncAenterProtocolImpl",
),
PyMethod::new(
"__aexit__",
"pyo3::class::pyasync::PyAsyncAexitProtocolImpl",
),
],
};

Expand Down Expand Up @@ -165,14 +191,14 @@ pub const CONTEXT: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__enter__",
proto: "pyo3::class::context::PyContextEnterProtocolImpl",
},
PyMethod {
name: "__exit__",
proto: "pyo3::class::context::PyContextExitProtocolImpl",
},
PyMethod::new(
"__enter__",
"pyo3::class::context::PyContextEnterProtocolImpl",
),
PyMethod::new(
"__exit__",
"pyo3::class::context::PyContextExitProtocolImpl",
),
],
};

Expand Down Expand Up @@ -222,14 +248,11 @@ pub const DESCR: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__del__",
proto: "pyo3::class::context::PyDescrDelProtocolImpl",
},
PyMethod {
name: "__set_name__",
proto: "pyo3::class::context::PyDescrNameProtocolImpl",
},
PyMethod::new("__del__", "pyo3::class::context::PyDescrDelProtocolImpl"),
PyMethod::new(
"__set_name__",
"pyo3::class::context::PyDescrNameProtocolImpl",
),
],
};

Expand Down Expand Up @@ -283,10 +306,10 @@ pub const MAPPING: Proto = Proto {
proto: "pyo3::class::mapping::PyMappingReversedProtocol",
},
],
py_methods: &[PyMethod {
name: "__reversed__",
proto: "pyo3::class::mapping::PyMappingReversedProtocolImpl",
}],
py_methods: &[PyMethod::new(
"__reversed__",
"pyo3::class::mapping::PyMappingReversedProtocolImpl",
)],
};

pub const SEQ: Proto = Proto {
Expand Down Expand Up @@ -579,10 +602,9 @@ pub const NUM: Proto = Proto {
pyres: false,
proto: "pyo3::class::number::PyNumberIModProtocol",
},
MethodProto::Ternary {
MethodProto::Binary {
name: "__ipow__",
arg1: "Other",
arg2: "Modulo",
arg: "Other",
pyres: false,
proto: "pyo3::class::number::PyNumberIPowProtocol",
},
Expand Down Expand Up @@ -651,81 +673,58 @@ pub const NUM: Proto = Proto {
pyres: true,
proto: "pyo3::class::number::PyNumberFloatProtocol",
},
MethodProto::Unary {
name: "__round__",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
MethodProto::Unary {
name: "__index__",
pyres: true,
proto: "pyo3::class::number::PyNumberIndexProtocol",
},
],
py_methods: &[
PyMethod {
name: "__radd__",
proto: "pyo3::class::number::PyNumberRAddProtocolImpl",
},
PyMethod {
name: "__rsub__",
proto: "pyo3::class::number::PyNumberRSubProtocolImpl",
},
PyMethod {
name: "__rmul__",
proto: "pyo3::class::number::PyNumberRMulProtocolImpl",
},
PyMethod {
name: "__rmatmul__",
proto: "pyo3::class::number::PyNumberRMatmulProtocolImpl",
},
PyMethod {
name: "__rtruediv__",
proto: "pyo3::class::number::PyNumberRTruedivProtocolImpl",
},
PyMethod {
name: "__rfloordiv__",
proto: "pyo3::class::number::PyNumberRFloordivProtocolImpl",
},
PyMethod {
name: "__rmod__",
proto: "pyo3::class::number::PyNumberRModProtocolImpl",
},
PyMethod {
name: "__rdivmod__",
proto: "pyo3::class::number::PyNumberRDivmodProtocolImpl",
},
PyMethod {
name: "__rpow__",
proto: "pyo3::class::number::PyNumberRPowProtocolImpl",
},
PyMethod {
name: "__rlshift__",
proto: "pyo3::class::number::PyNumberRLShiftProtocolImpl",
},
PyMethod {
name: "__rrshift__",
proto: "pyo3::class::number::PyNumberRRShiftProtocolImpl",
},
PyMethod {
name: "__rand__",
proto: "pyo3::class::number::PyNumberRAndProtocolImpl",
},
PyMethod {
name: "__rxor__",
proto: "pyo3::class::number::PyNumberRXorProtocolImpl",
},
PyMethod {
name: "__ror__",
proto: "pyo3::class::number::PyNumberROrProtocolImpl",
},
PyMethod {
name: "__complex__",
proto: "pyo3::class::number::PyNumberComplexProtocolImpl",
},
PyMethod {
MethodProto::Binary {
name: "__round__",
proto: "pyo3::class::number::PyNumberRoundProtocolImpl",
arg: "NDigits",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
],
py_methods: &[
PyMethod::coexist("__radd__", "pyo3::class::number::PyNumberRAddProtocolImpl"),
PyMethod::coexist("__rsub__", "pyo3::class::number::PyNumberRSubProtocolImpl"),
PyMethod::coexist("__rmul__", "pyo3::class::number::PyNumberRMulProtocolImpl"),
PyMethod::coexist(
"__rmatmul__",
"pyo3::class::number::PyNumberRMatmulProtocolImpl",
),
PyMethod::coexist(
"__rtruediv__",
"pyo3::class::number::PyNumberRTruedivProtocolImpl",
),
PyMethod::coexist(
"__rfloordiv__",
"pyo3::class::number::PyNumberRFloordivProtocolImpl",
),
PyMethod::coexist("__rmod__", "pyo3::class::number::PyNumberRModProtocolImpl"),
PyMethod::coexist(
"__rdivmod__",
"pyo3::class::number::PyNumberRDivmodProtocolImpl",
),
PyMethod::coexist("__rpow__", "pyo3::class::number::PyNumberRPowProtocolImpl"),
PyMethod::coexist(
"__rlshift__",
"pyo3::class::number::PyNumberRLShiftProtocolImpl",
),
PyMethod::coexist(
"__rrshift__",
"pyo3::class::number::PyNumberRRShiftProtocolImpl",
),
PyMethod::coexist("__rand__", "pyo3::class::number::PyNumberRAndProtocolImpl"),
PyMethod::coexist("__rxor__", "pyo3::class::number::PyNumberRXorProtocolImpl"),
PyMethod::coexist("__ror__", "pyo3::class::number::PyNumberROrProtocolImpl"),
PyMethod::new(
"__complex__",
"pyo3::class::number::PyNumberComplexProtocolImpl",
),
PyMethod::new(
"__round__",
"pyo3::class::number::PyNumberRoundProtocolImpl",
),
],
};
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl MethodProto {
}
}

pub fn impl_method_proto(
pub(crate) fn impl_method_proto(
cls: &syn::Type,
sig: &mut syn::Signature,
meth: &MethodProto,
Expand Down
1 change: 0 additions & 1 deletion pyo3-derive-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ fn impl_descriptors(
.collect::<syn::Result<_>>()?;

Ok(quote! {

pyo3::inventory::submit! {
#![crate = pyo3] {
type ClsInventory = <#cls as pyo3::class::methods::PyMethodsInventoryDispatch>::InventoryType;
Expand Down
59 changes: 30 additions & 29 deletions pyo3-derive-backend/src/pyproto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,39 +63,40 @@ fn impl_proto_impl(

for iimpl in impls.iter_mut() {
if let syn::ImplItem::Method(ref mut met) = iimpl {
for m in proto.methods {
if met.sig.ident == m.name() {
impl_method_proto(ty, &mut met.sig, m).to_tokens(&mut tokens);
}
if let Some(m) = proto.get_proto(&met.sig.ident) {
impl_method_proto(ty, &mut met.sig, m).to_tokens(&mut tokens);
}
for m in proto.py_methods {
if met.sig.ident == m.name {
let name = &met.sig.ident;
let proto: syn::Path = syn::parse_str(m.proto).unwrap();

let fn_spec = match FnSpec::parse(&met.sig, &mut met.attrs, false) {
Ok(fn_spec) => fn_spec,
Err(err) => return err.to_compile_error(),
};
let meth = pymethod::impl_proto_wrap(ty, &fn_spec);
if let Some(m) = proto.get_method(&met.sig.ident) {
let name = &met.sig.ident;
let proto: syn::Path = syn::parse_str(m.proto).unwrap();

py_methods.push(quote! {
impl #proto for #ty
{
#[inline]
fn #name() -> Option<pyo3::class::methods::PyMethodDef> {
#meth
let fn_spec = match FnSpec::parse(&met.sig, &mut met.attrs, false) {
Ok(fn_spec) => fn_spec,
Err(err) => return err.to_compile_error(),
};
let meth = pymethod::impl_proto_wrap(ty, &fn_spec);
let coexist = if m.can_coexist {
quote!(pyo3::ffi::METH_COEXIST)
} else {
quote!(0)
};
py_methods.push(quote! {
impl #proto for #ty
{
#[inline]
fn #name() -> Option<pyo3::class::methods::PyMethodDef> {
#meth

Some(pyo3::class::PyMethodDef {
ml_name: stringify!(#name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
ml_doc: ""
})
}
Some(pyo3::class::PyMethodDef {
ml_name: stringify!(#name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
// We need METH_COEXIST here to prevent __add__ from overriding __radd__
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | #coexist,
ml_doc: ""
})
}
});
}
}
});
}
}
}
Expand Down
Loading