diff --git a/src/test_encode_decode/dav1d.rs b/src/test_encode_decode/dav1d.rs index b1847ad07b..5ad2933bc8 100644 --- a/src/test_encode_decode/dav1d.rs +++ b/src/test_encode_decode/dav1d.rs @@ -12,6 +12,7 @@ use crate::test_encode_decode::{compare_plane, DecodeResult, TestDecoder}; use crate::util::{CastFromPrimitive, Pixel}; use std::collections::VecDeque; use std::marker::PhantomData; +use std::os::raw::c_int; use std::{ mem::{self, MaybeUninit}, ptr, slice, @@ -50,32 +51,28 @@ impl TestDecoder for Dav1dDecoder { h: usize, bit_depth: usize, ) -> DecodeResult { let mut corrupted_count = 0; - unsafe { - let mut data: Dav1dData = mem::zeroed(); - let ptr = dav1d_data_create(&mut data, packet.len()); - ptr::copy_nonoverlapping(packet.as_ptr(), ptr, packet.len()); - let ret = dav1d_send_data(self.dec, &mut data); - debug!("Decoded. -> {}", ret); - if ret != 0 { - corrupted_count += 1; - } + let mut data = SafeDav1dData::new(packet); + let ret = data.send(self.dec); + debug!("Decoded. -> {}", ret); + if ret != 0 { + corrupted_count += 1; + } - if ret == 0 { - loop { - let mut pic: Dav1dPicture = mem::zeroed(); - debug!("Retrieving frame"); - let ret = dav1d_get_picture(self.dec, &mut pic); - debug!("Retrieved."); - if ret == -(EAGAIN as i32) { - return DecodeResult::Done; - } - if ret != 0 { - panic!("Decode fail"); - } - - let rec = rec_fifo.pop_front().unwrap(); - compare_pic(&pic, &rec, bit_depth, w, h); + if ret == 0 { + loop { + let mut pic = SafeDav1dPicture::default(); + debug!("Retrieving frame"); + let ret = pic.get(self.dec); + debug!("Retrieved."); + if ret == -(EAGAIN as i32) { + return DecodeResult::Done; + } + if ret != 0 { + panic!("Decode fail"); } + + let rec = rec_fifo.pop_front().unwrap(); + compare_pic(&pic.0, &rec, bit_depth, w, h); } } if corrupted_count > 0 { @@ -92,6 +89,49 @@ impl Drop for Dav1dDecoder { } } +struct SafeDav1dData(Dav1dData); + +impl SafeDav1dData { + fn new(packet: &[u8]) -> Self { + unsafe { + let mut data = Self { 0: mem::zeroed() }; + let ptr = dav1d_data_create(&mut data.0, packet.len()); + ptr::copy_nonoverlapping(packet.as_ptr(), ptr, packet.len()); + data + } + } + + fn send(&mut self, context: *mut Dav1dContext) -> c_int { + unsafe { dav1d_send_data(context, &mut self.0) } + } +} + +impl Drop for SafeDav1dData { + fn drop(&mut self) { + unsafe { dav1d_data_unref(&mut self.0) }; + } +} + +struct SafeDav1dPicture(Dav1dPicture); + +impl Default for SafeDav1dPicture { + fn default() -> Self { + Self { 0: unsafe { mem::zeroed() } } + } +} + +impl SafeDav1dPicture { + fn get(&mut self, context: *mut Dav1dContext) -> c_int { + unsafe { dav1d_get_picture(context, &mut self.0) } + } +} + +impl Drop for SafeDav1dPicture { + fn drop(&mut self) { + unsafe { dav1d_picture_unref(&mut self.0) } + } +} + fn compare_pic( pic: &Dav1dPicture, frame: &Frame, bit_depth: usize, width: usize, height: usize,