Skip to content

Commit

Permalink
add pyclass hash option (#4206)
Browse files Browse the repository at this point in the history
* add pyclass `hash` option

* add newsfragment

* require `frozen` option for `hash`

* simplify `hash` without `frozen` error message

Co-authored-by: David Hewitt <mail@davidhewitt.dev>

* require `eq` for `hash`

* prevent manual `__hash__` with `#pyo3(hash)`

* combine error messages

---------

Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
Icxolu and davidhewitt authored Jun 1, 2024
1 parent 25c1db4 commit a7a5c10
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 5 deletions.
1 change: 1 addition & 0 deletions guide/pyclass-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
| <span style="white-space: pre">`freelist = N`</span> | Implements a [free list][params-2] of size N. This can improve performance for types that are often created and deleted in quick succession. Profile your code to see whether `freelist` is right for you. |
| <span style="white-space: pre">`frozen`</span> | Declares that your pyclass is immutable. It removes the borrow checker overhead when retrieving a shared reference to the Rust struct, but disables the ability to get a mutable reference. |
| `get_all` | Generates getters for all fields of the pyclass. |
| `hash` | Implements `__hash__` using the `Hash` implementation of the underlying Rust datatype. |
| `mapping` | Inform PyO3 that this class is a [`Mapping`][params-mapping], and so leave its implementation of sequence C-API slots empty. |
| <span style="white-space: pre">`module = "module_name"`</span> | Python code will see the class as being defined in this module. Defaults to `builtins`. |
| <span style="white-space: pre">`name = "python_name"`</span> | Sets the name that Python sees this class as. Defaults to the name of the Rust struct. |
Expand Down
13 changes: 13 additions & 0 deletions guide/src/class/object.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ impl Number {
}
}
```
To implement `__hash__` using the Rust [`Hash`] trait implementation, the `hash` option can be used.
This option is only available for `frozen` classes to prevent accidental hash changes from mutating the object. If you need
an `__hash__` implementation for a mutable class, use the manual method from above. This option also requires `eq`: According to the
[Python docs](https://docs.python.org/3/reference/datamodel.html#object.__hash__) "If a class does not define an `__eq__()`
method it should not define a `__hash__()` operation either"
```rust
# use pyo3::prelude::*;
#
#[pyclass(frozen, eq, hash)]
#[derive(PartialEq, Hash)]
struct Number(i32);
```


> **Note**: When implementing `__hash__` and comparisons, it is important that the following property holds:
>
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4206.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `#[pyclass(hash)]` option to implement `__hash__` in terms of the `Hash` implementation
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub mod kw {
syn::custom_keyword!(frozen);
syn::custom_keyword!(get);
syn::custom_keyword!(get_all);
syn::custom_keyword!(hash);
syn::custom_keyword!(item);
syn::custom_keyword!(from_item_all);
syn::custom_keyword!(mapping);
Expand Down
48 changes: 47 additions & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::pyfunction::ConstructorAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
SlotDef, __GETITEM__, __INT__, __LEN__, __REPR__, __RICHCMP__,
SlotDef, __GETITEM__, __HASH__, __INT__, __LEN__, __REPR__, __RICHCMP__,
};
use crate::utils::Ctx;
use crate::utils::{self, apply_renaming_rule, PythonDoc};
Expand All @@ -21,6 +21,7 @@ use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::ext::IdentExt;
use syn::parse::{Parse, ParseStream};
use syn::parse_quote_spanned;
use syn::punctuated::Punctuated;
use syn::{parse_quote, spanned::Spanned, Result, Token};

