Skip to content

Commit

Permalink
Refactor the string_or_struct logic and enables it for TermQuery. (#3906
Browse files Browse the repository at this point in the history
)

* Refactor the string_or_struct logic and enables it for TermQuery.

Closes #3900

* Allowing integers in term queries.
  • Loading branch information
fulmicoton authored Oct 3, 2023
1 parent 8d89c49 commit 91540c6
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@

use serde::Deserialize;

use crate::elastic_query_dsl::match_query::{MatchQueryParams, MatchQueryParamsForDeserialization};
use super::StringOrStructForSerialization;
use crate::elastic_query_dsl::match_query::MatchQueryParams;
use crate::elastic_query_dsl::{default_max_expansions, ConvertableToQueryAst};
use crate::query_ast::{FullTextParams, FullTextQuery, QueryAst};
use crate::OneFieldMap;

/// `MatchBoolPrefixQuery` as defined in
/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-bool-prefix-query.html>
#[derive(Deserialize, Clone, Eq, PartialEq, Debug)]
#[serde(
from = "OneFieldMap<MatchQueryParamsForDeserialization>",
into = "OneFieldMap<MatchQueryParams>"
)]
#[serde(from = "OneFieldMap<StringOrStructForSerialization<MatchQueryParams>>")]
pub(crate) struct MatchBoolPrefixQuery {
pub(crate) field: String,
pub(crate) params: MatchQueryParams,
Expand All @@ -54,8 +52,10 @@ impl ConvertableToQueryAst for MatchBoolPrefixQuery {
}
}

