diff --git a/diesel/src/expression/array_comparison.rs b/diesel/src/expression/array_comparison.rs index a4b9461d0ba3..cdbb91aa832e 100644 --- a/diesel/src/expression/array_comparison.rs +++ b/diesel/src/expression/array_comparison.rs @@ -27,7 +27,10 @@ use std::marker::PhantomData; /// `IN` expression. /// /// The postgres backend provided a specialized implementation -/// by using `left = ANY(values)` as optimized variant instead. +/// by using `left = ANY(values)` as optimized variant instead +/// if this is possible. For cases where this is not possible +/// like for example if values is a vector of arrays we +/// generate an ordinary `IN` expression instead. #[derive(Debug, Copy, Clone, QueryId, ValidGrouping)] #[non_exhaustive] pub struct In { @@ -47,7 +50,10 @@ pub struct In { /// `NOT IN` expression.0 /// /// The postgres backend provided a specialized implementation -/// by using `left = ALL(values)` as optimized variant instead. +/// by using `left != ALL(values)` as optimized variant instead +/// if this is possible. For cases where this is not possible +/// like for example if values is a vector of arrays we +/// generate a ordinary `NOT IN` expression instead #[derive(Debug, Copy, Clone, QueryId, ValidGrouping)] #[non_exhaustive] pub struct NotIn { @@ -61,12 +67,46 @@ impl In { pub(crate) fn new(left: T, values: U) -> Self { In { left, values } } + + pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> + where + DB: Backend, + T: QueryFragment, + U: QueryFragment + InExpression, + { + if self.values.is_empty() { + out.push_sql("1=0"); + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" IN ("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } + Ok(()) + } } impl NotIn { pub(crate) fn new(left: T, values: U) -> Self { NotIn { left, values } } + + pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> + where + DB: Backend, + T: QueryFragment, + U: QueryFragment + InExpression, + { + if self.values.is_empty() { + out.push_sql("1=1"); + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" NOT IN ("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } + Ok(()) + } } impl Expression for In @@ -114,16 +154,8 @@ where T: QueryFragment, U: QueryFragment + InExpression, { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { - if self.values.is_empty() { - out.push_sql("1=0"); - } else { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" IN ("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); - } - Ok(()) + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.walk_ansi_ast(out) } } @@ -145,16 +177,8 @@ where T: QueryFragment, U: QueryFragment + InExpression, { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { - if self.values.is_empty() { - out.push_sql("1=1"); - } else { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" NOT IN ("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); - } - Ok(()) + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.walk_ansi_ast(out) } } @@ -217,6 +241,10 @@ pub trait InExpression { /// Returns `true` if self represents an empty collection /// Otherwise `false` is returned. fn is_empty(&self) -> bool; + + /// Returns `true` if the values clause represents + /// bind values and each bind value is a postgres array type + fn is_array(&self) -> bool; } impl AsInExpression @@ -306,6 +334,10 @@ where fn is_empty(&self) -> bool { self.values.is_empty() } + + fn is_array(&self) -> bool { + ST::IS_ARRAY + } } impl SelectableExpression for Many @@ -345,7 +377,18 @@ where ST: SingleValue, I: ToSql, { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.walk_ansi_ast(out) + } +} + +impl Many { + pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> + where + DB: Backend + HasSqlType, + ST: SingleValue, + I: ToSql, + { out.unsafe_to_cache_prepared(); let mut first = true; for value in &self.values { diff --git a/diesel/src/expression/subselect.rs b/diesel/src/expression/subselect.rs index 25c98ea0c5a1..de30ca43040d 100644 --- a/diesel/src/expression/subselect.rs +++ b/diesel/src/expression/subselect.rs @@ -42,6 +42,9 @@ impl InExpression for Subselect { fn is_empty(&self) -> bool { false } + fn is_array(&self) -> bool { + false + } } impl SelectableExpression for Subselect diff --git a/diesel/src/expression_methods/global_expression_methods.rs b/diesel/src/expression_methods/global_expression_methods.rs index 893867a0bf8f..6a0832d317b6 100644 --- a/diesel/src/expression_methods/global_expression_methods.rs +++ b/diesel/src/expression_methods/global_expression_methods.rs @@ -107,7 +107,9 @@ pub trait ExpressionMethods: Expression + Sized { /// query will use the cache (assuming the subquery /// itself is safe to cache). /// On PostgreSQL, this method automatically performs a `= ANY()` - /// query. + /// query if this is possible. For cases where this is not possible + /// like for example if values is a vector of arrays we + /// generate an ordinary `IN` expression instead. /// /// # Example /// @@ -149,7 +151,10 @@ pub trait ExpressionMethods: Expression + Sized { /// /// Queries using this method will not be /// placed in the prepared statement cache. On PostgreSQL, this - /// method automatically performs a `!= ALL()` query. + /// method automatically performs a `!= ALL()` query if this is possible. + /// For cases where this is not possible + /// like for example if values is a vector of arrays we + /// generate an ordinary `NOT IN` expression instead. /// /// # Example /// diff --git a/diesel/src/pg/expression/array.rs b/diesel/src/pg/expression/array.rs index 7e08c2c929bd..e9bea41984d7 100644 --- a/diesel/src/pg/expression/array.rs +++ b/diesel/src/pg/expression/array.rs @@ -178,9 +178,15 @@ where ST: SqlType, { type SqlType = ST; + fn is_empty(&self) -> bool { false } + + fn is_array(&self) -> bool { + // we want to use the `= ANY(_)` syntax + false + } } impl AsInExpression for ArrayLiteral @@ -189,6 +195,7 @@ where ST: SqlType, { type InExpression = Self; + fn as_in_expression(self) -> Self::InExpression { self } @@ -296,9 +303,15 @@ where ST: SqlType, { type SqlType = ST; + fn is_empty(&self) -> bool { false } + + fn is_array(&self) -> bool { + // we want to use the `= ANY(_)` syntax + false + } } impl AsInExpression for ArraySubselect @@ -307,6 +320,7 @@ where ST: SqlType, { type InExpression = Self; + fn as_in_expression(self) -> Self::InExpression { self } diff --git a/diesel/src/pg/query_builder/query_fragment_impls.rs b/diesel/src/pg/query_builder/query_fragment_impls.rs index 224977569fab..bfc44f18f748 100644 --- a/diesel/src/pg/query_builder/query_fragment_impls.rs +++ b/diesel/src/pg/query_builder/query_fragment_impls.rs @@ -66,10 +66,14 @@ where U: QueryFragment + InExpression, { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" = ANY("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); + if self.values.is_array() { + self.walk_ansi_ast(out)?; + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" = ANY("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } Ok(()) } } @@ -80,10 +84,14 @@ where U: QueryFragment + InExpression, { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" != ALL("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); + if self.values.is_array() { + self.walk_ansi_ast(out)?; + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" != ALL("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } Ok(()) } } @@ -92,10 +100,15 @@ impl QueryFragment for Many where ST: SingleValue, Vec: ToSql, Pg>, + I: ToSql, Pg: HasSqlType, { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { - out.push_bind_param::, Vec>(&self.values) + if ST::IS_ARRAY { + self.walk_ansi_ast(out) + } else { + out.push_bind_param::, Vec>(&self.values) + } } } diff --git a/diesel/src/sql_types/mod.rs b/diesel/src/sql_types/mod.rs index 1aece629c942..84e330adb75f 100644 --- a/diesel/src/sql_types/mod.rs +++ b/diesel/src/sql_types/mod.rs @@ -676,6 +676,9 @@ pub trait SqlType: 'static { /// /// ['is_nullable`]: is_nullable type IsNull: OneIsNullable + OneIsNullable; + + #[doc(hidden)] + const IS_ARRAY: bool = false; } /// Is one value of `IsNull` nullable? diff --git a/diesel_derives/src/sql_type.rs b/diesel_derives/src/sql_type.rs index 2e6f00d65464..799f703ef1a0 100644 --- a/diesel_derives/src/sql_type.rs +++ b/diesel_derives/src/sql_type.rs @@ -11,18 +11,23 @@ pub fn derive(item: DeriveInput) -> Result { let model = Model::from_item(&item, true, false)?; let struct_name = &item.ident; + let generic_count = item.generics.params.len(); let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); let sqlite_tokens = sqlite_tokens(&item, &model); let mysql_tokens = mysql_tokens(&item, &model); let pg_tokens = pg_tokens(&item, &model); + let is_array = struct_name == "Array" && generic_count == 1; + Ok(wrap_in_dummy_mod(quote! { impl #impl_generics diesel::sql_types::SqlType for #struct_name #ty_generics #where_clause { type IsNull = diesel::sql_types::is_nullable::NotNull; + + const IS_ARRAY: bool = #is_array; } impl #impl_generics diesel::sql_types::SingleValue diff --git a/diesel_tests/tests/filter_operators.rs b/diesel_tests/tests/filter_operators.rs index 73f627dc123d..f7a0b8a768a9 100644 --- a/diesel_tests/tests/filter_operators.rs +++ b/diesel_tests/tests/filter_operators.rs @@ -323,6 +323,36 @@ fn filter_by_in_explicit_array() { ); } +#[test] +#[cfg(feature = "postgres")] +fn filter_array_by_in() { + use crate::schema::posts::dsl::*; + + let connection: &mut PgConnection = &mut connection(); + let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]]; + let result: Vec = posts + .filter(tags.eq_any(tag_combinations_to_look_for)) + .select(id) + .load(connection) + .unwrap(); + assert_eq!(result, &[] as &[i32]); +} + +#[test] +#[cfg(feature = "postgres")] +fn filter_array_by_not_in() { + use crate::schema::posts::dsl::*; + + let connection: &mut PgConnection = &mut connection(); + let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]]; + let result: Vec = posts + .filter(tags.ne_all(tag_combinations_to_look_for)) + .select(id) + .load(connection) + .unwrap(); + assert_eq!(result, &[] as &[i32]); +} + fn connection_with_3_users() -> TestConnection { let mut connection = connection_with_sean_and_tess_in_users_table(); diesel::sql_query("INSERT INTO users (id, name) VALUES (3, 'Jim')")