Expand Down Expand Up @@ -65,6 +66,7 @@ pub struct PyClassPyO3Options {
pub get_all: Option<kw::get_all>,
pub freelist: Option<FreelistAttribute>,
pub frozen: Option<kw::frozen>,
pub hash: Option<kw::hash>,
pub mapping: Option<kw::mapping>,
pub module: Option<ModuleAttribute>,
pub name: Option<NameAttribute>,
Expand All @@ -85,6 +87,7 @@ enum PyClassPyO3Option {
Freelist(FreelistAttribute),
Frozen(kw::frozen),
GetAll(kw::get_all),
Hash(kw::hash),
Mapping(kw::mapping),
Module(ModuleAttribute),
Name(NameAttribute),
Expand Down Expand Up @@ -115,6 +118,8 @@ impl Parse for PyClassPyO3Option {
input.parse().map(PyClassPyO3Option::Frozen)
} else if lookahead.peek(attributes::kw::get_all) {
input.parse().map(PyClassPyO3Option::GetAll)
} else if lookahead.peek(attributes::kw::hash) {
input.parse().map(PyClassPyO3Option::Hash)
} else if lookahead.peek(attributes::kw::mapping) {
input.parse().map(PyClassPyO3Option::Mapping)
} else if lookahead.peek(attributes::kw::module) {
Expand Down Expand Up @@ -180,6 +185,7 @@ impl PyClassPyO3Options {
PyClassPyO3Option::Freelist(freelist) => set_option!(freelist),
PyClassPyO3Option::Frozen(frozen) => set_option!(frozen),
PyClassPyO3Option::GetAll(get_all) => set_option!(get_all),
PyClassPyO3Option::Hash(hash) => set_option!(hash),
PyClassPyO3Option::Mapping(mapping) => set_option!(mapping),
PyClassPyO3Option::Module(module) => set_option!(module),
PyClassPyO3Option::Name(name) => set_option!(name),
Expand Down Expand Up @@ -363,8 +369,12 @@ fn impl_class(
let (default_richcmp, default_richcmp_slot) =
pyclass_richcmp(&args.options, &syn::parse_quote!(#cls), ctx)?;

let (default_hash, default_hash_slot) =
pyclass_hash(&args.options, &syn::parse_quote!(#cls), ctx)?;

let mut slots = Vec::new();
slots.extend(default_richcmp_slot);
slots.extend(default_hash_slot);

let py_class_impl = PyClassImplsBuilder::new(
cls,
Expand Down Expand Up @@ -393,6 +403,7 @@ fn impl_class(
#[allow(non_snake_case)]
impl #cls {
#default_richcmp
#default_hash
}
})
}
Expand Down Expand Up @@ -798,9 +809,11 @@ fn impl_simple_enum(

let (default_richcmp, default_richcmp_slot) =
pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx);
let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?;

let mut default_slots = vec![default_repr_slot, default_int_slot];
default_slots.extend(default_richcmp_slot);
default_slots.extend(default_hash_slot);

let pyclass_impls = PyClassImplsBuilder::new(
cls,
Expand All @@ -827,6 +840,7 @@ fn impl_simple_enum(
#default_repr
#default_int
#default_richcmp
#default_hash
}
})
}
Expand Down Expand Up @@ -858,9 +872,11 @@ fn impl_complex_enum(
let pytypeinfo = impl_pytypeinfo(cls, &args, None, ctx);

let (default_richcmp, default_richcmp_slot) = pyclass_richcmp(&args.options, &ty, ctx)?;
let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?;

let mut default_slots = vec![];
default_slots.extend(default_richcmp_slot);
default_slots.extend(default_hash_slot);

let impl_builder = PyClassImplsBuilder::new(
cls,
Expand Down Expand Up @@ -967,6 +983,7 @@ fn impl_complex_enum(
#[allow(non_snake_case)]
impl #cls {
#default_richcmp
#default_hash
}

#(#variant_cls_zsts)*
Expand Down Expand Up @@ -1783,6 +1800,35 @@ fn pyclass_richcmp(
}
}

fn pyclass_hash(
options: &PyClassPyO3Options,
cls: &syn::Type,
ctx: &Ctx,
) -> Result<(Option<syn::ImplItemFn>, Option<MethodAndSlotDef>)> {
if options.hash.is_some() {
ensure_spanned!(
options.frozen.is_some(), options.hash.span() => "The `hash` option requires the `frozen` option.";
options.eq.is_some(), options.hash.span() => "The `hash` option requires the `eq` option.";
);
}
// FIXME: Use hash.map(...).unzip() on MSRV >= 1.66
match options.hash {
Some(opt) => {
let mut hash_impl = parse_quote_spanned! { opt.span() =>
fn __pyo3__generated____hash__(&self) -> u64 {
let mut s = ::std::collections::hash_map::DefaultHasher::new();
::std::hash::Hash::hash(self, &mut s);
::std::hash::Hasher::finish(&s)
}
};
let hash_slot =
generate_protocol_slot(cls, &mut hash_impl, &__HASH__, "__hash__", ctx).unwrap();
Ok((Some(hash_impl), Some(hash_slot)))
}
None => Ok((None, None)),
}
}

/// Implements most traits used by `#[pyclass]`.
///
/// Specifically, it implements traits that only depend on class name,
Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ impl PropertyType<'_> {

const __STR__: SlotDef = SlotDef::new("Py_tp_str", "reprfunc");
pub const __REPR__: SlotDef = SlotDef::new("Py_tp_repr", "reprfunc");
const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc")
pub const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc")
.ret_ty(Ty::PyHashT)
.return_conversion(TokenGenerator(
|Ctx { pyo3_path }: &Ctx| quote! { #pyo3_path::callback::HashCallbackOutput },
Expand Down
15 changes: 14 additions & 1 deletion pyo3-macros-backend/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,20 @@ macro_rules! ensure_spanned {
if !($condition) {
bail_spanned!($span => $msg);
}
}
};
($($condition:expr, $span:expr => $msg:expr;)*) => {
if let Some(e) = [$(
(!($condition)).then(|| err_spanned!($span => $msg)),
)*]
.into_iter()
.flatten()
.reduce(|mut acc, e| {
acc.combine(e);
acc
}) {
return Err(e);
}
};
}

/// Check if the given type `ty` is `pyo3::Python`.
Expand Down
28 changes: 28 additions & 0 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,34 @@ fn class_with_object_field() {
});
}

#[pyclass(frozen, eq, hash)]
#[derive(PartialEq, Hash)]
struct ClassWithHash {
value: usize,
}

#[test]
fn class_with_hash() {
Python::with_gil(|py| {
use pyo3::types::IntoPyDict;
let class = ClassWithHash { value: 42 };
let hash = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
class.hash(&mut hasher);
hasher.finish() as isize
};

let env = [
("obj", Py::new(py, class).unwrap().into_any()),
("hsh", hash.into_py(py)),
]
.into_py_dict_bound(py);

py_assert!(py, *env, "hash(obj) == hsh");
});
}

#[pyclass(unsendable, subclass)]
struct UnsendableBase {
value: std::rc::Rc<usize>,
Expand Down
60 changes: 60 additions & 0 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,63 @@ fn test_renaming_all_enum_variants() {
);
});
}

#[pyclass(frozen, eq, eq_int, hash)]
#[derive(PartialEq, Hash)]
enum SimpleEnumWithHash {
A,
B,
}

#[test]
fn test_simple_enum_with_hash() {
Python::with_gil(|py| {
use pyo3::types::IntoPyDict;
let class = SimpleEnumWithHash::A;
let hash = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
class.hash(&mut hasher);
hasher.finish() as isize
};

let env = [
("obj", Py::new(py, class).unwrap().into_any()),
("hsh", hash.into_py(py)),
]
.into_py_dict_bound(py);

py_assert!(py, *env, "hash(obj) == hsh");
});
}

#[pyclass(eq, hash)]
#[derive(PartialEq, Hash)]
enum ComplexEnumWithHash {
A(u32),
B { msg: String },
}

#[test]
fn test_complex_enum_with_hash() {
Python::with_gil(|py| {
use pyo3::types::IntoPyDict;
let class = ComplexEnumWithHash::B {
msg: String::from("Hello"),
};
let hash = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
class.hash(&mut hasher);
hasher.finish() as isize
};

let env = [
("obj", Py::new(py, class).unwrap().into_any()),
("hsh", hash.into_py(py)),
]
.into_py_dict_bound(py);

py_assert!(py, *env, "hash(obj) == hsh");
});
}
19 changes: 19 additions & 0 deletions tests/ui/invalid_pyclass_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,23 @@ impl EqOptAndManualRichCmp {
#[pyclass(eq_int)]
struct NoEqInt {}

#[pyclass(frozen, eq, hash)]
#[derive(PartialEq)]
struct HashOptRequiresHash;

#[pyclass(hash)]
#[derive(Hash)]
struct HashWithoutFrozenAndEq;

#[pyclass(frozen, eq, hash)]
#[derive(PartialEq, Hash)]
struct HashOptAndManualHash {}

#[pymethods]
impl HashOptAndManualHash {
fn __hash__(&self) -> u64 {
todo!()
}
}

fn main() {}
Loading

0 comments on commit a7a5c10

Please sign in to comment.