Skip to content

Commit

Permalink
Merge pull request #839 from kngwyu/radd-fix
Browse files Browse the repository at this point in the history
Make __r*__ methods work with operators
  • Loading branch information
kngwyu authored Mar 30, 2020
2 parents e0c1f82 + 1efe142 commit 85de698
Show file tree
Hide file tree
Showing 10 changed files with 754 additions and 201 deletions.
223 changes: 111 additions & 112 deletions pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,44 @@ pub struct Proto {
pub py_methods: &'static [PyMethod],
}

impl Proto {
pub(crate) fn get_proto<Q>(&self, query: Q) -> Option<&'static MethodProto>
where
Q: PartialEq<&'static str>,
{
self.methods.iter().find(|m| query == m.name())
}
pub(crate) fn get_method<Q>(&self, query: Q) -> Option<&'static PyMethod>
where
Q: PartialEq<&'static str>,
{
self.py_methods.iter().find(|m| query == m.name)
}
}

// TODO(kngwyu): Currently only __radd__-like methods use METH_COEXIST to prevent
// __add__-like methods from overriding them.
pub struct PyMethod {
pub name: &'static str,
pub proto: &'static str,
pub can_coexist: bool,
}

impl PyMethod {
const fn coexist(name: &'static str, proto: &'static str) -> Self {
PyMethod {
name,
proto,
can_coexist: true,
}
}
const fn new(name: &'static str, proto: &'static str) -> Self {
PyMethod {
name,
proto,
can_coexist: false,
}
}
}

pub const OBJECT: Proto = Proto {
Expand Down Expand Up @@ -73,18 +108,9 @@ pub const OBJECT: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__format__",
proto: "pyo3::class::basic::FormatProtocolImpl",
},
PyMethod {
name: "__bytes__",
proto: "pyo3::class::basic::BytesProtocolImpl",
},
PyMethod {
name: "__unicode__",
proto: "pyo3::class::basic::UnicodeProtocolImpl",
},
PyMethod::new("__format__", "pyo3::class::basic::FormatProtocolImpl"),
PyMethod::new("__bytes__", "pyo3::class::basic::BytesProtocolImpl"),
PyMethod::new("__unicode__", "pyo3::class::basic::UnicodeProtocolImpl"),
],
};

Expand Down Expand Up @@ -120,14 +146,14 @@ pub const ASYNC: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__aenter__",
proto: "pyo3::class::pyasync::PyAsyncAenterProtocolImpl",
},
PyMethod {
name: "__aexit__",
proto: "pyo3::class::pyasync::PyAsyncAexitProtocolImpl",
},
PyMethod::new(
"__aenter__",
"pyo3::class::pyasync::PyAsyncAenterProtocolImpl",
),
PyMethod::new(
"__aexit__",
"pyo3::class::pyasync::PyAsyncAexitProtocolImpl",
),
],
};

Expand Down Expand Up @@ -165,14 +191,14 @@ pub const CONTEXT: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__enter__",
proto: "pyo3::class::context::PyContextEnterProtocolImpl",
},
PyMethod {
name: "__exit__",
proto: "pyo3::class::context::PyContextExitProtocolImpl",
},
PyMethod::new(
"__enter__",
"pyo3::class::context::PyContextEnterProtocolImpl",
),
PyMethod::new(
"__exit__",
"pyo3::class::context::PyContextExitProtocolImpl",
),
],
};

Expand Down Expand Up @@ -222,14 +248,11 @@ pub const DESCR: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__del__",
proto: "pyo3::class::context::PyDescrDelProtocolImpl",
},
PyMethod {
name: "__set_name__",
proto: "pyo3::class::context::PyDescrNameProtocolImpl",
},
PyMethod::new("__del__", "pyo3::class::context::PyDescrDelProtocolImpl"),
PyMethod::new(
"__set_name__",
"pyo3::class::context::PyDescrNameProtocolImpl",
),
],
};

