From 30f878d57fb82a26c803f3e8b00fa880e76bdd05 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Sat, 5 Mar 2022 10:28:47 -0500 Subject: [PATCH] Upgrade to pyo3 0.16 Rebased-by: H. Vetinari --- setup.cfg | 1 - setup.py | 2 +- src/rust/Cargo.lock | 72 ++++++++++--------------------- src/rust/Cargo.toml | 2 +- src/rust/src/x509/certificate.rs | 17 +++----- src/rust/src/x509/common.rs | 24 +++++------ src/rust/src/x509/crl.rs | 74 +++++++++++++------------------- src/rust/src/x509/csr.rs | 17 +++----- src/rust/src/x509/extensions.rs | 2 +- src/rust/src/x509/ocsp_req.rs | 4 +- src/rust/src/x509/ocsp_resp.rs | 10 ++--- src/rust/src/x509/sct.rs | 43 +++++++++---------- src/rust/src/x509/sign.rs | 12 +++--- 13 files changed, 115 insertions(+), 165 deletions(-) diff --git a/setup.cfg b/setup.cfg index 23a6a55bd9c6..236c77c5ef22 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,6 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 diff --git a/setup.py b/setup.py index 24b9f102bbf0..557d4064fb23 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ features=( [] if platform.python_implementation() == "PyPy" - else ["pyo3/abi3-py36"] + else ["pyo3/abi3-py37"] ), rust_version=">=1.41.0", ) diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index b03f612c4d26..8f5b81055179 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -85,24 +85,10 @@ dependencies = [ [[package]] name = "indoc" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" -dependencies = [ - "indoc-impl", - "proc-macro-hack", -] - -[[package]] -name = "indoc-impl" -version = "0.3.6" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" +checksum = "e7906a9fababaeacb774f72410e497a1d18de916322e33797bb2cd29baa23c9e" dependencies = [ - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", "unindent", ] @@ -210,25 +196,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "paste" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" -dependencies = [ - "paste-impl", - "proc-macro-hack", -] - -[[package]] -name = "paste-impl" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" -dependencies = [ - "proc-macro-hack", -] - [[package]] name = "pem" version = "1.0.1" @@ -264,12 +231,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "proc-macro-hack" -version = "0.5.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" - [[package]] name = "proc-macro2" version = "1.0.32" @@ -281,35 +242,46 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cf01dbf1c05af0a14c7779ed6f3aa9deac9c3419606ac9de537a2d649005720" +checksum = "9d1a3df45cb95bd954fac00bd9609062640fd7fb9e9946a660092c9e015421fb" dependencies = [ "cfg-if", "indoc", "libc", "parking_lot", - "paste", "pyo3-build-config", + "pyo3-ffi", "pyo3-macros", "unindent", ] [[package]] name = "pyo3-build-config" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf9e4d128bfbddc898ad3409900080d8d5095c379632fbbfbb9c8cfb1fb852b" +checksum = "386a68f0f5f2f9932815068bc8049a56989f2437d96dbb31d1fb11b63ce90364" dependencies = [ "once_cell", ] +[[package]] +name = "pyo3-ffi" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9a4e2f74dc77eea5ce11d19f0afaeb632b6590f8cbb1d5ee2f1330b766803e8" +dependencies = [ + "libc", + "pyo3-build-config", +] + [[package]] name = "pyo3-macros" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67701eb32b1f9a9722b4bc54b548ff9d7ebfded011c12daece7b9063be1fd755" +checksum = "dbff3a1579934968a53bcc78ac33663ed2577accda05d484097679cc8d28e52d" dependencies = [ + "proc-macro2", "pyo3-macros-backend", "quote", "syn", @@ -317,9 +289,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f44f09e825ee49a105f2c7b23ebee50886a9aee0746f4dd5a704138a64b0218a" +checksum = "90644126c8c1ac7b47f794dd20a5729f8646b91c49edb31689c90f8cb3c33ea9" dependencies = [ "proc-macro2", "pyo3-build-config", diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 617167d04429..7da28d8b8ff2 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -7,7 +7,7 @@ publish = false [dependencies] lazy_static = "1" -pyo3 = { version = "0.15.1" } +pyo3 = { version = "0.16" } asn1 = { version = "0.8.7", default-features = false, features = ["derive"] } pem = "1.0" chrono = { version = "0.4", default-features = false, features = ["alloc", "clock"] } diff --git a/src/rust/src/x509/certificate.rs b/src/rust/src/x509/certificate.rs index 4ab8d37025e6..920368eab8b0 100644 --- a/src/rust/src/x509/certificate.rs +++ b/src/rust/src/x509/certificate.rs @@ -82,8 +82,8 @@ pub(crate) struct Certificate { pub(crate) cached_extensions: Option, } -#[pyo3::prelude::pyproto] -impl pyo3::PyObjectProtocol for Certificate { +#[pyo3::prelude::pymethods] +impl Certificate { fn __hash__(&self) -> u64 { let mut hasher = DefaultHasher::new(); self.raw.borrow_value().hash(&mut hasher); @@ -92,7 +92,7 @@ impl pyo3::PyObjectProtocol for Certificate { fn __richcmp__( &self, - other: pyo3::PyRef, + other: pyo3::PyRef<'_, Certificate>, op: pyo3::basic::CompareOp, ) -> pyo3::PyResult { match op { @@ -112,10 +112,7 @@ impl pyo3::PyObjectProtocol for Certificate { let subject_repr = subject.repr()?.extract::<&str>()?; Ok(format!("", subject_repr)) } -} -#[pyo3::prelude::pymethods] -impl Certificate { fn __deepcopy__(slf: pyo3::PyRef<'_, Self>, _memo: pyo3::PyObject) -> pyo3::PyRef<'_, Self> { slf } @@ -157,9 +154,9 @@ impl Certificate { .getattr("Encoding")?; let result = asn1::write_single(self.raw.borrow_value()); - if encoding == encoding_class.getattr("DER")? { + if encoding.is(encoding_class.getattr("DER")?) { Ok(pyo3::types::PyBytes::new(py, &result)) - } else if encoding == encoding_class.getattr("PEM")? { + } else if encoding.is(encoding_class.getattr("PEM")?) { let pem = pem::encode_config( &pem::Pem { tag: "CERTIFICATE".to_string(), @@ -253,7 +250,7 @@ impl Certificate { let hash_alg = sig_oids_to_hash.get_item(self.signature_algorithm_oid(py)?); match hash_alg { Ok(data) => Ok(data), - Err(_) => Err(PyAsn1Error::from(pyo3::PyErr::from_instance( + Err(_) => Err(PyAsn1Error::from(pyo3::PyErr::from_value( py.import("cryptography.exceptions")?.call_method1( "UnsupportedAlgorithm", (format!( @@ -326,7 +323,7 @@ fn cert_version(py: pyo3::Python<'_>, version: u8) -> Result<&pyo3::PyAny, PyAsn match version { 0 => Ok(x509_module.getattr("Version")?.get_item("v1")?), 2 => Ok(x509_module.getattr("Version")?.get_item("v3")?), - _ => Err(PyAsn1Error::from(pyo3::PyErr::from_instance( + _ => Err(PyAsn1Error::from(pyo3::PyErr::from_value( x509_module .getattr("InvalidVersion")? .call1((format!("{} is not a valid X509 version", version), version))?, diff --git a/src/rust/src/x509/common.rs b/src/rust/src/x509/common.rs index e3b36c8a26dc..0c87edfa8add 100644 --- a/src/rust/src/x509/common.rs +++ b/src/rust/src/x509/common.rs @@ -108,9 +108,9 @@ pub(crate) fn encode_name_entry<'p>( let attr_type = py_name_entry.getattr("_type")?; let tag = attr_type.getattr("value")?.extract::()?; - let encoding = if attr_type == asn1_type.getattr("BMPString")? { + let encoding = if attr_type.is(asn1_type.getattr("BMPString")?) { "utf_16_be" - } else if attr_type == asn1_type.getattr("UniversalString")? { + } else if attr_type.is(asn1_type.getattr("UniversalString")?) { "utf_32_be" } else { "utf8" @@ -226,18 +226,18 @@ pub(crate) fn encode_general_name<'a>( let gn_module = py.import("cryptography.x509.general_name")?; let gn_type = gn.get_type().as_ref(); let gn_value = gn.getattr("value")?; - if gn_type == gn_module.getattr("DNSName")? { + if gn_type.is(gn_module.getattr("DNSName")?) { Ok(GeneralName::DNSName(UnvalidatedIA5String( gn_value.extract::<&str>()?, ))) - } else if gn_type == gn_module.getattr("RFC822Name")? { + } else if gn_type.is(gn_module.getattr("RFC822Name")?) { Ok(GeneralName::RFC822Name(UnvalidatedIA5String( gn_value.extract::<&str>()?, ))) - } else if gn_type == gn_module.getattr("DirectoryName")? { + } else if gn_type.is(gn_module.getattr("DirectoryName")?) { let name = encode_name(py, gn_value)?; Ok(GeneralName::DirectoryName(name)) - } else if gn_type == gn_module.getattr("OtherName")? { + } else if gn_type.is(gn_module.getattr("OtherName")?) { Ok(GeneralName::OtherName(OtherName { type_id: asn1::ObjectIdentifier::from_string( gn.getattr("type_id")? @@ -247,15 +247,15 @@ pub(crate) fn encode_general_name<'a>( .unwrap(), value: asn1::parse_single(gn_value.extract::<&[u8]>()?)?, })) - } else if gn_type == gn_module.getattr("UniformResourceIdentifier")? { + } else if gn_type.is(gn_module.getattr("UniformResourceIdentifier")?) { Ok(GeneralName::UniformResourceIdentifier( UnvalidatedIA5String(gn_value.extract::<&str>()?), )) - } else if gn_type == gn_module.getattr("IPAddress")? { + } else if gn_type.is(gn_module.getattr("IPAddress")?) { Ok(GeneralName::IPAddress( gn.call_method0("_packed")?.extract::<&[u8]>()?, )) - } else if gn_type == gn_module.getattr("RegisteredID")? { + } else if gn_type.is(gn_module.getattr("RegisteredID")?) { let oid = asn1::ObjectIdentifier::from_string( gn_value.getattr("dotted_string")?.extract::<&str>()?, ) @@ -458,7 +458,7 @@ pub(crate) fn parse_general_name( .to_object(py) } _ => { - return Err(PyAsn1Error::from(pyo3::PyErr::from_instance( + return Err(PyAsn1Error::from(pyo3::PyErr::from_value( x509_module.call_method1( "UnsupportedGeneralNameType", ("x400Address/EDIPartyName are not supported types",), @@ -556,7 +556,7 @@ pub(crate) fn parse_and_cache_extensions< x509_module.call_method1("ObjectIdentifier", (raw_ext.extn_id.to_string(),))?; if seen_oids.contains(&raw_ext.extn_id) { - return Err(pyo3::PyErr::from_instance(x509_module.call_method1( + return Err(pyo3::PyErr::from_value(x509_module.call_method1( "DuplicateExtension", ( format!("Duplicate {} extension found", raw_ext.extn_id), @@ -608,7 +608,7 @@ pub(crate) fn encode_extensions< .unwrap(); let ext_val = py_ext.getattr("value")?; - if unrecognized_extension_type.is_instance(ext_val)? { + if ext_val.is_instance(unrecognized_extension_type)? { exts.push(Extension { extn_id: oid, critical: py_ext.getattr("critical")?.extract()?, diff --git a/src/rust/src/x509/crl.rs b/src/rust/src/x509/crl.rs index e76fd740c819..9dfd5d3119d8 100644 --- a/src/rust/src/x509/crl.rs +++ b/src/rust/src/x509/crl.rs @@ -86,11 +86,11 @@ impl CertificateRevocationList { } } -#[pyo3::prelude::pyproto] -impl pyo3::PyObjectProtocol for CertificateRevocationList { +#[pyo3::prelude::pymethods] +impl CertificateRevocationList { fn __richcmp__( &self, - other: pyo3::PyRef, + other: pyo3::PyRef<'_, CertificateRevocationList>, op: pyo3::basic::CompareOp, ) -> pyo3::PyResult { match op { @@ -101,14 +101,26 @@ impl pyo3::PyObjectProtocol for CertificateRevocationList { )), } } -} -#[pyo3::prelude::pyproto] -impl pyo3::PyMappingProtocol for CertificateRevocationList { fn __len__(&self) -> usize { self.len() } + fn __iter__(&self) -> CRLIterator { + CRLIterator { + contents: OwnedCRLIteratorData::try_new(Arc::clone(&self.raw), |v| { + Ok::<_, ()>( + v.borrow_value() + .tbs_cert_list + .revoked_certificates + .as_ref() + .map(|v| v.unwrap_read().clone()), + ) + }) + .unwrap(), + } + } + fn __getitem__(&self, idx: &pyo3::PyAny) -> pyo3::PyResult { let gil = pyo3::Python::acquire_gil(); let py = gil.python(); @@ -122,7 +134,7 @@ impl pyo3::PyMappingProtocol for CertificateRevocationList { }); }); - if idx.is_instance::()? { + if idx.is_instance_of::()? { let indices = idx .downcast::()? .indices(self.len().try_into().unwrap())?; @@ -143,10 +155,7 @@ impl pyo3::PyMappingProtocol for CertificateRevocationList { Ok(pyo3::PyCell::new(py, self.revoked_cert(py, idx as usize)?)?.to_object(py)) } } -} -#[pyo3::prelude::pymethods] -impl CertificateRevocationList { fn fingerprint<'p>( &self, py: pyo3::Python<'p>, @@ -177,7 +186,7 @@ impl CertificateRevocationList { let exceptions_module = py.import("cryptography.exceptions")?; match oid_module.getattr("_SIG_OIDS_TO_HASH")?.get_item(oid) { Ok(v) => Ok(v), - Err(_) => Err(pyo3::PyErr::from_instance(exceptions_module.call_method1( + Err(_) => Err(pyo3::PyErr::from_value(exceptions_module.call_method1( "UnsupportedAlgorithm", (format!( "Signature algorithm OID:{} not recognized", @@ -208,9 +217,9 @@ impl CertificateRevocationList { .getattr("Encoding")?; let result = asn1::write_single(self.raw.borrow_value()); - if encoding == encoding_class.getattr("DER")? { + if encoding.is(encoding_class.getattr("DER")?) { Ok(pyo3::types::PyBytes::new(py, &result)) - } else if encoding == encoding_class.getattr("PEM")? { + } else if encoding.is(encoding_class.getattr("PEM")?) { let pem = pem::encode_config( &pem::Pem { tag: "X509 CRL".to_string(), @@ -391,24 +400,6 @@ impl CertificateRevocationList { } } -#[pyo3::prelude::pyproto] -impl pyo3::PyIterProtocol<'_> for CertificateRevocationList { - fn __iter__(slf: pyo3::PyRef<'p, Self>) -> CRLIterator { - CRLIterator { - contents: OwnedCRLIteratorData::try_new(Arc::clone(&slf.raw), |v| { - Ok::<_, ()>( - v.borrow_value() - .tbs_cert_list - .revoked_certificates - .as_ref() - .map(|v| v.unwrap_read().clone()), - ) - }) - .unwrap(), - } - } -} - #[ouroboros::self_referencing] struct OwnedCRLIteratorData { data: Arc, @@ -451,14 +442,18 @@ fn try_map_arc_data_mut_crl_iterator( }) } -#[pyo3::prelude::pyproto] -impl pyo3::PyIterProtocol<'_> for CRLIterator { - fn __iter__(slf: pyo3::PyRef<'p, Self>) -> pyo3::PyRef<'p, Self> { +#[pyo3::prelude::pymethods] +impl CRLIterator { + fn __len__(&self) -> usize { + self.contents.borrow_value().clone().map_or(0, |v| v.len()) + } + + fn __iter__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> { slf } - fn __next__(mut slf: pyo3::PyRefMut<'p, Self>) -> Option { - let revoked = try_map_arc_data_mut_crl_iterator(&mut slf.contents, |_data, v| match v { + fn __next__(&mut self) -> Option { + let revoked = try_map_arc_data_mut_crl_iterator(&mut self.contents, |_data, v| match v { Some(v) => match v.next() { Some(revoked) => Ok(revoked), None => Err(()), @@ -473,13 +468,6 @@ impl pyo3::PyIterProtocol<'_> for CRLIterator { } } -#[pyo3::prelude::pyproto] -impl pyo3::PySequenceProtocol<'_> for CRLIterator { - fn __len__(&self) -> usize { - self.contents.borrow_value().clone().map_or(0, |v| v.len()) - } -} - #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Hash)] struct RawCertificateRevocationList<'a> { tbs_cert_list: TBSCertList<'a>, diff --git a/src/rust/src/x509/csr.rs b/src/rust/src/x509/csr.rs index 268f4a3dc9b7..888c065a7f44 100644 --- a/src/rust/src/x509/csr.rs +++ b/src/rust/src/x509/csr.rs @@ -90,8 +90,8 @@ struct CertificateSigningRequest { cached_extensions: Option, } -#[pyo3::prelude::pyproto] -impl pyo3::basic::PyObjectProtocol for CertificateSigningRequest { +#[pyo3::prelude::pymethods] +impl CertificateSigningRequest { fn __hash__(&self) -> u64 { let mut hasher = DefaultHasher::new(); self.raw.borrow_data().hash(&mut hasher); @@ -100,7 +100,7 @@ impl pyo3::basic::PyObjectProtocol for CertificateSigningRequest { fn __richcmp__( &self, - other: pyo3::PyRef, + other: pyo3::PyRef<'_, CertificateSigningRequest>, op: pyo3::basic::CompareOp, ) -> pyo3::PyResult { match op { @@ -111,10 +111,7 @@ impl pyo3::basic::PyObjectProtocol for CertificateSigningRequest { )), } } -} -#[pyo3::prelude::pymethods] -impl CertificateSigningRequest { fn public_key<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<&'p pyo3::PyAny> { // This makes an unnecessary copy. It'd be nice to get rid of it. let serialized = pyo3::types::PyBytes::new( @@ -159,7 +156,7 @@ impl CertificateSigningRequest { let hash_alg = sig_oids_to_hash.get_item(self.signature_algorithm_oid(py)?); match hash_alg { Ok(data) => Ok(data), - Err(_) => Err(PyAsn1Error::from(pyo3::PyErr::from_instance( + Err(_) => Err(PyAsn1Error::from(pyo3::PyErr::from_value( py.import("cryptography.exceptions")?.call_method1( "UnsupportedAlgorithm", (format!( @@ -189,9 +186,9 @@ impl CertificateSigningRequest { .getattr("Encoding")?; let result = asn1::write_single(self.raw.borrow_value()); - if encoding == encoding_class.getattr("DER")? { + if encoding.is(encoding_class.getattr("DER")?) { Ok(pyo3::types::PyBytes::new(py, &result)) - } else if encoding == encoding_class.getattr("PEM")? { + } else if encoding.is(encoding_class.getattr("PEM")?) { let pem = pem::encode_config( &pem::Pem { tag: "CERTIFICATE REQUEST".to_string(), @@ -252,7 +249,7 @@ impl CertificateSigningRequest { } } } - Err(pyo3::PyErr::from_instance( + Err(pyo3::PyErr::from_value( py.import("cryptography.x509")?.call_method1( "AttributeNotFound", (format!("No {} attribute was found", oid_str), oid), diff --git a/src/rust/src/x509/extensions.rs b/src/rust/src/x509/extensions.rs index 606566dd96e6..374aa507ff10 100644 --- a/src/rust/src/x509/extensions.rs +++ b/src/rust/src/x509/extensions.rs @@ -165,7 +165,7 @@ pub(crate) fn encode_extension( let mut qualifiers = vec![]; for py_qualifier in py_policy_qualifiers.iter()? { let py_qualifier = py_qualifier?; - let qualifier = if py_qualifier.is_instance::()? { + let qualifier = if py_qualifier.is_instance_of::()? { let cps_uri = match asn1::IA5String::new(py_qualifier.extract()?) { Some(s) => s, None => { diff --git a/src/rust/src/x509/ocsp_req.rs b/src/rust/src/x509/ocsp_req.rs index 57b1391c6076..08cd5331299a 100644 --- a/src/rust/src/x509/ocsp_req.rs +++ b/src/rust/src/x509/ocsp_req.rs @@ -83,7 +83,7 @@ impl OCSPRequest { Some(alg_name) => Ok(hashes.getattr(alg_name)?.call0()?), None => { let exceptions = py.import("cryptography.exceptions")?; - Err(PyAsn1Error::from(pyo3::PyErr::from_instance( + Err(PyAsn1Error::from(pyo3::PyErr::from_value( exceptions.getattr("UnsupportedAlgorithm")?.call1((format!( "Signature algorithm OID: {} not recognized", cert_id.hash_algorithm.oid @@ -132,7 +132,7 @@ impl OCSPRequest { .import("cryptography.hazmat.primitives.serialization")? .getattr("Encoding")? .getattr("DER")?; - if encoding != der { + if !encoding.is(der) { return Err(pyo3::exceptions::PyValueError::new_err( "The only allowed encoding value is Encoding.DER", )); diff --git a/src/rust/src/x509/ocsp_resp.rs b/src/rust/src/x509/ocsp_resp.rs index 66b5e2324f76..85832f714d39 100644 --- a/src/rust/src/x509/ocsp_resp.rs +++ b/src/rust/src/x509/ocsp_resp.rs @@ -177,7 +177,7 @@ impl OCSPResponse { "Signature algorithm OID: {} not recognized", self.requires_successful_response()?.signature_algorithm.oid ); - Err(PyAsn1Error::from(pyo3::PyErr::from_instance( + Err(PyAsn1Error::from(pyo3::PyErr::from_value( py.import("cryptography.exceptions")? .call_method1("UnsupportedAlgorithm", (exc_messsage,))?, ))) @@ -401,7 +401,7 @@ impl OCSPResponse { .import("cryptography.hazmat.primitives.serialization")? .getattr("Encoding")? .getattr("DER")?; - if encoding != der { + if !encoding.is(der) { return Err(pyo3::exceptions::PyValueError::new_err( "The only allowed encoding value is Encoding.DER", )); @@ -546,9 +546,9 @@ fn create_ocsp_basic_response<'p>( builder.getattr("_responder_id")?.extract()?; let py_cert_status = py_single_resp.getattr("_cert_status")?; - let cert_status = if py_cert_status == ocsp_mod.getattr("OCSPCertStatus")?.getattr("GOOD")? { + let cert_status = if py_cert_status.is(ocsp_mod.getattr("OCSPCertStatus")?.getattr("GOOD")?) { CertStatus::Good(()) - } else if py_cert_status == ocsp_mod.getattr("OCSPCertStatus")?.getattr("UNKNOWN")? { + } else if py_cert_status.is(ocsp_mod.getattr("OCSPCertStatus")?.getattr("UNKNOWN")?) { CertStatus::Unknown(()) } else { let revocation_reason = if !py_single_resp.getattr("_revocation_reason")?.is_none() { @@ -588,7 +588,7 @@ fn create_ocsp_basic_response<'p>( let borrowed_cert = responder_cert.borrow(); let responder_id = - if responder_encoding == ocsp_mod.getattr("OCSPResponderEncoding")?.getattr("HASH")? { + if responder_encoding.is(ocsp_mod.getattr("OCSPResponderEncoding")?.getattr("HASH")?) { let sha1 = py .import("cryptography.hazmat.primitives.hashes")? .getattr("SHA1")? diff --git a/src/rust/src/x509/sct.rs b/src/rust/src/x509/sct.rs index 4db550bca74f..a6454b717c69 100644 --- a/src/rust/src/x509/sct.rs +++ b/src/rust/src/x509/sct.rs @@ -59,6 +59,26 @@ pub(crate) struct Sct { #[pyo3::prelude::pymethods] impl Sct { + fn __richcmp__( + &self, + other: pyo3::PyRef<'_, Sct>, + op: pyo3::basic::CompareOp, + ) -> pyo3::PyResult { + match op { + pyo3::basic::CompareOp::Eq => Ok(self.sct_data == other.sct_data), + pyo3::basic::CompareOp::Ne => Ok(self.sct_data != other.sct_data), + _ => Err(pyo3::exceptions::PyTypeError::new_err( + "SCTs cannot be ordered", + )), + } + } + + fn __hash__(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.sct_data.hash(&mut hasher); + hasher.finish() + } + #[getter] fn version<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<&'p pyo3::PyAny> { py.import("cryptography.x509.certificate_transparency")? @@ -96,29 +116,6 @@ impl Sct { } } -#[pyo3::prelude::pyproto] -impl pyo3::PyObjectProtocol for Sct { - fn __richcmp__( - &self, - other: pyo3::PyRef, - op: pyo3::basic::CompareOp, - ) -> pyo3::PyResult { - match op { - pyo3::basic::CompareOp::Eq => Ok(self.sct_data == other.sct_data), - pyo3::basic::CompareOp::Ne => Ok(self.sct_data != other.sct_data), - _ => Err(pyo3::exceptions::PyTypeError::new_err( - "SCTs cannot be ordered", - )), - } - } - - fn __hash__(&self) -> u64 { - let mut hasher = DefaultHasher::new(); - self.sct_data.hash(&mut hasher); - hasher.finish() - } -} - pub(crate) fn parse_scts( py: pyo3::Python<'_>, data: &[u8], diff --git a/src/rust/src/x509/sign.rs b/src/rust/src/x509/sign.rs index e1579481808f..a09aecbc43b5 100644 --- a/src/rust/src/x509/sign.rs +++ b/src/rust/src/x509/sign.rs @@ -51,15 +51,15 @@ fn identify_key_type(py: pyo3::Python<'_>, private_key: &pyo3::PyAny) -> pyo3::P .getattr("Ed448PrivateKey")? .extract()?; - if rsa_private_key.is_instance(private_key)? { + if private_key.is_instance(rsa_private_key)? { Ok(KeyType::Rsa) - } else if dsa_key_type.is_instance(private_key)? { + } else if private_key.is_instance(dsa_key_type)? { Ok(KeyType::Dsa) - } else if ec_key_type.is_instance(private_key)? { + } else if private_key.is_instance(ec_key_type)? { Ok(KeyType::Ec) - } else if ed25519_key_type.is_instance(private_key)? { + } else if private_key.is_instance(ed25519_key_type)? { Ok(KeyType::Ed25519) - } else if ed448_key_type.is_instance(private_key)? { + } else if private_key.is_instance(ed448_key_type)? { Ok(KeyType::Ed448) } else { Err(pyo3::exceptions::PyTypeError::new_err( @@ -80,7 +80,7 @@ fn identify_hash_type( .import("cryptography.hazmat.primitives.hashes")? .getattr("HashAlgorithm")? .extract()?; - if !hash_algorithm_type.is_instance(hash_algorithm)? { + if !hash_algorithm.is_instance(hash_algorithm_type)? { return Err(pyo3::exceptions::PyTypeError::new_err( "Algorithm must be a registered hash algorithm.", ));