impl From<OneFieldMap<MatchQueryParamsForDeserialization>> for MatchBoolPrefixQuery {
fn from(match_query_params: OneFieldMap<MatchQueryParamsForDeserialization>) -> Self {
impl From<OneFieldMap<StringOrStructForSerialization<MatchQueryParams>>> for MatchBoolPrefixQuery {
fn from(
match_query_params: OneFieldMap<StringOrStructForSerialization<MatchQueryParams>>,
) -> Self {
let OneFieldMap { field, value } = match_query_params;
MatchBoolPrefixQuery {
field,
Expand Down
77 changes: 14 additions & 63 deletions quickwit/quickwit-query/src/elastic_query_dsl/match_phrase_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,16 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use std::fmt;
use serde::Deserialize;

use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};

use crate::elastic_query_dsl::ConvertableToQueryAst;
use crate::elastic_query_dsl::{ConvertableToQueryAst, StringOrStructForSerialization};
use crate::query_ast::{FullTextMode, FullTextParams, FullTextQuery, QueryAst};
use crate::{MatchAllOrNone, OneFieldMap};

/// `MatchPhraseQuery` as defined in
/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query-phrase.html>
#[derive(Deserialize, Clone, Eq, PartialEq, Debug)]
#[serde(
from = "OneFieldMap<MatchPhraseQueryParamsForDeserialization>",
into = "OneFieldMap<MatchPhraseQueryParams>"
)]
#[serde(from = "OneFieldMap<StringOrStructForSerialization<MatchPhraseQueryParams>>")]
pub(crate) struct MatchPhraseQuery {
pub(crate) field: String,
pub(crate) params: MatchPhraseQueryParams,
Expand Down Expand Up @@ -67,36 +61,12 @@ impl ConvertableToQueryAst for MatchPhraseQuery {
}
}

// --------------
//
// Below is the Deserialization code
// The difficulty here is to support the two following formats:
//
// `{"field": {"query": "my query", "default_operator": "OR"}}`
// `{"field": "my query"}`
//
// We don't use untagged enum to support this, in order to keep good errors.
//
// The code below is adapted from solution described here: https://serde.rs/string-or-struct.html

#[derive(Deserialize)]
#[serde(transparent)]
struct MatchPhraseQueryParamsForDeserialization {
#[serde(deserialize_with = "string_or_struct")]
inner: MatchPhraseQueryParams,
}

impl From<MatchPhraseQuery> for OneFieldMap<MatchPhraseQueryParams> {
fn from(match_phrase_query: MatchPhraseQuery) -> OneFieldMap<MatchPhraseQueryParams> {
OneFieldMap {
field: match_phrase_query.field,
value: match_phrase_query.params,
}
}
}

impl From<OneFieldMap<MatchPhraseQueryParamsForDeserialization>> for MatchPhraseQuery {
fn from(match_query_params: OneFieldMap<MatchPhraseQueryParamsForDeserialization>) -> Self {
impl From<OneFieldMap<StringOrStructForSerialization<MatchPhraseQueryParams>>>
for MatchPhraseQuery
{
fn from(
match_query_params: OneFieldMap<StringOrStructForSerialization<MatchPhraseQueryParams>>,
) -> Self {
let OneFieldMap { field, value } = match_query_params;
MatchPhraseQuery {
field,
Expand All @@ -105,36 +75,17 @@ impl From<OneFieldMap<MatchPhraseQueryParamsForDeserialization>> for MatchPhrase
}
}

struct MatchQueryParamsStringOrStructVisitor;

impl<'de> Visitor<'de> for MatchQueryParamsStringOrStructVisitor {
type Value = MatchPhraseQueryParams;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string or map containing the parameters of a match query.")
}

fn visit_str<E>(self, query: &str) -> Result<Self::Value, E>
where E: serde::de::Error {
Ok(MatchPhraseQueryParams {
query: query.to_string(),
impl From<String> for MatchPhraseQueryParams {
fn from(query: String) -> MatchPhraseQueryParams {
MatchPhraseQueryParams {
query,
zero_terms_query: Default::default(),
analyzer: None,
slop: 0,
})
}

fn visit_map<M>(self, map: M) -> Result<MatchPhraseQueryParams, M::Error>
where M: MapAccess<'de> {
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
}
}
}

fn string_or_struct<'de, D>(deserializer: D) -> Result<MatchPhraseQueryParams, D::Error>
where D: Deserializer<'de> {
deserializer.deserialize_any(MatchQueryParamsStringOrStructVisitor)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
77 changes: 14 additions & 63 deletions quickwit/quickwit-query/src/elastic_query_dsl/match_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,18 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use std::fmt;
use serde::Deserialize;

use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};

use crate::elastic_query_dsl::{ConvertableToQueryAst, ElasticQueryDslInner};
use crate::elastic_query_dsl::{
ConvertableToQueryAst, ElasticQueryDslInner, StringOrStructForSerialization,
};
use crate::query_ast::{FullTextParams, FullTextQuery, QueryAst};
use crate::{BooleanOperand, MatchAllOrNone, OneFieldMap};

/// `MatchQuery` as defined in
/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html>
#[derive(Deserialize, Clone, Eq, PartialEq, Debug)]
#[serde(
from = "OneFieldMap<MatchQueryParamsForDeserialization>",
into = "OneFieldMap<MatchQueryParams>"
)]
#[serde(from = "OneFieldMap<StringOrStructForSerialization<MatchQueryParams>>")]
pub struct MatchQuery {
pub(crate) field: String,
pub(crate) params: MatchQueryParams,
Expand Down Expand Up @@ -74,36 +70,10 @@ impl From<MatchQuery> for ElasticQueryDslInner {
}
}

// --------------
//
// Below is the Deserialization code
// The difficulty here is to support the two following formats:
//
// `{"field": {"query": "my query", "default_operator": "OR"}}`
// `{"field": "my query"}`
//
// We don't use untagged enum to support this, in order to keep good errors.
//
// The code below is adapted from solution described here: https://serde.rs/string-or-struct.html

#[derive(Deserialize)]
#[serde(transparent)]
pub(crate) struct MatchQueryParamsForDeserialization {
#[serde(deserialize_with = "string_or_struct")]
pub(crate) inner: MatchQueryParams,
}

impl From<MatchQuery> for OneFieldMap<MatchQueryParams> {
fn from(match_query: MatchQuery) -> OneFieldMap<MatchQueryParams> {
OneFieldMap {
field: match_query.field,
value: match_query.params,
}
}
}

impl From<OneFieldMap<MatchQueryParamsForDeserialization>> for MatchQuery {
fn from(match_query_params: OneFieldMap<MatchQueryParamsForDeserialization>) -> Self {
impl From<OneFieldMap<StringOrStructForSerialization<MatchQueryParams>>> for MatchQuery {
fn from(
match_query_params: OneFieldMap<StringOrStructForSerialization<MatchQueryParams>>,
) -> Self {
let OneFieldMap { field, value } = match_query_params;
MatchQuery {
field,
Expand All @@ -112,36 +82,17 @@ impl From<OneFieldMap<MatchQueryParamsForDeserialization>> for MatchQuery {
}
}

struct MatchQueryParamsStringOrStructVisitor;

impl<'de> Visitor<'de> for MatchQueryParamsStringOrStructVisitor {
type Value = MatchQueryParams;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string or map containing the parameters of a match query.")
}

fn visit_str<E>(self, query: &str) -> Result<Self::Value, E>
where E: serde::de::Error {
Ok(MatchQueryParams {
query: query.to_string(),
impl From<String> for MatchQueryParams {
fn from(query: String) -> MatchQueryParams {
MatchQueryParams {
query,
zero_terms_query: Default::default(),
operator: Default::default(),
_lenient: false,
})
}

fn visit_map<M>(self, map: M) -> Result<MatchQueryParams, M::Error>
where M: MapAccess<'de> {
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
}
}
}

fn string_or_struct<'de, D>(deserializer: D) -> Result<MatchQueryParams, D::Error>
where D: Deserializer<'de> {
deserializer.deserialize_any(MatchQueryParamsStringOrStructVisitor)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 2 additions & 0 deletions quickwit/quickwit-query/src/elastic_query_dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod one_field_map;
mod phrase_prefix_query;
mod query_string_query;
mod range_query;
mod string_or_struct;
mod term_query;
mod terms_query;

Expand All @@ -37,6 +38,7 @@ pub use one_field_map::OneFieldMap;
use phrase_prefix_query::MatchPhrasePrefixQuery;
pub(crate) use query_string_query::QueryStringQuery;
use range_query::RangeQuery;
pub(crate) use string_or_struct::StringOrStructForSerialization;
use term_query::TermQuery;

use crate::elastic_query_dsl::exists_query::ExistsQuery;
Expand Down
95 changes: 95 additions & 0 deletions quickwit/quickwit-query/src/elastic_query_dsl/string_or_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (C) 2023 Quickwit, Inc.
//
// Quickwit is offered under the AGPL v3.0 and as commercial software.
// For commercial licensing, contact us at hello@quickwit.io.
//
// AGPL:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use std::fmt;
use std::marker::PhantomData;

use serde::de::{MapAccess, Visitor};
use serde::{de, Deserialize, Deserializer};

/// The point of `StringOrStructForSerialization` is to support
/// the two following formats for various queries.
///
/// `{"field": {"query": "my query", "default_operator": "OR"}}`
///
/// and the shorter.
/// `{"field": "my query"}`
///
/// If a integer is passed, we cast it to string. Floats are not supported.
///
/// We don't use untagged enum to support this, in order to keep good errors.
///
/// The code below is adapted from solution described here: https://serde.rs/string-or-struct.html
#[derive(Deserialize)]
#[serde(transparent)]
pub(crate) struct StringOrStructForSerialization<T>
where
T: From<String>,
for<'de2> T: Deserialize<'de2>,
{
#[serde(deserialize_with = "string_or_struct")]
pub inner: T,
}

struct StringOrStructVisitor<T> {
phantom_data: PhantomData<T>,
}

fn string_or_struct<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: From<String> + Deserialize<'de>,
{
deserializer.deserialize_any(StringOrStructVisitor {
phantom_data: Default::default(),
})
}

impl<'de, T> Visitor<'de> for StringOrStructVisitor<T>
where
T: From<String>,
T: Deserialize<'de>,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
let type_str = std::any::type_name::<T>();
formatter.write_str(&format!("string or map to deserialize {type_str}."))
}

fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where E: de::Error {
self.visit_str(&v.to_string())
}

fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where E: de::Error {
self.visit_str(&v.to_string())
}

fn visit_str<E>(self, query: &str) -> Result<Self::Value, E>
where E: serde::de::Error {
Ok(T::from(query.to_string()))
}

fn visit_map<M>(self, map: M) -> Result<T, M::Error>
where M: MapAccess<'de> {
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
}
}
Loading

0 comments on commit 91540c6

Please sign in to comment.