Skip to content

Commit

Permalink
Use type parameters to specify sampling method.
Browse files Browse the repository at this point in the history
  • Loading branch information
Drew Vogel committed Feb 6, 2024
1 parent f596788 commit 4ec897c
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 82 deletions.
8 changes: 8 additions & 0 deletions diesel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ pub mod dsl {

#[doc(inline)]
pub use diesel_derives::auto_type;

#[cfg(feature = "postgres_backend")]
#[doc(inline)]
pub use crate::pg::expression::extensions::OnlyDsl;

#[cfg(feature = "postgres_backend")]
#[doc(inline)]
pub use crate::pg::expression::extensions::TablesampleDsl;
}

pub mod helper_types {
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/pg/expression/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ mod tablesample_dsl;

pub use self::interval_dsl::IntervalDsl;
pub use self::only_dsl::OnlyDsl;
pub use self::tablesample_dsl::TablesampleDsl;
pub use self::tablesample_dsl::{BernoulliMethod, SystemMethod, TablesampleDsl};
53 changes: 42 additions & 11 deletions diesel/src/pg/expression/extensions/tablesample_dsl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
use crate::query_builder::Tablesample;
pub(crate) use crate::query_builder::{TablesampleMethod, TablesampleSeed};
pub(crate) use crate::query_builder::TablesampleMethod;
use crate::Table;
use std::marker::PhantomData;

#[derive(Clone, Copy, Debug)]
/// Used to specify the `BERNOULLI` sampling method.
pub struct BernoulliMethod;

impl TablesampleMethod for BernoulliMethod {
fn method_name_sql() -> &'static str {
"BERNOULLI"
}
}

#[derive(Clone, Copy, Debug)]
/// Used to specify the `SYSTEM` sampling method.
pub struct SystemMethod;

impl TablesampleMethod for SystemMethod {
fn method_name_sql() -> &'static str {
"SYSTEM"
}
}

