From 5a6b29a2262e3ca9d539bc09ac20fc674e1a9c00 Mon Sep 17 00:00:00 2001 From: Yoh Deadfall Date: Tue, 17 Sep 2024 19:42:16 +0300 Subject: [PATCH 1/3] Allowed CStr to be used as SPI commands --- pgrx/src/spi/client.rs | 55 ++------- pgrx/src/spi/query.rs | 246 ++++++++++++++++++++++++++--------------- 2 files changed, 166 insertions(+), 135 deletions(-) diff --git a/pgrx/src/spi/client.rs b/pgrx/src/spi/client.rs index e32728d743..ccedde57f2 100644 --- a/pgrx/src/spi/client.rs +++ b/pgrx/src/spi/client.rs @@ -1,10 +1,11 @@ -use std::ffi::CString; use std::marker::PhantomData; use std::ptr::NonNull; use crate::pg_sys::{self, PgOid}; use crate::spi::{PreparedStatement, Query, Spi, SpiCursor, SpiError, SpiResult, SpiTupleTable}; +use super::query::PreparableQuery; + // TODO: should `'conn` be invariant? pub struct SpiClient<'conn> { __marker: PhantomData<&'conn SpiConnection>, @@ -12,61 +13,21 @@ pub struct SpiClient<'conn> { impl<'conn> SpiClient<'conn> { /// Prepares a statement that is valid for the lifetime of the client - /// - /// # Panics - /// - /// This function will panic if the supplied `query` string contained a NULL byte - pub fn prepare( + pub fn prepare>( &self, - query: &str, + query: Q, args: Option>, ) -> SpiResult> { - self.make_prepare_statement(query, args, false) + query.prepare(self, args) } /// Prepares a mutating statement that is valid for the lifetime of the client - /// - /// # Panics - /// - /// This function will panic if the supplied `query` string contained a NULL byte - pub fn prepare_mut( + pub fn prepare_mut>( &self, - query: &str, - args: Option>, - ) -> SpiResult> { - self.make_prepare_statement(query, args, true) - } - - fn make_prepare_statement( - &self, - query: &str, + query: Q, args: Option>, - mutating: bool, ) -> SpiResult> { - let src = CString::new(query).expect("query contained a null byte"); - let args = args.unwrap_or_default(); - let nargs = args.len(); - - // SAFETY: all arguments are prepared above - let plan = unsafe { - pg_sys::SPI_prepare( - src.as_ptr(), - nargs as i32, - args.into_iter().map(PgOid::value).collect::>().as_mut_ptr(), - ) - }; - Ok(PreparedStatement { - plan: NonNull::new(plan).ok_or_else(|| { - Spi::check_status(unsafe { - // SAFETY: no concurrent usage - pg_sys::SPI_result - }) - .err() - .unwrap() - })?, - __marker: PhantomData, - mutating, - }) + query.prepare_mut(self, args) } /// perform a SELECT statement diff --git a/pgrx/src/spi/query.rs b/pgrx/src/spi/query.rs index 2f4178bf4a..61262f02fe 100644 --- a/pgrx/src/spi/query.rs +++ b/pgrx/src/spi/query.rs @@ -1,4 +1,4 @@ -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::marker::PhantomData; use std::ops::Deref; use std::ptr::NonNull; @@ -41,32 +41,84 @@ pub trait Query<'conn>: Sized { ) -> SpiResult>; } -impl<'conn> Query<'conn> for &String { - type Arguments = Option)>>; - - fn execute( +/// A trait representing a query which can be prepared. +pub trait PreparableQuery<'conn>: Query<'conn> { + /// Prepares a query. + fn prepare( self, client: &SpiClient<'conn>, - limit: Option, - args: Self::Arguments, - ) -> SpiResult> { - self.as_str().execute(client, limit, args) - } + args: Option>, + ) -> SpiResult>; - fn try_open_cursor( + /// Prepares a query allowed to change data + fn prepare_mut( self, client: &SpiClient<'conn>, - args: Self::Arguments, - ) -> SpiResult> { - self.as_str().try_open_cursor(client, args) - } + args: Option>, + ) -> SpiResult>; } -fn prepare_datum(datum: Option) -> (pg_sys::Datum, std::os::raw::c_char) { - match datum { - Some(datum) => (datum, ' ' as std::os::raw::c_char), - None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char), +fn execute<'conn>( + cmd: &CStr, + args: Option)>>, + limit: Option, +) -> SpiResult> { + // SAFETY: no concurrent access + unsafe { + pg_sys::SPI_tuptable = std::ptr::null_mut(); } + + let status_code = match args { + Some(args) => { + let nargs = args.len(); + let (mut argtypes, mut datums, nulls) = args_to_datums(args); + + // SAFETY: arguments are prepared above + unsafe { + pg_sys::SPI_execute_with_args( + cmd.as_ptr(), + nargs as i32, + argtypes.as_mut_ptr(), + datums.as_mut_ptr(), + nulls.as_ptr(), + Spi::is_xact_still_immutable(), + limit.unwrap_or(0), + ) + } + } + // SAFETY: arguments are prepared above + None => unsafe { + pg_sys::SPI_execute(cmd.as_ptr(), Spi::is_xact_still_immutable(), limit.unwrap_or(0)) + }, + }; + + SpiClient::prepare_tuple_table(status_code) +} + +fn open_cursor<'conn>( + cmd: &CStr, + args: Option)>>, +) -> SpiResult> { + let args = args.unwrap_or_default(); + let nargs = args.len(); + let (mut argtypes, mut datums, nulls) = args_to_datums(args); + + let ptr = unsafe { + // SAFETY: arguments are prepared above and SPI_cursor_open_with_args will never return + // the null pointer. It'll raise an ERROR if something is invalid for it to create the cursor + NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args( + std::ptr::null_mut(), // let postgres assign a name + cmd.as_ptr(), + nargs as i32, + argtypes.as_mut_ptr(), + datums.as_mut_ptr(), + nulls.as_ptr(), + Spi::is_xact_still_immutable(), + 0, + )) + }; + + Ok(SpiCursor { ptr, __marker: PhantomData }) } fn args_to_datums( @@ -87,84 +139,102 @@ fn args_to_datums( (argtypes, datums, nulls) } -impl<'conn> Query<'conn> for &str { - type Arguments = Option)>>; +fn prepare_datum(datum: Option) -> (pg_sys::Datum, std::os::raw::c_char) { + match datum { + Some(datum) => (datum, ' ' as std::os::raw::c_char), + None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char), + } +} - /// # Panics - /// - /// This function will panic if somehow the specified query contains a null byte. - fn execute( - self, - _client: &SpiClient<'conn>, - limit: Option, - args: Self::Arguments, - ) -> SpiResult> { - // SAFETY: no concurrent access - unsafe { - pg_sys::SPI_tuptable = std::ptr::null_mut(); - } +fn prepare<'conn>( + cmd: &CStr, + args: Option>, + mutating: bool, +) -> SpiResult> { + let args = args.unwrap_or_default(); + + // SAFETY: all arguments are prepared above + let plan = unsafe { + pg_sys::SPI_prepare( + cmd.as_ptr(), + args.len() as i32, + args.into_iter().map(PgOid::value).collect::>().as_mut_ptr(), + ) + }; + Ok(PreparedStatement { + plan: NonNull::new(plan).ok_or_else(|| { + Spi::check_status(unsafe { + // SAFETY: no concurrent usage + pg_sys::SPI_result + }) + .err() + .unwrap() + })?, + __marker: PhantomData, + mutating, + }) +} - let src = CString::new(self).expect("query contained a null byte"); - let status_code = match args { - Some(args) => { - let nargs = args.len(); - let (mut argtypes, mut datums, nulls) = args_to_datums(args); - - // SAFETY: arguments are prepared above - unsafe { - pg_sys::SPI_execute_with_args( - src.as_ptr(), - nargs as i32, - argtypes.as_mut_ptr(), - datums.as_mut_ptr(), - nulls.as_ptr(), - Spi::is_xact_still_immutable(), - limit.unwrap_or(0), - ) - } +macro_rules! impl_prepared_query { + ($t:ident, $s:ident) => { + impl<'conn> Query<'conn> for &$t { + type Arguments = Option)>>; + + #[inline] + fn execute( + self, + _client: &SpiClient<'conn>, + limit: Option, + args: Self::Arguments, + ) -> SpiResult> { + execute($s(self).as_ref(), args, limit) } - // SAFETY: arguments are prepared above - None => unsafe { - pg_sys::SPI_execute( - src.as_ptr(), - Spi::is_xact_still_immutable(), - limit.unwrap_or(0), - ) - }, - }; - SpiClient::prepare_tuple_table(status_code) - } + #[inline] + fn try_open_cursor( + self, + _client: &SpiClient<'conn>, + args: Self::Arguments, + ) -> SpiResult> { + open_cursor($s(self).as_ref(), args) + } + } - fn try_open_cursor( - self, - _client: &SpiClient<'conn>, - args: Self::Arguments, - ) -> SpiResult> { - let src = CString::new(self).expect("query contained a null byte"); - let args = args.unwrap_or_default(); + impl<'conn> PreparableQuery<'conn> for &$t { + fn prepare( + self, + _client: &SpiClient<'conn>, + args: Option>, + ) -> SpiResult> { + prepare($s(self).as_ref(), args, false) + } - let nargs = args.len(); - let (mut argtypes, mut datums, nulls) = args_to_datums(args); + fn prepare_mut( + self, + _client: &SpiClient<'conn>, + args: Option>, + ) -> SpiResult> { + prepare($s(self).as_ref(), args, true) + } + } + }; +} - let ptr = unsafe { - // SAFETY: arguments are prepared above and SPI_cursor_open_with_args will never return - // the null pointer. It'll raise an ERROR if something is invalid for it to create the cursor - NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args( - std::ptr::null_mut(), // let postgres assign a name - src.as_ptr(), - nargs as i32, - argtypes.as_mut_ptr(), - datums.as_mut_ptr(), - nulls.as_ptr(), - Spi::is_xact_still_immutable(), - 0, - )) - }; - Ok(SpiCursor { ptr, __marker: PhantomData }) - } +#[inline] +fn pass_as_is(s: T) -> T { + s +} + +#[inline] +fn pass_with_conv>(s: T) -> CString { + CString::new(s.as_ref()).expect("query contained a null byte") } +impl_prepared_query!(CStr, pass_as_is); +impl_prepared_query!(CString, pass_as_is); +impl_prepared_query!(String, pass_with_conv); +impl_prepared_query!(str, pass_with_conv); + /// Client lifetime-bound prepared statement pub struct PreparedStatement<'conn> { pub(super) plan: NonNull, From 8d1d195c107540f16e3ee3c1bb3fddc403c282b0 Mon Sep 17 00:00:00 2001 From: Yoh Deadfall Date: Fri, 20 Sep 2024 21:40:09 +0300 Subject: [PATCH 2/3] Added UI test --- .../spi-prepare-prepared-statement.rs | 14 ++++++++++++++ .../spi-prepare-prepared-statement.stderr | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs create mode 100644 pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.stderr diff --git a/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs b/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs new file mode 100644 index 0000000000..96fdd2906e --- /dev/null +++ b/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs @@ -0,0 +1,14 @@ +use pgrx::prelude::*; + +fn main() {} + +#[pg_extern] +pub fn cast_function() { + Spi::connect(|client| { + let stmt = client.prepare(c"SELECT 1", None)?; + + client.prepare(stmt, None); + + Ok(()) + }); +} diff --git a/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.stderr b/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.stderr new file mode 100644 index 0000000000..82fc66f569 --- /dev/null +++ b/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.stderr @@ -0,0 +1,18 @@ +error[E0277]: the trait bound `PreparedStatement<'_>: spi::query::PreparableQuery<'_>` is not satisfied + --> tests/compile-fail/spi-prepare-prepared-statement.rs:10:24 + | +10 | client.prepare(stmt, None); + | ------- ^^^^ the trait `spi::query::PreparableQuery<'_>` is not implemented for `PreparedStatement<'_>` + | | + | required by a bound introduced by this call + | + = help: the following other types implement trait `spi::query::PreparableQuery<'conn>`: + &CStr + &CString + &std::string::String + &str +note: required by a bound in `SpiClient::<'conn>::prepare` + --> $WORKSPACE/pgrx/src/spi/client.rs + | + | pub fn prepare>( + | ^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `SpiClient::<'conn>::prepare` From cbfce97233c5bb21c4fb171af85c296274962cdf Mon Sep 17 00:00:00 2001 From: Yoh Deadfall Date: Fri, 20 Sep 2024 23:29:35 +0300 Subject: [PATCH 3/3] Marked UI test as ignored --- ...red-statement.rs => spi-prepare-prepared-statement.rs.ignored} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pgrx-tests/tests/compile-fail/{spi-prepare-prepared-statement.rs => spi-prepare-prepared-statement.rs.ignored} (100%) diff --git a/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs b/pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs.ignored similarity index 100% rename from pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs rename to pgrx-tests/tests/compile-fail/spi-prepare-prepared-statement.rs.ignored