Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert src/backend/ciphers.rs to new pyo3 APIs #10703

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions src/rust/src/backend/ciphers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::buf::{CffiBuf, CffiMutBuf};
use crate::error::{CryptographyError, CryptographyResult};
use crate::exceptions;
use crate::types;
use pyo3::prelude::PyAnyMethods;
use pyo3::prelude::{PyAnyMethods, PyModuleMethods};
use pyo3::IntoPy;

struct CipherContext {
Expand Down Expand Up @@ -121,10 +121,10 @@ impl CipherContext {
&mut self,
py: pyo3::Python<'p>,
buf: &[u8],
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let mut out_buf = vec![0; buf.len() + self.ctx.block_size()];
let n = self.update_into(py, buf, &mut out_buf)?;
Ok(pyo3::types::PyBytes::new(py, &out_buf[..n]))
Ok(pyo3::types::PyBytes::new_bound(py, &out_buf[..n]))
}

fn update_into(
Expand All @@ -146,7 +146,11 @@ impl CipherContext {
for chunk in buf.chunks(1 << 29) {
// SAFETY: We ensure that outbuf is sufficiently large above.
unsafe {
let n = if self.py_mode.as_ref(py).is_instance(types::XTS.get(py)?)? {
let n = if self
.py_mode
.bind(py)
.is_instance(&types::XTS.get_bound(py)?)?
{
self.ctx.cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..])).map_err(|_| {
pyo3::exceptions::PyValueError::new_err(
"In XTS mode you must supply at least a full block in the first update call. For AES this is 16 bytes."
Expand All @@ -171,14 +175,14 @@ impl CipherContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let mut out_buf = vec![0; self.ctx.block_size()];
let n = self.ctx.cipher_final(&mut out_buf).or_else(|e| {
if e.errors().is_empty()
&& self
.py_mode
.as_ref(py)
.is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)?
.bind(py)
.is_instance(&types::MODE_WITH_AUTHENTICATION_TAG.get_bound(py)?)?
{
return Err(CryptographyError::from(exceptions::InvalidTag::new_err(())));
}
Expand All @@ -188,7 +192,7 @@ impl CipherContext {
),
))
})?;
Ok(pyo3::types::PyBytes::new(py, &out_buf[..n]))
Ok(pyo3::types::PyBytes::new_bound(py, &out_buf[..n]))
}
}

Expand Down Expand Up @@ -233,7 +237,7 @@ impl PyCipherContext {
&mut self,
py: pyo3::Python<'p>,
buf: CffiBuf<'_>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
get_mut_ctx(self.ctx.as_mut())?.update(py, buf.as_bytes())
}

Expand All @@ -249,7 +253,7 @@ impl PyCipherContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let result = get_mut_ctx(self.ctx.as_mut())?.finalize(py)?;
self.ctx = None;
Ok(result)
Expand All @@ -262,7 +266,7 @@ impl PyAEADEncryptionContext {
&mut self,
py: pyo3::Python<'p>,
buf: CffiBuf<'_>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let data = buf.as_bytes();

self.updated = true;
Expand Down Expand Up @@ -314,16 +318,16 @@ impl PyAEADEncryptionContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let ctx = get_mut_ctx(self.ctx.as_mut())?;
let result = ctx.finalize(py)?;

// XXX: do not hard code 16
let tag = pyo3::types::PyBytes::new_with(py, 16, |t| {
let tag = pyo3::types::PyBytes::new_bound_with(py, 16, |t| {
ctx.ctx.tag(t).map_err(CryptographyError::from)?;
Ok(())
})?;
self.tag = Some(tag.into_py(py));
self.tag = Some(tag.unbind());
self.ctx = None;

Ok(result)
Expand All @@ -349,7 +353,7 @@ impl PyAEADDecryptionContext {
&mut self,
py: pyo3::Python<'p>,
buf: CffiBuf<'_>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let data = buf.as_bytes();

self.updated = true;
Expand Down Expand Up @@ -401,12 +405,12 @@ impl PyAEADDecryptionContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let ctx = get_mut_ctx(self.ctx.as_mut())?;

if ctx
.py_mode
.as_ref(py)
.bind(py)
.getattr(pyo3::intern!(py, "tag"))?
.is_none()
{
Expand All @@ -426,12 +430,12 @@ impl PyAEADDecryptionContext {
&mut self,
py: pyo3::Python<'p>,
tag: &[u8],
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let ctx = get_mut_ctx(self.ctx.as_mut())?;

if !ctx
.py_mode
.as_ref(py)
.bind(py)
.getattr(pyo3::intern!(py, "tag"))?
.is_none()
{
Expand All @@ -444,7 +448,7 @@ impl PyAEADDecryptionContext {

let min_tag_length = ctx
.py_mode
.as_ref(py)
.bind(py)
.getattr(pyo3::intern!(py, "_min_tag_length"))?
.extract()?;
// XXX: Do not hard code 16
Expand Down Expand Up @@ -506,8 +510,11 @@ fn create_decryption_ctx(
let mut ctx = CipherContext::new(py, algorithm, mode.clone(), openssl::symm::Mode::Decrypt)?;

if mode.is_instance(&types::MODE_WITH_AUTHENTICATION_TAG.get_bound(py)?)? {
if let Some(tag) = mode.getattr(pyo3::intern!(py, "tag"))?.extract()? {
ctx.ctx.set_tag(tag)?;
if let Some(tag) = mode
.getattr(pyo3::intern!(py, "tag"))?
.extract::<Option<pyo3::pybacked::PyBackedBytes>>()?
{
ctx.ctx.set_tag(&tag)?;
}

Ok(PyAEADDecryptionContext {
Expand Down Expand Up @@ -536,31 +543,33 @@ fn cipher_supported(
}

#[pyo3::prelude::pyfunction]
fn _advance(ctx: &pyo3::PyAny, n: u64) {
if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADEncryptionContext>>() {
fn _advance(ctx: pyo3::Bound<'_, pyo3::PyAny>, n: u64) {
if let Ok(c) = ctx.downcast::<PyAEADEncryptionContext>() {
c.borrow_mut().bytes_remaining -= n;
} else if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADDecryptionContext>>() {
} else if let Ok(c) = ctx.downcast::<PyAEADDecryptionContext>() {
c.borrow_mut().bytes_remaining -= n;
}
}

#[pyo3::prelude::pyfunction]
fn _advance_aad(ctx: &pyo3::PyAny, n: u64) {
if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADEncryptionContext>>() {
fn _advance_aad(ctx: pyo3::Bound<'_, pyo3::PyAny>, n: u64) {
if let Ok(c) = ctx.downcast::<PyAEADEncryptionContext>() {
c.borrow_mut().aad_bytes_remaining -= n;
} else if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADDecryptionContext>>() {
} else if let Ok(c) = ctx.downcast::<PyAEADDecryptionContext>() {
c.borrow_mut().aad_bytes_remaining -= n;
}
}

pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "ciphers")?;
m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(cipher_supported, m)?)?;

m.add_function(pyo3::wrap_pyfunction!(_advance, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(_advance_aad, m)?)?;
pub(crate) fn create_module(
py: pyo3::Python<'_>,
) -> pyo3::PyResult<pyo3::Bound<'_, pyo3::prelude::PyModule>> {
let m = pyo3::prelude::PyModule::new_bound(py, "ciphers")?;
m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(cipher_supported, &m)?)?;

m.add_function(pyo3::wrap_pyfunction!(_advance, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(_advance_aad, &m)?)?;

m.add_class::<PyCipherContext>()?;
m.add_class::<PyAEADEncryptionContext>()?;
Expand Down
2 changes: 1 addition & 1 deletion src/rust/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(crate) mod x448;

pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<()> {
module.add_submodule(aead::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(ciphers::create_module(module.py())?)?;
module.add_submodule(ciphers::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(cmac::create_module(module.py())?)?;
module.add_submodule(dh::create_module(module.py())?)?;
module.add_submodule(dsa::create_module(module.py())?)?;
Expand Down