/// The `tablesample` method
///
Expand All @@ -11,23 +32,28 @@ use crate::Table;
/// supporting a wide variety of sampling methods.
///
/// Calling this function on a table (`mytable.tablesample(...)`) will result in the SQL
/// `FROM mytable TABLESAMPLE ...`.
/// `FROM mytable TABLESAMPLE ...` --
/// `mytable.tablesample(...)` can be used just like any table in diesel since it implements
/// [Table](crate::Table).
///
/// The `BernoulliMethod` and `SystemMethod` types can be used to indicate the sampling method for
/// a `TABLESAMPLE method(p)` clause where p is specified by the portion argument. The provided
/// percentage should be an integer between 0 and 100.
///
/// If the seed argument is is Some(f) then f becomes the seed in `TABLESAMPLE ... REPEATABLE (f)`.
///
/// Example:
///
/// ```rust
/// # include!("../../../doctest_setup.rs");
/// # use schema::{posts, users};
/// # use diesel::dsl::*;
/// # use crate::pg::query_builder::{TablesampleMethod, TablesampleSeed};
/// # fn main() {
/// # let connection = &mut establish_connection();
/// let random_user_ids = users::table
/// .tablesample(TablesampleMethod::Bernoulli(10), TablesampleSeed::Auto)
/// .tablesample::<BernoulliMethod>(10, None)
/// .select((users::id))
/// .load::<i64>(connection);
/// .load::<i32>(connection);
/// # }
/// ```
/// Selects the ids for a random 10 percent of users.
Expand All @@ -38,25 +64,30 @@ use crate::Table;
/// # include!("../../../doctest_setup.rs");
/// # use schema::{posts, users};
/// # use diesel::dsl::*;
/// # use crate::query_builder::{TablesampleMethod, TablesampleSeed};
/// # fn main() {
/// # let connection = &mut establish_connection();
/// # let _ =
/// users::table
/// .tablesample(TablesampleMethod::Bernoulli(10), TablesampleSeed::Auto)
/// .inner_join(posts::table.only())
/// .tablesample::<BernoulliMethod>(10, Some(42.0))
/// .inner_join(posts::table)
/// .select((users::name, posts::title))
/// .load::<(String, String)>(connection);
/// # }
/// ```
/// That query selects all of the posts for a random 10 percent of users.
/// That query selects all of the posts for a random 10 percent of users, returning the same
/// results each time it is run due to the static seed of 42.0.
///
pub trait TablesampleDsl: Table {
/// See the trait-level docs.
fn tablesample(self, method: TablesampleMethod, seed: TablesampleSeed) -> Tablesample<Self> {
fn tablesample<TSM: TablesampleMethod>(
self,
portion: i16,
seed: Option<f64>,
) -> Tablesample<Self, TSM> {
Tablesample {
source: self,
method,
method: PhantomData,
portion,
seed,
}
}
Expand Down
97 changes: 39 additions & 58 deletions diesel/src/pg/query_builder/tablesample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,38 @@ use crate::{
sql_types::{Double, SmallInt},
JoinTo, SelectableExpression, Table,
};
use std::marker::PhantomData;

/// Indicates the sampling method for a `TABLESAMPLE method(n)` clause. The provided percentage
/// should be an integer between 0 and 100.
#[derive(Debug, Clone, Copy)]
pub enum TablesampleMethod {
/// Use the BERNOULLI sampline method. This is row-based, slower but more accurate.
Bernoulli(i16),

/// Use the SYSTEM sampling method. This is page-based, faster but less accurate.
System(i16),
}

/// Indicates the random number seed for a `TABLESAMPLE ... REPEATABLE(f)` clause.
#[derive(Debug, Clone, Copy)]
pub enum TablesampleSeed {
/// Have PostgreSQL generate an implied random number generator seed.
Auto,

/// Provide your own random number generator seed.
Repeatable(f64),
#[doc(hidden)]
pub trait TablesampleMethod: Clone {
fn method_name_sql() -> &'static str;
}

/// Represents a query with a `TABLESAMPLE` clause.
#[doc(hidden)]
#[derive(Debug, Clone, Copy)]
pub struct Tablesample<S> {
pub struct Tablesample<S, TSM> {
pub source: S,
pub method: TablesampleMethod,
pub seed: TablesampleSeed,
pub method: PhantomData<TSM>,
pub portion: i16,
pub seed: Option<f64>,
}

impl<S> QueryId for Tablesample<S>
impl<S, TSM> QueryId for Tablesample<S, TSM>
where
S: QueryId,
TSM: TablesampleMethod,
{
type QueryId = ();
const HAS_STATIC_QUERY_ID: bool = false;
}

impl<S> QuerySource for Tablesample<S>
impl<S, TSM> QuerySource for Tablesample<S, TSM>
where
S: Table + Clone,
<S as QuerySource>::DefaultSelection: ValidGrouping<()> + SelectableExpression<Tablesample<S>>,
TSM: TablesampleMethod,
<S as QuerySource>::DefaultSelection:
ValidGrouping<()> + SelectableExpression<Tablesample<S, TSM>>,
{
type FromClause = Self;
type DefaultSelection = <S as QuerySource>::DefaultSelection;
Expand All @@ -63,41 +52,33 @@ where
}
}

impl<S> QueryFragment<Pg> for Tablesample<S>
impl<S, TSM> QueryFragment<Pg> for Tablesample<S, TSM>
where
S: QueryFragment<Pg>,
TSM: TablesampleMethod,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.source.walk_ast(out.reborrow())?;
out.push_sql(" TABLESAMPLE ");
match &self.method {
TablesampleMethod::Bernoulli(p) => {
out.push_sql("BERNOULLI(");
out.push_bind_param::<SmallInt, _>(p)?;
out.push_sql(")");
}
TablesampleMethod::System(p) => {
out.push_sql("SYSTEM(");
out.push_bind_param::<SmallInt, _>(p)?;
out.push_sql(")");
}
};
match &self.seed {
TablesampleSeed::Auto => { /* no-op, this is the default */ }
TablesampleSeed::Repeatable(f) => {
out.push_sql(" REPEATABLE(");
out.push_bind_param::<Double, _>(f)?;
out.push_sql(")");
}
out.push_sql(TSM::method_name_sql());
out.push_sql("(");
out.push_bind_param::<SmallInt, _>(&self.portion)?;
out.push_sql(")");
if let Some(f) = &self.seed {
out.push_sql(" REPEATABLE(");
out.push_bind_param::<Double, _>(f)?;
out.push_sql(")");
}
Ok(())
}
}

impl<S> AsQuery for Tablesample<S>
impl<S, TSM> AsQuery for Tablesample<S, TSM>
where
S: Table + Clone,
<S as QuerySource>::DefaultSelection: ValidGrouping<()> + SelectableExpression<Tablesample<S>>,
TSM: TablesampleMethod,
<S as QuerySource>::DefaultSelection:
ValidGrouping<()> + SelectableExpression<Tablesample<S, TSM>>,
{
type SqlType = <<Self as QuerySource>::DefaultSelection as Expression>::SqlType;
type Query = SelectStatement<FromClause<Self>>;
Expand All @@ -106,11 +87,12 @@ where
}
}

impl<S, T> JoinTo<T> for Tablesample<S>
impl<S, T, TSM> JoinTo<T> for Tablesample<S, TSM>
where
S: JoinTo<T>,
T: Table,
S: Table,
TSM: TablesampleMethod,
{
type FromClause = <S as JoinTo<T>>::FromClause;
type OnClause = <S as JoinTo<T>>::OnClause;
Expand All @@ -120,13 +102,15 @@ where
}
}

impl<S> Table for Tablesample<S>
impl<S, TSM> Table for Tablesample<S, TSM>
where
S: Table + Clone + AsQuery,
TSM: TablesampleMethod,

