Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct typing hints for the FunctionScore query #1960

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions elasticsearch_dsl/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ class FunctionScore(Query):

name = "function_score"
_param_defs = {
"functions": {"type": "score_function", "multi": True},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, how does this line get added, since we removed it in query.py.tpl?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed a bit "magical" to me as well when I noticed it and made me smile. It is all done by logic that is already present in the generator.

In the original version before this fix I had to put this definition in the template, because the code generator did not attempt to do anything with function scores.

But with this fix the generator recognizes FunctionScoreContainer and returns a DSL type for it. The return value has two components, the typing hint and the DSL type:

        return "ScoreFunction", {"type": "score_function"}

The first returned value is used in type hints. The second one (when given, since it is optional) is added to the _param_defs dictionary of the class by the generator itself. After this type is seen the generator will likely see an array_of that wraps this type, and that adds Sequence[...] to the type hint, and "multi": True to the DSL type.

The intention is that at some point all these will be generated, but there are still some that are manually injected through the templates.

"query": {"type": "query"},
"filter": {"type": "query"},
"functions": {"type": "score_function", "multi": True},
}

def __init__(
Expand All @@ -623,11 +623,7 @@ def __init__(
boost_mode: Union[
Literal["multiply", "replace", "sum", "avg", "max", "min"], "DefaultType"
] = DEFAULT,
functions: Union[
Sequence["types.FunctionScoreContainer"],
Sequence[Dict[str, Any]],
"DefaultType",
] = DEFAULT,
functions: Union[Sequence[ScoreFunction], "DefaultType"] = DEFAULT,
max_boost: Union[float, "DefaultType"] = DEFAULT,
min_score: Union[float, "DefaultType"] = DEFAULT,
query: Union[Query, "DefaultType"] = DEFAULT,
Expand Down
70 changes: 1 addition & 69 deletions elasticsearch_dsl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl import Query, function
from elasticsearch_dsl import Query
from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl.utils import AttrDict

Expand Down Expand Up @@ -688,74 +688,6 @@ def __init__(
super().__init__(kwargs)


class FunctionScoreContainer(AttrDict[Any]):
"""
:arg exp: Function that scores a document with a exponential decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg gauss: Function that scores a document with a normal decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg linear: Function that scores a document with a linear decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg field_value_factor: Function allows you to use a field from a
document to influence the score. It’s similar to using the
script_score function, however, it avoids the overhead of
scripting.
:arg random_score: Generates scores that are uniformly distributed
from 0 up to but not including 1. In case you want scores to be
reproducible, it is possible to provide a `seed` and `field`.
:arg script_score: Enables you to wrap another query and customize the
scoring of it optionally with a computation derived from other
numeric field values in the doc using a script expression.
:arg filter:
:arg weight:
"""

exp: Union[function.DecayFunction, DefaultType]
gauss: Union[function.DecayFunction, DefaultType]
linear: Union[function.DecayFunction, DefaultType]
field_value_factor: Union[function.FieldValueFactorScore, DefaultType]
random_score: Union[function.RandomScore, DefaultType]
script_score: Union[function.ScriptScore, DefaultType]
filter: Union[Query, DefaultType]
weight: Union[float, DefaultType]

def __init__(
self,
*,
exp: Union[function.DecayFunction, DefaultType] = DEFAULT,
gauss: Union[function.DecayFunction, DefaultType] = DEFAULT,
linear: Union[function.DecayFunction, DefaultType] = DEFAULT,
field_value_factor: Union[
function.FieldValueFactorScore, DefaultType
] = DEFAULT,
random_score: Union[function.RandomScore, DefaultType] = DEFAULT,
script_score: Union[function.ScriptScore, DefaultType] = DEFAULT,
filter: Union[Query, DefaultType] = DEFAULT,
weight: Union[float, DefaultType] = DEFAULT,
**kwargs: Any,
):
if exp is not DEFAULT:
kwargs["exp"] = exp
if gauss is not DEFAULT:
kwargs["gauss"] = gauss
if linear is not DEFAULT:
kwargs["linear"] = linear
if field_value_factor is not DEFAULT:
kwargs["field_value_factor"] = field_value_factor
if random_score is not DEFAULT:
kwargs["random_score"] = random_score
if script_score is not DEFAULT:
kwargs["script_score"] = script_score
if filter is not DEFAULT:
kwargs["filter"] = filter
if weight is not DEFAULT:
kwargs["weight"] = weight
super().__init__(kwargs)


class FuzzyQuery(AttrDict[Any]):
"""
:arg value: (required) Term you wish to find in the provided field.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,33 @@ def test_function_score_to_dict() -> None:
assert d == q.to_dict()


def test_function_score_class_based_to_dict() -> None:
q = query.FunctionScore(
query=query.Match(title="python"),
functions=[
function.RandomScore(),
function.FieldValueFactor(
field="comment_count",
filter=query.Term(tags="python"),
),
],
)

d = {
"function_score": {
"query": {"match": {"title": "python"}},
"functions": [
{"random_score": {}},
{
"filter": {"term": {"tags": "python"}},
"field_value_factor": {"field": "comment_count"},
},
],
}
}
assert d == q.to_dict()


def test_function_score_with_single_function() -> None:
d = {
"function_score": {
Expand Down
6 changes: 6 additions & 0 deletions utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ def get_python_type(self, schema_type, for_response=False):
):
# QueryContainer maps to the DSL's Query class
return "Query", {"type": "query"}
elif (
type_name["namespace"] == "_types.query_dsl"
and type_name["name"] == "FunctionScoreContainer"
):
# FunctionScoreContainer maps to the DSL's ScoreFunction class
return "ScoreFunction", {"type": "score_function"}
elif (
type_name["namespace"] == "_types.aggregations"
and type_name["name"] == "Buckets"
Expand Down
1 change: 0 additions & 1 deletion utils/templates/query.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class {{ k.name }}({{ parent }}):
shortcut property. Until the code generator can support shortcut
properties directly that solution is added here #}
"filter": {"type": "query"},
"functions": {"type": "score_function", "multi": True},
{% endif %}
}
{% endif %}
Expand Down
2 changes: 1 addition & 1 deletion utils/templates/types.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union
from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl import function, Query
from elasticsearch_dsl import Query
from elasticsearch_dsl.utils import AttrDict

PipeSeparatedFlags = str
Expand Down
Loading