Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed CStr to be used as SPI commands #1864

Merged
merged 3 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use pgrx::prelude::*;

fn main() {}

#[pg_extern]
pub fn cast_function() {
Spi::connect(|client| {
let stmt = client.prepare(c"SELECT 1", None)?;

client.prepare(stmt, None);

Ok(())
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
error[E0277]: the trait bound `PreparedStatement<'_>: spi::query::PreparableQuery<'_>` is not satisfied
--> tests/compile-fail/spi-prepare-prepared-statement.rs:10:24
|
10 | client.prepare(stmt, None);
| ------- ^^^^ the trait `spi::query::PreparableQuery<'_>` is not implemented for `PreparedStatement<'_>`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `spi::query::PreparableQuery<'conn>`:
&CStr
&CString
&std::string::String
&str
note: required by a bound in `SpiClient::<'conn>::prepare`
--> $WORKSPACE/pgrx/src/spi/client.rs
|
| pub fn prepare<Q: PreparableQuery<'conn>>(
| ^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `SpiClient::<'conn>::prepare`
55 changes: 8 additions & 47 deletions pgrx/src/spi/client.rs
Original file line number Diff line number Diff line change
@@ -1,72 +1,33 @@
use std::ffi::CString;
use std::marker::PhantomData;
use std::ptr::NonNull;

use crate::pg_sys::{self, PgOid};
use crate::spi::{PreparedStatement, Query, Spi, SpiCursor, SpiError, SpiResult, SpiTupleTable};

use super::query::PreparableQuery;

// TODO: should `'conn` be invariant?
pub struct SpiClient<'conn> {
__marker: PhantomData<&'conn SpiConnection>,
}

impl<'conn> SpiClient<'conn> {
/// Prepares a statement that is valid for the lifetime of the client
///
/// # Panics
///
/// This function will panic if the supplied `query` string contained a NULL byte
pub fn prepare(
pub fn prepare<Q: PreparableQuery<'conn>>(
&self,
query: &str,
query: Q,
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>> {
self.make_prepare_statement(query, args, false)
query.prepare(self, args)
}

/// Prepares a mutating statement that is valid for the lifetime of the client
///
/// # Panics
///
/// This function will panic if the supplied `query` string contained a NULL byte
pub fn prepare_mut(
pub fn prepare_mut<Q: PreparableQuery<'conn>>(
&self,
query: &str,
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>> {
self.make_prepare_statement(query, args, true)
}

fn make_prepare_statement(
&self,
query: &str,
query: Q,
args: Option<Vec<PgOid>>,
mutating: bool,
) -> SpiResult<PreparedStatement<'conn>> {
let src = CString::new(query).expect("query contained a null byte");
let args = args.unwrap_or_default();
let nargs = args.len();

// SAFETY: all arguments are prepared above
let plan = unsafe {
pg_sys::SPI_prepare(
src.as_ptr(),
nargs as i32,
args.into_iter().map(PgOid::value).collect::<Vec<_>>().as_mut_ptr(),
)
};
Ok(PreparedStatement {
plan: NonNull::new(plan).ok_or_else(|| {
Spi::check_status(unsafe {
// SAFETY: no concurrent usage
pg_sys::SPI_result
})
.err()
.unwrap()
})?,
__marker: PhantomData,
mutating,
})
query.prepare_mut(self, args)
}

/// perform a SELECT statement
Expand Down
246 changes: 158 additions & 88 deletions pgrx/src/spi/query.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ffi::CString;
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::ops::Deref;
use std::ptr::NonNull;
Expand Down Expand Up @@ -41,32 +41,84 @@ pub trait Query<'conn>: Sized {
) -> SpiResult<SpiCursor<'conn>>;
}

impl<'conn> Query<'conn> for &String {
type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;

fn execute(
/// A trait representing a query which can be prepared.
pub trait PreparableQuery<'conn>: Query<'conn> {
/// Prepares a query.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably seems obvious but can you explain how this distinguishes from Query?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To distinguish from the prepared statements which should not be used by the prepare methods of the SPI client.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not entirely sure what the benefit is?

Copy link
Member

@workingjubilee workingjubilee Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably obvious to you, it's just been ages since I've looked at the Query API and my main dread is it being cemented further, because I would rather it... not exist, tbh. Not in its current form anyway. It's... overly cutesy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem that trait solves is prohibiting PreparedStatement<'conn> from being used with prepare and prepare_mut methods of the SqlClient<'conn>, so the following code will result in a compiler error:

let conn = SpiConnection::connect()?;
let client = conn.client();

let once = client.prepare(c"SELECT 1", None)?;
let twice = client.prepare(once, None);
//                 ------- ^^^^ the trait `PreparableQuery<'_>` is not implemented for `PreparedStatement<'_>`
//                              |
//                              required by a bound introduced by this call

Copy link
Contributor Author

@YohDeadfall YohDeadfall Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played with the CI and observed that my UI test fails only for PostgreSQL versions 12, 13 and 14, but not for others. That's totally weird. See #1870.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...huh.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...oh, I remember. They introduced a type named String in PostgreSQL. Can you defer the test to opening an issue instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

fn prepare(
self,
client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: Self::Arguments,
) -> SpiResult<SpiTupleTable<'conn>> {
self.as_str().execute(client, limit, args)
}
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>>;

fn try_open_cursor(
/// Prepares a query allowed to change data
fn prepare_mut(
self,
client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
self.as_str().try_open_cursor(client, args)
}
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>>;
}

fn prepare_datum(datum: Option<pg_sys::Datum>) -> (pg_sys::Datum, std::os::raw::c_char) {
match datum {
Some(datum) => (datum, ' ' as std::os::raw::c_char),
None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
fn execute<'conn>(
cmd: &CStr,
args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
limit: Option<libc::c_long>,
) -> SpiResult<SpiTupleTable<'conn>> {
// SAFETY: no concurrent access
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}

let status_code = match args {
Some(args) => {
let nargs = args.len();
let (mut argtypes, mut datums, nulls) = args_to_datums(args);

// SAFETY: arguments are prepared above
unsafe {
pg_sys::SPI_execute_with_args(
cmd.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
}
}
// SAFETY: arguments are prepared above
None => unsafe {
pg_sys::SPI_execute(cmd.as_ptr(), Spi::is_xact_still_immutable(), limit.unwrap_or(0))
},
};

SpiClient::prepare_tuple_table(status_code)
}

fn open_cursor<'conn>(
cmd: &CStr,
args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
) -> SpiResult<SpiCursor<'conn>> {
let args = args.unwrap_or_default();
let nargs = args.len();
let (mut argtypes, mut datums, nulls) = args_to_datums(args);

let ptr = unsafe {
// SAFETY: arguments are prepared above and SPI_cursor_open_with_args will never return
// the null pointer. It'll raise an ERROR if something is invalid for it to create the cursor
NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
std::ptr::null_mut(), // let postgres assign a name
cmd.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
0,
))
};

Ok(SpiCursor { ptr, __marker: PhantomData })
}

