Skip to content

Commit

Permalink
fix a few postgres integer overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
lovasoa committed Aug 16, 2024
1 parent 7ec50d7 commit eb7f6a4
Show file tree
Hide file tree
Showing 16 changed files with 74 additions and 57 deletions.
3 changes: 3 additions & 0 deletions sqlx-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ pub enum Error {
#[cfg(feature = "migrate")]
#[error("{0}")]
Migrate(#[source] Box<crate::migrate::MigrateError>),

#[error("integer overflow while converting to target type")]
IntegerOverflow(#[source] std::num::TryFromIntError),
}

impl StdError for Box<dyn DatabaseError> {}
Expand Down
4 changes: 2 additions & 2 deletions sqlx-core/src/postgres/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl PgArguments {
}

for (offset, name) in type_holes {
let oid = conn.fetch_type_id_by_name(&*name).await?;
let oid = conn.fetch_type_id_by_name(name).await?;
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
}

Expand Down Expand Up @@ -134,7 +134,7 @@ impl PgArgumentBuffer {

// encode the value into our buffer
let len = if let IsNull::No = value.encode(self) {
(self.len() - offset - 4) as i32
i32::try_from(self.len() - offset - 4).expect("bind parameter too large")
} else {
// Write a -1 to indicate NULL
// NOTE: It is illegal for [encode] to write any data
Expand Down
10 changes: 5 additions & 5 deletions sqlx-core/src/postgres/migrate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations (
.map_err(MigrateError::AccessMigrationMetadata)?;

if let Some(checksum) = checksum {
return if checksum == &*migration.checksum {
if checksum == *migration.checksum {
Ok(())
} else {
Err(MigrateError::VersionMismatch(migration.version))
};
}
} else {
Err(MigrateError::VersionMissing(migration.version))
}
Expand Down Expand Up @@ -257,7 +257,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations (
WHERE version = $2
"#,
)
.bind(elapsed.as_nanos() as i64)
.bind(i64::try_from(elapsed.as_nanos()).map_err(crate::error::Error::IntegerOverflow)?)
.bind(migration.version)
.execute(self)
.await?;
Expand Down Expand Up @@ -295,9 +295,9 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations (

async fn current_database(conn: &mut PgConnection) -> Result<String> {
// language=SQL
Ok(query_scalar("SELECT current_database()")
query_scalar("SELECT current_database()")
.fetch_one(conn)
.await?)
.await
}

// inspired from rails: https://github.com/rails/rails/blob/6e49cc77ab3d16c06e12f93158eaf3e507d4120e/activerecord/lib/active_record/migration.rb#L1308
Expand Down
6 changes: 4 additions & 2 deletions sqlx-core/src/postgres/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ where
}
}

buf.extend(&(self.len() as i32).to_be_bytes()); // len
let len = i32::try_from(self.len()).unwrap_or(i32::MAX);

buf.extend(len.to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound

for element in self.iter() {
for element in self.iter().take(usize::try_from(len).unwrap()) {
buf.encode(element);
}

Expand Down
32 changes: 20 additions & 12 deletions sqlx-core/src/postgres/types/bit_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ impl PgHasArrayType for BitVec {

impl Encode<'_, Postgres> for BitVec {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
buf.extend(&(self.len() as i32).to_be_bytes());
buf.extend(self.to_bytes());

if let Ok(len) = i32::try_from(self.len()) {
buf.extend(&len.to_be_bytes());
buf.extend_from_slice(&self.to_bytes());
} else {
debug_assert!(false, "BitVec length is too large to be encoded as i32.");
let len = i32::MAX;
buf.extend(&len.to_be_bytes());
let truncated = &self.to_bytes()[0..usize::try_from(i32::MAX).unwrap()];
buf.extend_from_slice(truncated);
};
IsNull::No
}

Expand All @@ -47,17 +54,18 @@ impl Decode<'_, Postgres> for BitVec {
match value.format() {
PgValueFormat::Binary => {
let mut bytes = value.as_bytes()?;
let len = bytes.get_i32();

if len < 0 {
Err(io::Error::new(
let len = if let Ok(len) = usize::try_from(bytes.get_i32()) {
len
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Negative VARBIT length.",
))?
}
)
.into());
};

// The smallest amount of data we can read is one byte
let bytes_len = (len as usize + 7) / 8;
let bytes_len = (len + 7) / 8;

if bytes.remaining() != bytes_len {
Err(io::Error::new(
Expand All @@ -66,12 +74,12 @@ impl Decode<'_, Postgres> for BitVec {
))?;
}

let mut bitvec = BitVec::from_bytes(&bytes);
let mut bitvec = BitVec::from_bytes(bytes);

// Chop off zeroes from the back. We get bits in bytes, so if
// our bitvec is not in full bytes, extra zeroes are added to
// the end.
while bitvec.len() > len as usize {
while bitvec.len() > len {
bitvec.pop();
}

Expand Down
5 changes: 3 additions & 2 deletions sqlx-core/src/postgres/types/chrono/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ impl PgHasArrayType for NaiveDate {
impl Encode<'_, Postgres> for NaiveDate {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
// DATE is encoded as the days since epoch
let days = (*self - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap()).num_days() as i32;
Encode::<Postgres>::encode(&days, buf)
let days = i32::try_from((*self - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap()).num_days())
.unwrap_or(i32::MAX);
Encode::<Postgres>::encode(days, buf)
}

fn size_hint(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/postgres/types/chrono/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl Encode<'_, Postgres> for NaiveDateTime {
.num_microseconds()
.unwrap_or_else(|| panic!("NaiveDateTime out of range for Postgres: {:?}", self));

Encode::<Postgres>::encode(&us, buf)
Encode::<Postgres>::encode(us, buf)
}

fn size_hint(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/postgres/types/chrono/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl Encode<'_, Postgres> for NaiveTime {
// NOTE: panic! is on overflow and 1 day does not have enough micros to overflow
let us = (*self - NaiveTime::default()).num_microseconds().unwrap();

Encode::<Postgres>::encode(&us, buf)
Encode::<Postgres>::encode(us, buf)
}

fn size_hint(&self) -> usize {
Expand Down
10 changes: 5 additions & 5 deletions sqlx-core/src/postgres/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl TryFrom<&'_ Decimal> for PgNumeric {
});
}

let scale = decimal.scale() as u16;
let scale = decimal.scale();

// A serialized version of the decimal number. The resulting byte array
// will have the following representation:
Expand All @@ -114,7 +114,7 @@ impl TryFrom<&'_ Decimal> for PgNumeric {
let remainder = 4 - groups_diff as u32;
let power = 10u32.pow(remainder as u32) as u128;

mantissa = mantissa * power;
mantissa *= power;
}

// Array to store max mantissa of Decimal in Postgres decimal format.
Expand All @@ -130,8 +130,8 @@ impl TryFrom<&'_ Decimal> for PgNumeric {
digits.reverse();

// Weight is number of digits on the left side of the decimal.
let digits_after_decimal = (scale + 3) as u16 / 4;
let weight = digits.len() as i16 - digits_after_decimal as i16 - 1;
let digits_after_decimal = u16::try_from(scale + 3)? / 4;
let weight = i16::try_from(digits.len())? - i16::try_from(digits_after_decimal)? - 1;

// Remove non-significant zeroes.
while let Some(&0) = digits.last() {
Expand All @@ -143,7 +143,7 @@ impl TryFrom<&'_ Decimal> for PgNumeric {
false => PgNumericSign::Positive,
true => PgNumericSign::Negative,
},
scale: scale as i16,
scale: i16::try_from(scale)?,
weight,
digits,
})
Expand Down
1 change: 1 addition & 0 deletions sqlx-core/src/postgres/types/ipnetwork.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const AF_INET: u8 = 2;
const AF_INET: u8 = 0;

#[cfg(unix)]
#[allow(clippy::cast_possible_truncation)]
const AF_INET: u8 = libc::AF_INET as u8;

// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/postgres/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fn array_compatible<E: Type<Postgres> + ?Sized>(ty: &PgTypeInfo) -> bool {
// we require the declared type to be an _array_ with an
// element type that is acceptable
if let PgTypeKind::Array(element) = &ty.kind() {
return E::compatible(&element);
return E::compatible(element);
}

false
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/postgres/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl Encode<'_, Postgres> for String {

impl<'r> Decode<'r, Postgres> for &'r str {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?)
value.as_str()
}
}

