From 3357d10b323504c62d6b135acce6523df3e1c3f2 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Thu, 4 Apr 2024 07:59:21 -0400 Subject: [PATCH] Convert `src/backend/ciphers.rs` to new pyo3 APIs --- src/rust/src/backend/ciphers.rs | 81 ++++++++++++++++++--------------- src/rust/src/backend/mod.rs | 2 +- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/rust/src/backend/ciphers.rs b/src/rust/src/backend/ciphers.rs index 2cf97d7b8800..5677e0fbba3d 100644 --- a/src/rust/src/backend/ciphers.rs +++ b/src/rust/src/backend/ciphers.rs @@ -7,7 +7,7 @@ use crate::buf::{CffiBuf, CffiMutBuf}; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; use crate::types; -use pyo3::prelude::PyAnyMethods; +use pyo3::prelude::{PyAnyMethods, PyModuleMethods}; use pyo3::IntoPy; struct CipherContext { @@ -121,10 +121,10 @@ impl CipherContext { &mut self, py: pyo3::Python<'p>, buf: &[u8], - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let mut out_buf = vec![0; buf.len() + self.ctx.block_size()]; let n = self.update_into(py, buf, &mut out_buf)?; - Ok(pyo3::types::PyBytes::new(py, &out_buf[..n])) + Ok(pyo3::types::PyBytes::new_bound(py, &out_buf[..n])) } fn update_into( @@ -146,7 +146,11 @@ impl CipherContext { for chunk in buf.chunks(1 << 29) { // SAFETY: We ensure that outbuf is sufficiently large above. unsafe { - let n = if self.py_mode.as_ref(py).is_instance(types::XTS.get(py)?)? { + let n = if self + .py_mode + .bind(py) + .is_instance(&types::XTS.get_bound(py)?)? + { self.ctx.cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..])).map_err(|_| { pyo3::exceptions::PyValueError::new_err( "In XTS mode you must supply at least a full block in the first update call. For AES this is 16 bytes." @@ -171,14 +175,14 @@ impl CipherContext { fn finalize<'p>( &mut self, py: pyo3::Python<'p>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let mut out_buf = vec![0; self.ctx.block_size()]; let n = self.ctx.cipher_final(&mut out_buf).or_else(|e| { if e.errors().is_empty() && self .py_mode - .as_ref(py) - .is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? + .bind(py) + .is_instance(&types::MODE_WITH_AUTHENTICATION_TAG.get_bound(py)?)? { return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); } @@ -188,7 +192,7 @@ impl CipherContext { ), )) })?; - Ok(pyo3::types::PyBytes::new(py, &out_buf[..n])) + Ok(pyo3::types::PyBytes::new_bound(py, &out_buf[..n])) } } @@ -233,7 +237,7 @@ impl PyCipherContext { &mut self, py: pyo3::Python<'p>, buf: CffiBuf<'_>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { get_mut_ctx(self.ctx.as_mut())?.update(py, buf.as_bytes()) } @@ -249,7 +253,7 @@ impl PyCipherContext { fn finalize<'p>( &mut self, py: pyo3::Python<'p>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let result = get_mut_ctx(self.ctx.as_mut())?.finalize(py)?; self.ctx = None; Ok(result) @@ -262,7 +266,7 @@ impl PyAEADEncryptionContext { &mut self, py: pyo3::Python<'p>, buf: CffiBuf<'_>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let data = buf.as_bytes(); self.updated = true; @@ -314,16 +318,16 @@ impl PyAEADEncryptionContext { fn finalize<'p>( &mut self, py: pyo3::Python<'p>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let ctx = get_mut_ctx(self.ctx.as_mut())?; let result = ctx.finalize(py)?; // XXX: do not hard code 16 - let tag = pyo3::types::PyBytes::new_with(py, 16, |t| { + let tag = pyo3::types::PyBytes::new_bound_with(py, 16, |t| { ctx.ctx.tag(t).map_err(CryptographyError::from)?; Ok(()) })?; - self.tag = Some(tag.into_py(py)); + self.tag = Some(tag.unbind()); self.ctx = None; Ok(result) @@ -349,7 +353,7 @@ impl PyAEADDecryptionContext { &mut self, py: pyo3::Python<'p>, buf: CffiBuf<'_>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let data = buf.as_bytes(); self.updated = true; @@ -401,12 +405,12 @@ impl PyAEADDecryptionContext { fn finalize<'p>( &mut self, py: pyo3::Python<'p>, - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let ctx = get_mut_ctx(self.ctx.as_mut())?; if ctx .py_mode - .as_ref(py) + .bind(py) .getattr(pyo3::intern!(py, "tag"))? .is_none() { @@ -426,12 +430,12 @@ impl PyAEADDecryptionContext { &mut self, py: pyo3::Python<'p>, tag: &[u8], - ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + ) -> CryptographyResult> { let ctx = get_mut_ctx(self.ctx.as_mut())?; if !ctx .py_mode - .as_ref(py) + .bind(py) .getattr(pyo3::intern!(py, "tag"))? .is_none() { @@ -444,7 +448,7 @@ impl PyAEADDecryptionContext { let min_tag_length = ctx .py_mode - .as_ref(py) + .bind(py) .getattr(pyo3::intern!(py, "_min_tag_length"))? .extract()?; // XXX: Do not hard code 16 @@ -506,8 +510,11 @@ fn create_decryption_ctx( let mut ctx = CipherContext::new(py, algorithm, mode.clone(), openssl::symm::Mode::Decrypt)?; if mode.is_instance(&types::MODE_WITH_AUTHENTICATION_TAG.get_bound(py)?)? { - if let Some(tag) = mode.getattr(pyo3::intern!(py, "tag"))?.extract()? { - ctx.ctx.set_tag(tag)?; + if let Some(tag) = mode + .getattr(pyo3::intern!(py, "tag"))? + .extract::>()? + { + ctx.ctx.set_tag(&tag)?; } Ok(PyAEADDecryptionContext { @@ -536,31 +543,33 @@ fn cipher_supported( } #[pyo3::prelude::pyfunction] -fn _advance(ctx: &pyo3::PyAny, n: u64) { - if let Ok(c) = ctx.downcast::>() { +fn _advance(ctx: pyo3::Bound<'_, pyo3::PyAny>, n: u64) { + if let Ok(c) = ctx.downcast::() { c.borrow_mut().bytes_remaining -= n; - } else if let Ok(c) = ctx.downcast::>() { + } else if let Ok(c) = ctx.downcast::() { c.borrow_mut().bytes_remaining -= n; } } #[pyo3::prelude::pyfunction] -fn _advance_aad(ctx: &pyo3::PyAny, n: u64) { - if let Ok(c) = ctx.downcast::>() { +fn _advance_aad(ctx: pyo3::Bound<'_, pyo3::PyAny>, n: u64) { + if let Ok(c) = ctx.downcast::() { c.borrow_mut().aad_bytes_remaining -= n; - } else if let Ok(c) = ctx.downcast::>() { + } else if let Ok(c) = ctx.downcast::() { c.borrow_mut().aad_bytes_remaining -= n; } } -pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { - let m = pyo3::prelude::PyModule::new(py, "ciphers")?; - m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(cipher_supported, m)?)?; - - m.add_function(pyo3::wrap_pyfunction!(_advance, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(_advance_aad, m)?)?; +pub(crate) fn create_module( + py: pyo3::Python<'_>, +) -> pyo3::PyResult> { + let m = pyo3::prelude::PyModule::new_bound(py, "ciphers")?; + m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, &m)?)?; + m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, &m)?)?; + m.add_function(pyo3::wrap_pyfunction!(cipher_supported, &m)?)?; + + m.add_function(pyo3::wrap_pyfunction!(_advance, &m)?)?; + m.add_function(pyo3::wrap_pyfunction!(_advance_aad, &m)?)?; m.add_class::()?; m.add_class::()?; diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index 2b1592906a1f..10365558fc21 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -25,7 +25,7 @@ pub(crate) mod x448; pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<()> { module.add_submodule(aead::create_module(module.py())?.into_gil_ref())?; - module.add_submodule(ciphers::create_module(module.py())?)?; + module.add_submodule(ciphers::create_module(module.py())?.into_gil_ref())?; module.add_submodule(cmac::create_module(module.py())?)?; module.add_submodule(dh::create_module(module.py())?)?; module.add_submodule(dsa::create_module(module.py())?)?;