fn args_to_datums(
Expand All @@ -87,84 +139,102 @@ fn args_to_datums(
(argtypes, datums, nulls)
}

impl<'conn> Query<'conn> for &str {
type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;
fn prepare_datum(datum: Option<pg_sys::Datum>) -> (pg_sys::Datum, std::os::raw::c_char) {
match datum {
Some(datum) => (datum, ' ' as std::os::raw::c_char),
None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
}
}

/// # Panics
///
/// This function will panic if somehow the specified query contains a null byte.
fn execute(
self,
_client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: Self::Arguments,
) -> SpiResult<SpiTupleTable<'conn>> {
// SAFETY: no concurrent access
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
fn prepare<'conn>(
cmd: &CStr,
args: Option<Vec<PgOid>>,
mutating: bool,
) -> SpiResult<PreparedStatement<'conn>> {
let args = args.unwrap_or_default();

// SAFETY: all arguments are prepared above
let plan = unsafe {
pg_sys::SPI_prepare(
cmd.as_ptr(),
args.len() as i32,
args.into_iter().map(PgOid::value).collect::<Vec<_>>().as_mut_ptr(),
)
};
Ok(PreparedStatement {
plan: NonNull::new(plan).ok_or_else(|| {
Spi::check_status(unsafe {
// SAFETY: no concurrent usage
pg_sys::SPI_result
})
.err()
.unwrap()
})?,
__marker: PhantomData,
mutating,
})
}

let src = CString::new(self).expect("query contained a null byte");
let status_code = match args {
Some(args) => {
let nargs = args.len();
let (mut argtypes, mut datums, nulls) = args_to_datums(args);

// SAFETY: arguments are prepared above
unsafe {
pg_sys::SPI_execute_with_args(
src.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
}
macro_rules! impl_prepared_query {
($t:ident, $s:ident) => {
impl<'conn> Query<'conn> for &$t {
type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;

#[inline]
fn execute(
self,
_client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: Self::Arguments,
) -> SpiResult<SpiTupleTable<'conn>> {
execute($s(self).as_ref(), args, limit)
}
// SAFETY: arguments are prepared above
None => unsafe {
pg_sys::SPI_execute(
src.as_ptr(),
Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
},
};

SpiClient::prepare_tuple_table(status_code)
}
#[inline]
fn try_open_cursor(
self,
_client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
open_cursor($s(self).as_ref(), args)
}
}

fn try_open_cursor(
self,
_client: &SpiClient<'conn>,
args: Self::Arguments,
) -> SpiResult<SpiCursor<'conn>> {
let src = CString::new(self).expect("query contained a null byte");
let args = args.unwrap_or_default();
impl<'conn> PreparableQuery<'conn> for &$t {
fn prepare(
self,
_client: &SpiClient<'conn>,
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>> {
prepare($s(self).as_ref(), args, false)
}

let nargs = args.len();
let (mut argtypes, mut datums, nulls) = args_to_datums(args);
fn prepare_mut(
self,
_client: &SpiClient<'conn>,
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>> {
prepare($s(self).as_ref(), args, true)
}
}
};
}

let ptr = unsafe {
// SAFETY: arguments are prepared above and SPI_cursor_open_with_args will never return
// the null pointer. It'll raise an ERROR if something is invalid for it to create the cursor
NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
std::ptr::null_mut(), // let postgres assign a name
src.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
0,
))
};
Ok(SpiCursor { ptr, __marker: PhantomData })
}
#[inline]
fn pass_as_is<T>(s: T) -> T {
s
}

#[inline]
fn pass_with_conv<T: AsRef<str>>(s: T) -> CString {
CString::new(s.as_ref()).expect("query contained a null byte")
}

impl_prepared_query!(CStr, pass_as_is);
impl_prepared_query!(CString, pass_as_is);
impl_prepared_query!(String, pass_with_conv);
impl_prepared_query!(str, pass_with_conv);

/// Client lifetime-bound prepared statement
pub struct PreparedStatement<'conn> {
pub(super) plan: NonNull<pg_sys::_SPI_plan>,
Expand Down
Loading