Skip to content

Commit

Permalink
Fixed out of bounds reads for open_cursor (#1837)
Browse files Browse the repository at this point in the history
The problem arises when the prepared statement has less arguments than
the used plan since there's no parameter to pass the number of arguments
like it's for `SPI_execute_with_args`. Meanwhile the `execute` method
does a check.

Introduce a variant, `try_open_cursor`, that returns an SpiResult
instead.
  • Loading branch information
YohDeadfall authored Sep 6, 2024
1 parent 7755cc7 commit a772c49
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 43 deletions.
39 changes: 39 additions & 0 deletions pgrx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,44 @@ mod tests {
})
}

#[pg_test]
#[should_panic]
fn test_cursor_prepared_statement_panics_none_args() -> Result<(), pgrx::spi::Error> {
test_cursor_prepared_statement_panics_impl(None)
}

#[pg_test]
#[should_panic]
fn test_cursor_prepared_statement_panics_less_args() -> Result<(), pgrx::spi::Error> {
test_cursor_prepared_statement_panics_impl(Some([].to_vec()))
}

#[pg_test]
#[should_panic]
fn test_cursor_prepared_statement_panics_more_args() -> Result<(), pgrx::spi::Error> {
test_cursor_prepared_statement_panics_impl(Some([None, None].to_vec()))
}

fn test_cursor_prepared_statement_panics_impl(
args: Option<Vec<Option<pg_sys::Datum>>>,
) -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, None)?;
client.update(
"INSERT INTO tests.cursor_table (id) \
SELECT i FROM generate_series(1, 10) AS t(i)",
None,
None,
)?;
let prepared = client.prepare(
"SELECT * FROM tests.cursor_table WHERE id = $1",
Some([PgBuiltInOids::INT4OID.oid()].to_vec()),
)?;
client.open_cursor(&prepared, args);
unreachable!();
})
}