Expand Down
4 changes: 2 additions & 2 deletions sqlx-core/src/postgres/types/time/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ impl PgHasArrayType for Date {
impl Encode<'_, Postgres> for Date {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
// DATE is encoded as the days since epoch
let days = (*self - PG_EPOCH).whole_days() as i32;
Encode::<Postgres>::encode(&days, buf)
let days = i32::try_from((*self - PG_EPOCH).whole_days()).unwrap_or(i32::MAX);
Encode::<Postgres>::encode(days, buf)
}

fn size_hint(&self) -> usize {
Expand Down
11 changes: 6 additions & 5 deletions sqlx-core/src/postgres/types/time/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ impl PgHasArrayType for OffsetDateTime {
impl Encode<'_, Postgres> for PrimitiveDateTime {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
// TIMESTAMP is encoded as the microseconds since the epoch
let us = (*self - PG_EPOCH.midnight()).whole_microseconds() as i64;
Encode::<Postgres>::encode(&us, buf)
let us =
i64::try_from((*self - PG_EPOCH.midnight()).whole_microseconds()).unwrap_or(i64::MAX);
Encode::<Postgres>::encode(us, buf)
}

fn size_hint(&self) -> usize {
Expand Down Expand Up @@ -71,10 +72,10 @@ impl<'r> Decode<'r, Postgres> for PrimitiveDateTime {
// This is given for timestamptz for some reason
// Postgres already guarantees this to always be UTC
if s.contains('+') {
PrimitiveDateTime::parse(&*s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour]"))?
PrimitiveDateTime::parse(&s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour]"))?
} else {
PrimitiveDateTime::parse(
&*s,
&s,
&format_description!(
"[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]"
),
Expand All @@ -90,7 +91,7 @@ impl Encode<'_, Postgres> for OffsetDateTime {
let utc = self.to_offset(offset!(UTC));
let primitive = PrimitiveDateTime::new(utc.date(), utc.time());

Encode::<Postgres>::encode(&primitive, buf)
Encode::<Postgres>::encode(primitive, buf)
}

fn size_hint(&self) -> usize {
Expand Down
5 changes: 3 additions & 2 deletions sqlx-core/src/postgres/types/time/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ impl PgHasArrayType for Time {
impl Encode<'_, Postgres> for Time {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
// TIME is encoded as the microseconds since midnight
let us = (*self - Time::MIDNIGHT).whole_microseconds() as i64;
Encode::<Postgres>::encode(&us, buf)
let us = i64::try_from((*self - Time::MIDNIGHT).whole_microseconds())
.expect("number of microseconds since midnight should fit in an i64");
Encode::<Postgres>::encode(us, buf)
}

fn size_hint(&self) -> usize {
Expand Down
32 changes: 16 additions & 16 deletions sqlx-core/src/postgres/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ pub struct PgValue {
}

impl<'r> PgValueRef<'r> {
pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self {
let mut element_len = buf.get_i32();

let element_val = if element_len == -1 {
element_len = 0;
None
pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, type_info: PgTypeInfo) -> Self {
if let Ok(element_len) = usize::try_from(buf.get_i32()) {
let value = Some(&buf[..(element_len)]);
buf.advance(element_len);
PgValueRef {
value,
row: None,
type_info,
format,
}
} else {
Some(&buf[..(element_len as usize)])
};

buf.advance(element_len as usize);

PgValueRef {
value: element_val,
row: None,
type_info: ty,
format,
PgValueRef {
value: None,
row: None,
type_info,
format,
}
}
}

Expand Down

0 comments on commit eb7f6a4

Please sign in to comment.