<S as Table>::PrimaryKey: SelectableExpression<Tablesample<S>>,
<S as Table>::AllColumns: SelectableExpression<Tablesample<S>>,
<S as QuerySource>::DefaultSelection: ValidGrouping<()> + SelectableExpression<Tablesample<S>>,
<S as Table>::PrimaryKey: SelectableExpression<Tablesample<S, TSM>>,
<S as Table>::AllColumns: SelectableExpression<Tablesample<S, TSM>>,
<S as QuerySource>::DefaultSelection:
ValidGrouping<()> + SelectableExpression<Tablesample<S, TSM>>,
{
type PrimaryKey = <S as Table>::PrimaryKey;
type AllColumns = <S as Table>::AllColumns;
Expand Down Expand Up @@ -169,20 +153,17 @@ mod test {
#[test]
fn test_generated_tablesample_sql() {
assert_sql!(
users::table.tablesample(TablesampleMethod::Bernoulli(10), TablesampleSeed::Auto),
users::table.tablesample::<BernoulliMethod>(10, None),
"\"users\" TABLESAMPLE BERNOULLI($1)"
);

assert_sql!(
users::table.tablesample(TablesampleMethod::System(10), TablesampleSeed::Auto),
users::table.tablesample::<SystemMethod>(10, None),
"\"users\" TABLESAMPLE SYSTEM($1)"
);

assert_sql!(
users::table.tablesample(
TablesampleMethod::System(10),
TablesampleSeed::Repeatable(42.0),
),
users::table.tablesample::<SystemMethod>(10, Some(42.0),),
"\"users\" TABLESAMPLE SYSTEM($1) REPEATABLE($2)"
);
}
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/query_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ pub(crate) use self::insert_statement::ColumnList;
pub use crate::pg::query_builder::only::Only;

#[cfg(feature = "postgres_backend")]
pub use crate::pg::query_builder::tablesample::{Tablesample, TablesampleMethod, TablesampleSeed};
pub use crate::pg::query_builder::tablesample::{Tablesample, TablesampleMethod};

use crate::backend::Backend;
use crate::result::QueryResult;
Expand Down
29 changes: 18 additions & 11 deletions diesel_derives/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,27 +173,32 @@ pub(crate) fn expand(input: TableDecl) -> TokenStream {
type Count = diesel::query_source::Once;
}

impl<S> diesel::JoinTo<diesel::query_builder::Tablesample<S>> for table
impl<S, TSM> diesel::JoinTo<diesel::query_builder::Tablesample<S, TSM>> for table
where
diesel::query_builder::Tablesample<S>: diesel::JoinTo<table>,
diesel::query_builder::Tablesample<S, TSM>: diesel::JoinTo<table>,
TSM: diesel::query_builder::TablesampleMethod
{
type FromClause = diesel::query_builder::Tablesample<S>;
type OnClause = <diesel::query_builder::Tablesample<S> as diesel::JoinTo<table>>::OnClause;
type FromClause = diesel::query_builder::Tablesample<S, TSM>;
type OnClause = <diesel::query_builder::Tablesample<S, TSM> as diesel::JoinTo<table>>::OnClause;

fn join_target(__diesel_internal_rhs: diesel::query_builder::Tablesample<S>) -> (Self::FromClause, Self::OnClause) {
let (_, __diesel_internal_on_clause) = diesel::query_builder::Tablesample::<S>::join_target(table);
fn join_target(__diesel_internal_rhs: diesel::query_builder::Tablesample<S, TSM>) -> (Self::FromClause, Self::OnClause) {
let (_, __diesel_internal_on_clause) = diesel::query_builder::Tablesample::<S, TSM>::join_target(table);
(__diesel_internal_rhs, __diesel_internal_on_clause)
}
}

impl diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<table>>
impl<TSM> diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<table, TSM>>
for table
where
TSM: diesel::query_builder::TablesampleMethod
{
type Count = diesel::query_source::Once;
}

impl diesel::query_source::AppearsInFromClause<table>
for diesel::query_builder::Tablesample<table>
impl<TSM> diesel::query_source::AppearsInFromClause<table>
for diesel::query_builder::Tablesample<table, TSM>
where
TSM: diesel::query_builder::TablesampleMethod
{
type Count = diesel::query_source::Once;
}
Expand Down Expand Up @@ -693,12 +698,14 @@ fn expand_column_def(column_def: &ColumnDef) -> TokenStream {
}
impl diesel::SelectableExpression<diesel::query_builder::Only<super::table>> for #column_name {}

impl diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<super::table>>
impl<TSM> diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<super::table, TSM>>
for #column_name
where TSM: diesel::query_builder::TablesampleMethod
{
type Count = diesel::query_source::Once;
}
impl diesel::SelectableExpression<diesel::query_builder::Tablesample<super::table>> for #column_name {}
impl<TSM> diesel::SelectableExpression<diesel::query_builder::Tablesample<super::table, TSM>>
for #column_name where TSM: diesel::query_builder::TablesampleMethod {}
})
} else {
None
Expand Down

0 comments on commit 4ec897c

Please sign in to comment.