From f89e9a5e2a9454d382b9114f53e5fc8ea0362829 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Tue, 30 Nov 2021 12:15:14 +0100 Subject: [PATCH] Make the mysql bind implementation more robust against possible aliasing issues This commit applies the same reasoning from the previous commit to the mysql bind implemenation. --- diesel/src/mysql/connection/bind.rs | 328 +++++++++++++++++++++------- 1 file changed, 252 insertions(+), 76 deletions(-) diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index 116be14e332d..d7ddc66556e4 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -3,6 +3,7 @@ use std::mem; use std::mem::MaybeUninit; use std::ops::Index; use std::os::raw as libc; +use std::ptr::NonNull; use super::stmt::MysqlFieldMetadata; use super::stmt::Statement; @@ -148,29 +149,106 @@ impl From for Flags { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct BindData { tpe: ffi::enum_field_types, - bytes: Vec, + bytes: Option>, length: libc::c_ulong, + capacity: usize, flags: Flags, is_null: ffi::my_bool, is_truncated: Option, } +// We need to write a manual clone impl +// as we need to clone the underlying buffer +// instead of just copying the pointer +impl Clone for BindData { + fn clone(&self) -> Self { + let (ptr, len, capacity) = if let Some(ptr) = self.bytes { + let slice = unsafe { + // We know that this points to a slice and the pointer is not null at this + // location + // The length pointer is valid as long as noone missuses `bind_for_truncated_data` + // as this is the only location that updates the length vield before the corresponding data are + // written. At the time of writting this comment the `BindData::bind_for_truncated_data` + // function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding + // invariant. + std::slice::from_raw_parts(ptr.as_ptr(), self.length as usize) + }; + let mut vec = slice.to_owned(); + let ptr = NonNull::new(vec.as_mut_ptr()); + let len = vec.len() as libc::c_ulong; + let capacity = vec.capacity(); + mem::forget(vec); + (ptr, len, capacity) + } else { + (None, 0, 0) + }; + Self { + tpe: self.tpe, + bytes: ptr, + length: len, + capacity, + flags: self.flags, + is_null: self.is_null, + is_truncated: self.is_truncated, + } + } +} + +impl Drop for BindData { + fn drop(&mut self) { + if let Some(bytes) = self.bytes { + std::mem::drop(unsafe { + // We know that this buffer was allocated by a vector, so constructing a vector from it is fine + // We know the correct capacity here + // We use 0 as lenght to prevent situations where the lenght is already updated but + // no date are already written as we could touch uninitialized memory otherwise + // Using 0 as lenght is fine as we don't need to call drop for `u8` + // (as there is no drop impl for primitive types) + Vec::from_raw_parts(bytes.as_ptr(), 0, self.capacity) + }); + self.bytes = None; + } + } +} + impl BindData { fn for_input((tpe, data): (MysqlType, Option>)) -> Self { - let is_null = if data.is_none() { 1 } else { 0 }; - let bytes = data.unwrap_or_default(); - let length = bytes.len() as libc::c_ulong; let (tpe, flags) = tpe.into(); - BindData { - tpe, - bytes, - length, - is_null, - is_truncated: None, - flags, + if let Some(mut bytes) = data { + bytes.shrink_to_fit(); + let len = bytes.len(); + let capacity = bytes.capacity(); + + let bytes = if len > 0 { + let ret = NonNull::new(bytes.as_mut_ptr()); + mem::forget(bytes); + ret + } else { + None + }; + + Self { + tpe, + bytes, + length: len as libc::c_ulong, + capacity, + flags, + is_null: 0, + is_truncated: None, + } + } else { + Self { + tpe, + bytes: None, + length: 0, + capacity: 0, + flags, + is_null: 1, + is_truncated: None, + } } } @@ -306,36 +384,39 @@ impl BindData { } else { (metadata.field_type(), metadata.flags()) }; - - let bytes = known_buffer_size_for_ffi_type(tpe) - .map(|len| vec![0; len]) - .unwrap_or_default(); - let length = bytes.len() as libc::c_ulong; - - BindData { - tpe, - bytes, - length, - is_null: 0, - is_truncated: Some(0), - flags, - } + Self::from_tpe_and_flags((tpe, flags)) } - #[cfg(test)] - fn for_test_output((tpe, flags): (ffi::enum_field_types, Flags)) -> Self { - let bytes = known_buffer_size_for_ffi_type(tpe) + fn from_tpe_and_flags((tpe, flags): (ffi::enum_field_types, Flags)) -> Self { + let mut bytes = known_buffer_size_for_ffi_type(tpe) .map(|len| vec![0; len]) .unwrap_or_default(); let length = bytes.len() as libc::c_ulong; - - BindData { - tpe, - bytes, - length, - is_null: 0, - is_truncated: Some(0), - flags, + let capacity = bytes.capacity(); + + if capacity > 0 { + let ptr = NonNull::new(bytes.as_mut_ptr()); + mem::forget(bytes); + + Self { + tpe, + bytes: ptr, + length, + capacity, + flags, + is_null: 0, + is_truncated: Some(0), + } + } else { + Self { + tpe, + bytes: None, + length, + capacity, + flags, + is_null: 0, + is_truncated: Some(0), + } } } @@ -351,8 +432,19 @@ impl BindData { if self.is_null() { None } else { + let data = self.bytes?; let tpe = (self.tpe, self.flags).into(); - Some(MysqlValue::new(&self.bytes, tpe)) + let slice = unsafe { + // We know that this points to a slice and the pointer is not null at this + // location + // The length pointer is valid as long as noone missuses `bind_for_truncated_data` + // as this is the only location that updates the length vield before the corresponding data are + // written. At the time of writting this comment the `BindData::bind_for_truncated_data` + // function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding + // invariant. + std::slice::from_raw_parts(data.as_ptr(), self.length as usize) + }; + Some(MysqlValue::new(slice, tpe)) } } @@ -363,10 +455,13 @@ impl BindData { fn update_buffer_length(&mut self) { use std::cmp::min; - let actual_bytes_in_buffer = min(self.bytes.capacity(), self.length as usize); - unsafe { self.bytes.set_len(actual_bytes_in_buffer) } + let actual_bytes_in_buffer = min(self.capacity, self.length as usize); + self.length = actual_bytes_in_buffer as libc::c_ulong; } + // This function is marked as unsafe as it returns a owned value + // containing a pointer with a lifetime coupled to self. + // Callers need to ensure that the returend value cannot outlive `self` unsafe fn mysql_bind(&mut self) -> ffi::MYSQL_BIND { use std::ptr::addr_of_mut; @@ -374,8 +469,12 @@ impl BindData { let ptr = bind.as_mut_ptr(); addr_of_mut!((*ptr).buffer_type).write(self.tpe); - addr_of_mut!((*ptr).buffer).write(self.bytes.as_mut_ptr() as *mut libc::c_void); - addr_of_mut!((*ptr).buffer_length).write(self.bytes.capacity() as libc::c_ulong); + addr_of_mut!((*ptr).buffer).write( + self.bytes + .map(|p| p.as_ptr()) + .unwrap_or(std::ptr::null_mut()) as *mut libc::c_void, + ); + addr_of_mut!((*ptr).buffer_length).write(self.capacity as libc::c_ulong); addr_of_mut!((*ptr).length).write(&mut self.length); addr_of_mut!((*ptr).is_null).write(&mut self.is_null); addr_of_mut!((*ptr).is_unsigned) @@ -397,22 +496,52 @@ impl BindData { /// this function is unsafe unless the binds are immediately rebound. unsafe fn bind_for_truncated_data(&mut self) -> Option<(ffi::MYSQL_BIND, usize)> { if self.is_truncated() { - let offset = self.bytes.capacity(); - let truncated_amount = self.length as usize - offset; - - debug_assert!( - truncated_amount > 0, - "output buffers were invalidated \ - without calling `mysql_stmt_bind_result`" - ); - self.bytes.set_len(offset); - self.bytes.reserve(truncated_amount); - self.bytes.set_len(self.length as usize); - - let mut bind = self.mysql_bind(); - bind.buffer = self.bytes[offset..].as_mut_ptr() as *mut libc::c_void; - bind.buffer_length = truncated_amount as libc::c_ulong; - Some((bind, offset)) + if let Some(bytes) = self.bytes { + let mut bytes = Vec::from_raw_parts(bytes.as_ptr(), self.capacity, self.capacity); + self.bytes = None; + + let offset = self.capacity; + let truncated_amount = self.length as usize - offset; + + debug_assert!( + truncated_amount > 0, + "output buffers were invalidated \ + without calling `mysql_stmt_bind_result`" + ); + + // reserve space for any missing byte + // we know the exact size here + bytes.reserve(truncated_amount); + self.capacity = bytes.capacity(); + self.bytes = NonNull::new(bytes.as_mut_ptr()); + mem::forget(bytes); + + let mut bind = self.mysql_bind(); + + if let Some(ptr) = self.bytes { + // Using offset is safe here as we have a u8 array (where std::mem::size_of:: == 1) + // and we have a buffer that has at least + bind.buffer = ptr.as_ptr().add(offset) as *mut libc::c_void; + bind.buffer_length = truncated_amount as libc::c_ulong; + } else { + bind.buffer_length = 0; + } + Some((bind, offset)) + } else { + // offset is zero here as we don't have a buffer yet + // we know the requested lenght here so we can just request + // the correct size + let mut vec = vec![0_u8; self.length as usize]; + self.capacity = vec.capacity(); + self.bytes = NonNull::new(vec.as_mut_ptr()); + mem::forget(vec); + + let bind = self.mysql_bind(); + // As we did not have a buffer before + // we couldn't have loaded any data yet, therefore + // request everything + Some((bind, 0)) + } } else { None } @@ -648,7 +777,10 @@ mod tests { { let meta = (bind.tpe, bind.flags).into(); dbg!(meta); - let value = MysqlValue::new(&bind.bytes, meta); + + let value = bind.value().expect("Is not null"); + let value = MysqlValue::new(value.as_bytes(), meta); + dbg!(T::from_sql(value)) } @@ -1119,7 +1251,7 @@ mod tests { ) -> BindData { let mut stmt: Statement = conn.raw_connection.prepare(query).unwrap(); - let bind = BindData::for_test_output(bind_tpe.into()); + let bind = BindData::from_tpe_and_flags(bind_tpe.into()); let mut binds = OutputBinds(Binds { data: vec![bind] }); @@ -1133,27 +1265,35 @@ mod tests { query: &'static str, conn: &MysqlConnection, id: i32, - (field, tpe): (Vec, impl Into<(ffi::enum_field_types, Flags)>), + (mut field, tpe): (Vec, impl Into<(ffi::enum_field_types, Flags)>), ) { let mut stmt = conn.raw_connection.prepare(query).unwrap(); let length = field.len() as _; let (tpe, flags) = tpe.into(); + let capacity = field.capacity(); + let ptr = NonNull::new(field.as_mut_ptr()); + mem::forget(field); let field_bind = BindData { tpe, - bytes: field, + bytes: ptr, + capacity, length, flags, is_null: 0, is_truncated: None, }; - let bytes = id.to_be_bytes().to_vec(); + let mut bytes = id.to_be_bytes().to_vec(); let length = bytes.len() as _; + let capacity = bytes.capacity(); + let ptr = NonNull::new(bytes.as_mut_ptr()); + mem::forget(bytes); let id_bind = BindData { tpe: ffi::enum_field_types::MYSQL_TYPE_LONG, - bytes, + bytes: ptr, + capacity, length, flags: Flags::empty(), is_null: 0, @@ -1222,7 +1362,10 @@ mod tests { to_value::(&json_col_as_text).unwrap(), "{\"key1\": \"value1\", \"key2\": \"value2\"}" ); - assert_eq!(json_col_as_json.bytes, json_col_as_text.bytes); + assert_eq!( + json_col_as_json.value().unwrap().as_bytes(), + json_col_as_text.value().unwrap().as_bytes() + ); conn.execute("DELETE FROM json_test").unwrap(); @@ -1270,7 +1413,10 @@ mod tests { to_value::(&json_col_as_text).unwrap(), "{\"abc\": 42}" ); - assert_eq!(json_col_as_json.bytes, json_col_as_text.bytes); + assert_eq!( + json_col_as_json.value().unwrap().as_bytes(), + json_col_as_text.value().unwrap().as_bytes() + ); conn.execute("DELETE FROM json_test").unwrap(); @@ -1314,7 +1460,10 @@ mod tests { to_value::(&json_col_as_text).unwrap(), "{\"abca\": 42}" ); - assert_eq!(json_col_as_json.bytes, json_col_as_text.bytes); + assert_eq!( + json_col_as_json.value().unwrap().as_bytes(), + json_col_as_text.value().unwrap().as_bytes() + ); } #[test] @@ -1370,7 +1519,10 @@ mod tests { to_value::(&enum_col_as_text).unwrap(), "green" ); - assert_eq!(enum_col_as_enum.bytes, enum_col_as_text.bytes); + assert_eq!( + enum_col_as_enum.value().unwrap().as_bytes(), + enum_col_as_text.value().unwrap().as_bytes() + ); } let enum_col_as_text = query_single_table( @@ -1389,7 +1541,10 @@ mod tests { to_value::(&enum_col_as_text).unwrap(), "green" ); - assert_eq!(enum_col_as_enum.bytes, enum_col_as_text.bytes); + assert_eq!( + enum_col_as_enum.value().unwrap().as_bytes(), + enum_col_as_text.value().unwrap().as_bytes() + ); conn.execute("DELETE FROM enum_test").unwrap(); @@ -1427,7 +1582,10 @@ mod tests { assert!(enum_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!enum_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&enum_col_as_text).unwrap(), "blue"); - assert_eq!(enum_col_as_enum.bytes, enum_col_as_text.bytes); + assert_eq!( + enum_col_as_enum.value().unwrap().as_bytes(), + enum_col_as_text.value().unwrap().as_bytes() + ); let enum_col_as_text = query_single_table( "SELECT enum_field FROM enum_test", @@ -1442,7 +1600,10 @@ mod tests { assert!(enum_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!enum_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&enum_col_as_text).unwrap(), "blue"); - assert_eq!(enum_col_as_enum.bytes, enum_col_as_text.bytes); + assert_eq!( + enum_col_as_enum.value().unwrap().as_bytes(), + enum_col_as_text.value().unwrap().as_bytes() + ); conn.execute("DELETE FROM enum_test").unwrap(); @@ -1483,7 +1644,10 @@ mod tests { assert!(enum_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!enum_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&enum_col_as_text).unwrap(), "red"); - assert_eq!(enum_col_as_enum.bytes, enum_col_as_text.bytes); + assert_eq!( + enum_col_as_enum.value().unwrap().as_bytes(), + enum_col_as_text.value().unwrap().as_bytes() + ); } #[test] @@ -1530,7 +1694,10 @@ mod tests { assert!(!set_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!set_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&set_col_as_text).unwrap(), "green"); - assert_eq!(set_col_as_set.bytes, set_col_as_text.bytes); + assert_eq!( + set_col_as_set.value().unwrap().as_bytes(), + set_col_as_text.value().unwrap().as_bytes() + ); } let set_col_as_text = query_single_table( "SELECT set_field FROM set_test", @@ -1545,7 +1712,10 @@ mod tests { assert!(!set_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!set_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&set_col_as_text).unwrap(), "green"); - assert_eq!(set_col_as_set.bytes, set_col_as_text.bytes); + assert_eq!( + set_col_as_set.value().unwrap().as_bytes(), + set_col_as_text.value().unwrap().as_bytes() + ); conn.execute("DELETE FROM set_test").unwrap(); @@ -1580,7 +1750,10 @@ mod tests { assert!(!set_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!set_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&set_col_as_text).unwrap(), "blue"); - assert_eq!(set_col_as_set.bytes, set_col_as_text.bytes); + assert_eq!( + set_col_as_set.value().unwrap().as_bytes(), + set_col_as_text.value().unwrap().as_bytes() + ); conn.execute("DELETE FROM set_test").unwrap(); @@ -1615,6 +1788,9 @@ mod tests { assert!(!set_col_as_text.flags.contains(Flags::ENUM_FLAG)); assert!(!set_col_as_text.flags.contains(Flags::BINARY_FLAG)); assert_eq!(to_value::(&set_col_as_text).unwrap(), "red"); - assert_eq!(set_col_as_set.bytes, set_col_as_text.bytes); + assert_eq!( + set_col_as_set.value().unwrap().as_bytes(), + set_col_as_text.value().unwrap().as_bytes() + ); } }