#[pg_test]
fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> {
let cursor_name = Spi::connect(|mut client| {
Expand Down Expand Up @@ -532,6 +570,7 @@ mod tests {
}

#[pg_test]
#[allow(deprecated)]
fn can_return_borrowed_str() -> Result<(), Box<dyn Error>> {
let res = Spi::connect(|c| {
let mut cursor = c.open_cursor("SELECT 'hello' FROM generate_series(1, 10000)", None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use pgrx::prelude::*;
use std::error::Error;

#[pg_test]
#[allow(deprecated)]
fn issue1209() -> Result<Option<String>, Box<dyn Error>> {
// create the cursor we actually care about
let mut res = Spi::connect(|c| {
Expand Down
26 changes: 13 additions & 13 deletions pgrx-tests/tests/compile-fail/escaping-spiclient-1209-cursor.stderr
Original file line number Diff line number Diff line change
@@ -1,42 +1,42 @@
error: lifetime may not live long enough
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:8:9
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:9:9
|
7 | let mut res = Spi::connect(|c| {
8 | let mut res = Spi::connect(|c| {
| -- return type of closure is SpiTupleTable<'2>
| |
| has type `SpiClient<'1>`
8 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", None)
9 | | .fetch(1000)
10 | | .unwrap()
9 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", None)
10 | | .fetch(1000)
11 | | .unwrap()
| |_____________________^ returning this value requires that `'1` must outlive `'2`

error[E0515]: cannot return value referencing temporary value
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:8:9
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:9:9
|
8 | c.open_cursor("select 'hello world' from generate_series(1, 1000)", None)
9 | c.open_cursor("select 'hello world' from generate_series(1, 1000)", None)
| ^------------------------------------------------------------------------
| |
| _________temporary value created here
| |
9 | | .fetch(1000)
10 | | .unwrap()
10 | | .fetch(1000)
11 | | .unwrap()
| |_____________________^ returns a value referencing data owned by the current function
|
= help: use `.collect()` to allocate the iterator

error: lifetime may not live long enough
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:15:26
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26
|
15 | Spi::connect(|c| c.open_cursor("select 1", None).fetch(1).unwrap());
16 | Spi::connect(|c| c.open_cursor("select 1", None).fetch(1).unwrap());
| -- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is SpiTupleTable<'2>
| has type `SpiClient<'1>`

error[E0515]: cannot return value referencing temporary value
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:15:26
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26
|
15 | Spi::connect(|c| c.open_cursor("select 1", None).fetch(1).unwrap());
16 | Spi::connect(|c| c.open_cursor("select 1", None).fetch(1).unwrap());
| -------------------------------^^^^^^^^^^^^^^^^^^
| |
| returns a value referencing data owned by the current function
Expand Down
43 changes: 41 additions & 2 deletions pgrx/src/spi/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,61 @@ impl<'conn> SpiClient<'conn> {
/// Rows may be then fetched using [`SpiCursor::fetch`].
///
/// See [`SpiCursor`] docs for usage details.
///
/// See [`try_open_cursor`][Self::try_open_cursor] which will return an [`SpiError`] rather than panicking.
///
/// # Panics
///
/// Panics if a cursor wasn't opened.
pub fn open_cursor<Q: Query<'conn>>(&self, query: Q, args: Q::Arguments) -> SpiCursor<'conn> {
query.open_cursor(self, args)
self.try_open_cursor(query, args).unwrap()
}

/// Set up a cursor that will execute the specified query
///
/// Rows may be then fetched using [`SpiCursor::fetch`].
///
/// See [`SpiCursor`] docs for usage details.
pub fn try_open_cursor<Q: Query<'conn>>(
&self,
query: Q,
args: Q::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
query.try_open_cursor(self, args)
}

/// Set up a cursor that will execute the specified update (mutating) query
///
/// Rows may be then fetched using [`SpiCursor::fetch`].
///
/// See [`SpiCursor`] docs for usage details.
///
/// See [`try_open_cursor_mut`][Self::try_open_cursor_mut] which will return an [`SpiError`] rather than panicking.
///
/// # Panics
///
/// Panics if a cursor wasn't opened.
pub fn open_cursor_mut<Q: Query<'conn>>(
&mut self,
query: Q,
args: Q::Arguments,
) -> SpiCursor<'conn> {
Spi::mark_mutable();
query.open_cursor(self, args)
self.try_open_cursor_mut(query, args).unwrap()
}

/// Set up a cursor that will execute the specified update (mutating) query
///
/// Rows may be then fetched using [`SpiCursor::fetch`].
///
/// See [`SpiCursor`] docs for usage details.
pub fn try_open_cursor_mut<Q: Query<'conn>>(
&mut self,
query: Q,
args: Q::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
Spi::mark_mutable();
query.try_open_cursor(self, args)
}

/// Find a cursor in transaction by name
Expand Down
99 changes: 71 additions & 28 deletions pgrx/src/spi/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@ use crate::pg_sys::{self, PgOid};
/// Its primary purpose is to abstract away differences between
/// one-off statements and prepared statements, but it can potentially
/// be implemented for other types, provided they can be converted into a query.
pub trait Query<'conn> {
pub trait Query<'conn>: Sized {
type Arguments;

/// Execute a query given a client and other arguments
/// Execute a query given a client and other arguments.
fn execute(
self,
client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
arguments: Self::Arguments,
) -> SpiResult<SpiTupleTable<'conn>>;

/// Open a cursor for the query
fn open_cursor(self, client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn>;
/// Open a cursor for the query.
///
/// # Panics
///
/// Panics if a cursor wasn't opened.
fn open_cursor(self, client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
self.try_open_cursor(client, args).unwrap()
}

/// Tries to open cursor for the query.
fn try_open_cursor(
self,
client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>>;
}

impl<'conn> Query<'conn> for &String {
Expand All @@ -40,8 +53,12 @@ impl<'conn> Query<'conn> for &String {
self.as_str().execute(client, limit, arguments)
}

fn open_cursor(self, client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
self.as_str().open_cursor(client, args)
fn try_open_cursor(
self,
client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
self.as_str().try_open_cursor(client, args)
}
}

Expand Down Expand Up @@ -119,7 +136,11 @@ impl<'conn> Query<'conn> for &str {
SpiClient::prepare_tuple_table(status_code)
}

fn open_cursor(self, _client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
fn try_open_cursor(
self,
_client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
let src = CString::new(self).expect("query contained a null byte");
let args = args.unwrap_or_default();

Expand All @@ -140,7 +161,7 @@ impl<'conn> Query<'conn> for &str {
0,
))
};
SpiCursor { ptr, __marker: PhantomData }
Ok(SpiCursor { ptr, __marker: PhantomData })
}
}

Expand Down Expand Up @@ -182,8 +203,12 @@ impl<'conn> Query<'conn> for &OwnedPreparedStatement {
(&self.0).execute(client, limit, arguments)
}

fn open_cursor(self, client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
(&self.0).open_cursor(client, args)
fn try_open_cursor(
self,
client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
(&self.0).try_open_cursor(client, args)
}
}

Expand All @@ -199,8 +224,12 @@ impl<'conn> Query<'conn> for OwnedPreparedStatement {
(&self.0).execute(client, limit, arguments)
}

fn open_cursor(self, client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
(&self.0).open_cursor(client, args)
fn try_open_cursor(
self,
client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
(&self.0).try_open_cursor(client, args)
}
}

Expand All @@ -221,6 +250,22 @@ impl<'conn> PreparedStatement<'conn> {
mutating: self.mutating,
})
}

fn args_to_datums(
&self,
args: <Self as Query<'conn>>::Arguments,
) -> SpiResult<(Vec<pg_sys::Datum>, Vec<std::os::raw::c_char>)> {
let args = args.unwrap_or_default();

let actual = args.len();
let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;

if expected == actual {
Ok(args.into_iter().map(prepare_datum).unzip())
} else {
Err(SpiError::PreparedStatementArgumentMismatch { expected, got: actual })
}
}
}

impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
Expand All @@ -236,16 +281,8 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
let args = arguments.unwrap_or_default();
let nargs = args.len();

let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;

if nargs != expected {
return Err(SpiError::PreparedStatementArgumentMismatch { expected, got: nargs });
}

let (mut datums, mut nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
let (mut datums, mut nulls) = self.args_to_datums(arguments)?;

// SAFETY: all arguments are prepared above
let status_code = unsafe {
Expand All @@ -261,10 +298,12 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
SpiClient::prepare_tuple_table(status_code)
}

fn open_cursor(self, _client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
let args = args.unwrap_or_default();

let (mut datums, nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
fn try_open_cursor(
self,
_client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
let (mut datums, nulls) = self.args_to_datums(args)?;

// SAFETY: arguments are prepared above and SPI_cursor_open will never return the null
// pointer. It'll raise an ERROR if something is invalid for it to create the cursor
Expand All @@ -277,7 +316,7 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
!self.mutating && Spi::is_xact_still_immutable(),
))
};
SpiCursor { ptr, __marker: PhantomData }
Ok(SpiCursor { ptr, __marker: PhantomData })
}
}

Expand All @@ -293,7 +332,11 @@ impl<'conn> Query<'conn> for PreparedStatement<'conn> {
(&self).execute(client, limit, arguments)
}

fn open_cursor(self, client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
(&self).open_cursor(client, args)
fn try_open_cursor(
self,
client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
(&self).try_open_cursor(client, args)
}
}

0 comments on commit a772c49

Please sign in to comment.