Expand Down Expand Up @@ -283,10 +306,10 @@ pub const MAPPING: Proto = Proto {
proto: "pyo3::class::mapping::PyMappingReversedProtocol",
},
],
py_methods: &[PyMethod {
name: "__reversed__",
proto: "pyo3::class::mapping::PyMappingReversedProtocolImpl",
}],
py_methods: &[PyMethod::new(
"__reversed__",
"pyo3::class::mapping::PyMappingReversedProtocolImpl",
)],
};

pub const SEQ: Proto = Proto {
Expand Down Expand Up @@ -579,10 +602,9 @@ pub const NUM: Proto = Proto {
pyres: false,
proto: "pyo3::class::number::PyNumberIModProtocol",
},
MethodProto::Ternary {
MethodProto::Binary {
name: "__ipow__",
arg1: "Other",
arg2: "Modulo",
arg: "Other",
pyres: false,
proto: "pyo3::class::number::PyNumberIPowProtocol",
},
Expand Down Expand Up @@ -651,81 +673,58 @@ pub const NUM: Proto = Proto {
pyres: true,
proto: "pyo3::class::number::PyNumberFloatProtocol",
},
MethodProto::Unary {
name: "__round__",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
MethodProto::Unary {
name: "__index__",
pyres: true,
proto: "pyo3::class::number::PyNumberIndexProtocol",
},
],
py_methods: &[
PyMethod {
name: "__radd__",
proto: "pyo3::class::number::PyNumberRAddProtocolImpl",
},
PyMethod {
name: "__rsub__",
proto: "pyo3::class::number::PyNumberRSubProtocolImpl",
},
PyMethod {
name: "__rmul__",
proto: "pyo3::class::number::PyNumberRMulProtocolImpl",
},
PyMethod {
name: "__rmatmul__",
proto: "pyo3::class::number::PyNumberRMatmulProtocolImpl",
},
PyMethod {
name: "__rtruediv__",
proto: "pyo3::class::number::PyNumberRTruedivProtocolImpl",
},
PyMethod {
name: "__rfloordiv__",
proto: "pyo3::class::number::PyNumberRFloordivProtocolImpl",
},
PyMethod {
name: "__rmod__",
proto: "pyo3::class::number::PyNumberRModProtocolImpl",
},
PyMethod {
name: "__rdivmod__",
proto: "pyo3::class::number::PyNumberRDivmodProtocolImpl",
},
PyMethod {
name: "__rpow__",
proto: "pyo3::class::number::PyNumberRPowProtocolImpl",
},
PyMethod {
name: "__rlshift__",
proto: "pyo3::class::number::PyNumberRLShiftProtocolImpl",
},
PyMethod {
name: "__rrshift__",
proto: "pyo3::class::number::PyNumberRRShiftProtocolImpl",
},
PyMethod {
name: "__rand__",
proto: "pyo3::class::number::PyNumberRAndProtocolImpl",
},
PyMethod {
name: "__rxor__",
proto: "pyo3::class::number::PyNumberRXorProtocolImpl",
},
PyMethod {
name: "__ror__",
proto: "pyo3::class::number::PyNumberROrProtocolImpl",
},
PyMethod {
name: "__complex__",
proto: "pyo3::class::number::PyNumberComplexProtocolImpl",
},
PyMethod {
MethodProto::Binary {
name: "__round__",
proto: "pyo3::class::number::PyNumberRoundProtocolImpl",
arg: "NDigits",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
],
py_methods: &[
PyMethod::coexist("__radd__", "pyo3::class::number::PyNumberRAddProtocolImpl"),
PyMethod::coexist("__rsub__", "pyo3::class::number::PyNumberRSubProtocolImpl"),
PyMethod::coexist("__rmul__", "pyo3::class::number::PyNumberRMulProtocolImpl"),
PyMethod::coexist(
"__rmatmul__",
"pyo3::class::number::PyNumberRMatmulProtocolImpl",
),
PyMethod::coexist(
"__rtruediv__",
"pyo3::class::number::PyNumberRTruedivProtocolImpl",
),
PyMethod::coexist(
"__rfloordiv__",
"pyo3::class::number::PyNumberRFloordivProtocolImpl",
),
PyMethod::coexist("__rmod__", "pyo3::class::number::PyNumberRModProtocolImpl"),
PyMethod::coexist(
"__rdivmod__",
"pyo3::class::number::PyNumberRDivmodProtocolImpl",
),
PyMethod::coexist("__rpow__", "pyo3::class::number::PyNumberRPowProtocolImpl"),
PyMethod::coexist(
"__rlshift__",
"pyo3::class::number::PyNumberRLShiftProtocolImpl",
),
PyMethod::coexist(
"__rrshift__",
"pyo3::class::number::PyNumberRRShiftProtocolImpl",
),
PyMethod::coexist("__rand__", "pyo3::class::number::PyNumberRAndProtocolImpl"),
PyMethod::coexist("__rxor__", "pyo3::class::number::PyNumberRXorProtocolImpl"),
PyMethod::coexist("__ror__", "pyo3::class::number::PyNumberROrProtocolImpl"),
PyMethod::new(
"__complex__",
"pyo3::class::number::PyNumberComplexProtocolImpl",
),
PyMethod::new(
"__round__",
"pyo3::class::number::PyNumberRoundProtocolImpl",
),
],
};
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl MethodProto {
}
}

pub fn impl_method_proto(
pub(crate) fn impl_method_proto(
cls: &syn::Type,
sig: &mut syn::Signature,
meth: &MethodProto,
Expand Down
1 change: 0 additions & 1 deletion pyo3-derive-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ fn impl_descriptors(
.collect::<syn::Result<_>>()?;

Ok(quote! {

pyo3::inventory::submit! {
#![crate = pyo3] {
type ClsInventory = <#cls as pyo3::class::methods::PyMethodsInventoryDispatch>::InventoryType;
Expand Down
59 changes: 30 additions & 29 deletions pyo3-derive-backend/src/pyproto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,39 +63,40 @@ fn impl_proto_impl(

for iimpl in impls.iter_mut() {
if let syn::ImplItem::Method(ref mut met) = iimpl {
for m in proto.methods {
if met.sig.ident == m.name() {
impl_method_proto(ty, &mut met.sig, m).to_tokens(&mut tokens);
}
if let Some(m) = proto.get_proto(&met.sig.ident) {
impl_method_proto(ty, &mut met.sig, m).to_tokens(&mut tokens);
}
for m in proto.py_methods {
if met.sig.ident == m.name {
let name = &met.sig.ident;
let proto: syn::Path = syn::parse_str(m.proto).unwrap();

let fn_spec = match FnSpec::parse(&met.sig, &mut met.attrs, false) {
Ok(fn_spec) => fn_spec,
Err(err) => return err.to_compile_error(),
};
let meth = pymethod::impl_proto_wrap(ty, &fn_spec);
if let Some(m) = proto.get_method(&met.sig.ident) {
let name = &met.sig.ident;
let proto: syn::Path = syn::parse_str(m.proto).unwrap();

py_methods.push(quote! {
impl #proto for #ty
{
#[inline]
fn #name() -> Option<pyo3::class::methods::PyMethodDef> {
#meth
let fn_spec = match FnSpec::parse(&met.sig, &mut met.attrs, false) {
Ok(fn_spec) => fn_spec,
Err(err) => return err.to_compile_error(),
};
let meth = pymethod::impl_proto_wrap(ty, &fn_spec);
let coexist = if m.can_coexist {
quote!(pyo3::ffi::METH_COEXIST)
} else {
quote!(0)
};
py_methods.push(quote! {
impl #proto for #ty
{
#[inline]
fn #name() -> Option<pyo3::class::methods::PyMethodDef> {
#meth

Some(pyo3::class::PyMethodDef {
ml_name: stringify!(#name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
ml_doc: ""
})
}
Some(pyo3::class::PyMethodDef {
ml_name: stringify!(#name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
// We need METH_COEXIST here to prevent __add__ from overriding __radd__
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | #coexist,
ml_doc: ""
})
}
});
}
}
});
}
}
}
Expand Down
Loading

0 comments on commit 85de698

Please sign in to comment.