Skip to content

Commit 5898740

Browse files
add new COMPLETION command
1 parent c6fb3f9 commit 5898740

File tree

3 files changed

+207
-17
lines changed

3 files changed

+207
-17
lines changed

docs/sphinx/esql.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Commands
2525
:members:
2626
:exclude-members: __init__
2727

28+
.. autoclass:: elasticsearch.esql.esql.ChangePoint
29+
:members:
30+
:exclude-members: __init__
31+
2832
.. autoclass:: elasticsearch.esql.esql.Dissect
2933
:members:
3034
:exclude-members: __init__

elasticsearch/esql/esql.py

Lines changed: 103 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,47 @@ def change_point(self, value: FieldType) -> "ChangePoint":
134134
"""
135135
return ChangePoint(self, value)
136136

137+
def completion(self, *prompt: ExpressionType, **named_prompt: ExpressionType) -> "Completion":
138+
"""The `COMPLETION` command allows you to send prompts and context to a Large
139+
Language Model (LLM) directly within your ES|QL queries, to perform text
140+
generation tasks.
141+
142+
:param prompt: The input text or expression used to prompt the LLM. This can
143+
be a string literal or a reference to a column containing text.
144+
:param named_prompt: The input text or expresion, given as a keyword argument.
145+
The argument name is used for the column name. If not
146+
specified, the results will be stored in a column named
147+
`completion`. If the specified column already exists, it
148+
will be overwritten with the new results.
149+
150+
Examples::
151+
152+
query1 = (
153+
ESQL.row(question="What is Elasticsearch?")
154+
.completion("question").with_("test_completion_model")
155+
.keep("question", "completion")
156+
)
157+
query2 = (
158+
ESQL.row(question="What is Elasticsearch?")
159+
.completion(answer="question").with_("test_completion_model")
160+
.keep("question", "answer")
161+
)
162+
query3 = (
163+
ESQL.from_("movies")
164+
.sort("rating DESC")
165+
.limit(10)
166+
.eval(prompt=\"\"\"CONCAT(
167+
"Summarize this movie using the following information: \\n",
168+
"Title: ", title, "\\n",
169+
"Synopsis: ", synopsis, "\\n",
170+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
171+
)\"\"\")
172+
.completion(summary="prompt").with_("test_completion_model")
173+
.keep("title", "summary", "rating")
174+
)
175+
"""
176+
return Completion(self, *prompt, **named_prompt)
177+
137178
def dissect(self, input: FieldType, pattern: str) -> "Dissect":
138179
"""``DISSECT`` enables you to extract structured data out of a string.
139180
@@ -306,43 +347,39 @@ def limit(self, max_number_of_rows: int) -> "Limit":
306347
"""
307348
return Limit(self, max_number_of_rows)
308349

309-
def lookup_join(self, lookup_index: IndexType, field: FieldType) -> "LookupJoin":
350+
def lookup_join(self, lookup_index: IndexType) -> "LookupJoin":
310351
"""`LOOKUP JOIN` enables you to add data from another index, AKA a 'lookup' index,
311352
to your ES|QL query results, simplifying data enrichment and analysis workflows.
312353
313354
:param lookup_index: The name of the lookup index. This must be a specific index
314355
name - wildcards, aliases, and remote cluster references are
315356
not supported. Indices used for lookups must be configured
316357
with the lookup index mode.
317-
:param field: The field to join on. This field must exist in both your current query
318-
results and in the lookup index. If the field contains multi-valued
319-
entries, those entries will not match anything (the added fields will
320-
contain null for those rows).
321358
322359
Examples::
323360
324361
query1 = (
325362
ESQL.from_("firewall_logs")
326-
.lookup_join("threat_list", "source.IP")
363+
.lookup_join("threat_list").on("source.IP")
327364
.where("threat_level IS NOT NULL")
328365
)
329366
query2 = (
330367
ESQL.from_("system_metrics")
331-
.lookup_join("host_inventory", "host.name")
332-
.lookup_join("ownerships", "host.name")
368+
.lookup_join("host_inventory").on("host.name")
369+
.lookup_join("ownerships").on("host.name")
333370
)
334371
query3 = (
335372
ESQL.from_("app_logs")
336-
.lookup_join("service_owners", "service_id")
373+
.lookup_join("service_owners").on("service_id")
337374
)
338375
query4 = (
339376
ESQL.from_("employees")
340377
.eval(language_code="languages")
341378
.where("emp_no >= 10091 AND emp_no < 10094")
342-
.lookup_join("languages_lookup", "language_code")
379+
.lookup_join("languages_lookup").on("language_code")
343380
)
344381
"""
345-
return LookupJoin(self, lookup_index, field)
382+
return LookupJoin(self, lookup_index)
346383

347384
def mv_expand(self, column: FieldType) -> "MvExpand":
348385
"""The `MV_EXPAND` processing command expands multivalued columns into one row per
@@ -635,6 +672,47 @@ def _render_internal(self) -> str:
635672
return f"CHANGE_POINT {self._value}{key}{names}"
636673

637674

675+
class Completion(ESQLBase):
676+
"""Implementation of the ``COMPLETION`` processing command.
677+
678+
This class inherits from :class:`ESQLBase <elasticsearch.esql.esql.ESQLBase>`,
679+
to make it possible to chain all the commands that belong to an ES|QL query
680+
in a single expression.
681+
"""
682+
683+
def __init__(
684+
self, parent: ESQLBase, *prompt: ExpressionType, **named_prompt: ExpressionType
685+
):
686+
if len(prompt) + len(named_prompt) > 1:
687+
raise ValueError(
688+
"this method requires either one positional or one keyword argument only"
689+
)
690+
super().__init__(parent)
691+
self._prompt = prompt
692+
self._named_prompt = named_prompt
693+
self._inference_id: Optional[str] = None
694+
695+
def with_(self, inference_id: str) -> "Completion":
696+
"""Continuation of the `COMPLETION` command.
697+
698+
:param inference_id: The ID of the inference endpoint to use for the task. The
699+
inference endpoint must be configured with the completion
700+
task type.
701+
"""
702+
self._inference_id = inference_id
703+
return self
704+
705+
def _render_internal(self) -> str:
706+
if self._inference_id is None:
707+
raise ValueError("The completion command requires an inference ID")
708+
if self._named_prompt:
709+
column = list(self._named_prompt.keys())[0]
710+
prompt = list(self._named_prompt.values())[0]
711+
return f"COMPLETION {column} = {prompt} WITH {self._inference_id}"
712+
else:
713+
return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}"
714+
715+
638716
class Dissect(ESQLBase):
639717
"""Implementation of the ``DISSECT`` processing command.
640718
@@ -861,12 +939,25 @@ class LookupJoin(ESQLBase):
861939
in a single expression.
862940
"""
863941

864-
def __init__(self, parent: ESQLBase, lookup_index: IndexType, field: FieldType):
942+
def __init__(self, parent: ESQLBase, lookup_index: IndexType):
865943
super().__init__(parent)
866944
self._lookup_index = lookup_index
945+
self._field = None
946+
947+
def on(self, field: FieldType) -> "LookupJoin":
948+
"""Continuation of the `LOOKUP_JOIN` command.
949+
950+
:param field: The field to join on. This field must exist in both your current query
951+
results and in the lookup index. If the field contains multi-valued
952+
entries, those entries will not match anything (the added fields will
953+
contain null for those rows).
954+
"""
867955
self._field = field
956+
return self
868957

869958
def _render_internal(self) -> str:
959+
if self._field is None:
960+
raise ValueError("Joins require a field to join on.")
870961
index = (
871962
self._lookup_index
872963
if isinstance(self._lookup_index, str)

test_elasticsearch/test_esql.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,97 @@ def test_change_point():
7474
)
7575

7676

77+
def test_completion():
78+
query = (
79+
ESQL.row(question="What is Elasticsearch?")
80+
.completion("question")
81+
.with_("test_completion_model")
82+
.keep("question", "completion")
83+
)
84+
assert (
85+
query.render()
86+
== """ROW question = "What is Elasticsearch?"
87+
| COMPLETION question WITH test_completion_model
88+
| KEEP question, completion"""
89+
)
90+
91+
query = (
92+
ESQL.row(question="What is Elasticsearch?")
93+
.completion(answer=E("question"))
94+
.with_("test_completion_model")
95+
.keep("question", "answer")
96+
)
97+
assert (
98+
query.render()
99+
== """ROW question = "What is Elasticsearch?"
100+
| COMPLETION answer = question WITH test_completion_model
101+
| KEEP question, answer"""
102+
)
103+
104+
query = (
105+
ESQL.from_("movies")
106+
.sort("rating DESC")
107+
.limit(10)
108+
.eval(
109+
prompt="""CONCAT(
110+
"Summarize this movie using the following information: \\n",
111+
"Title: ", title, "\\n",
112+
"Synopsis: ", synopsis, "\\n",
113+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
114+
)"""
115+
)
116+
.completion(summary="prompt")
117+
.with_("test_completion_model")
118+
.keep("title", "summary", "rating")
119+
)
120+
assert (
121+
query.render()
122+
== """FROM movies
123+
| SORT rating DESC
124+
| LIMIT 10
125+
| EVAL prompt = CONCAT(
126+
"Summarize this movie using the following information: \\n",
127+
"Title: ", title, "\\n",
128+
"Synopsis: ", synopsis, "\\n",
129+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
130+
)
131+
| COMPLETION summary = prompt WITH test_completion_model
132+
| KEEP title, summary, rating"""
133+
)
134+
135+
query = (
136+
ESQL.from_("movies")
137+
.sort("rating DESC")
138+
.limit(10)
139+
.eval(
140+
prompt=functions.concat(
141+
"Summarize this movie using the following information: \n",
142+
"Title: ",
143+
E("title"),
144+
"\n",
145+
"Synopsis: ",
146+
E("synopsis"),
147+
"\n",
148+
"Actors: ",
149+
functions.mv_concat(E("actors"), ", "),
150+
"\n",
151+
)
152+
)
153+
.completion(summary="prompt")
154+
.with_("test_completion_model")
155+
.keep("title", "summary", "rating")
156+
)
157+
assert (
158+
query.render()
159+
== """FROM movies
160+
| SORT rating DESC
161+
| LIMIT 10
162+
| EVAL prompt = CONCAT("Summarize this movie using the following information: \\n", "Title: ", title, "\\n", "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n")
163+
| COMPLETION summary = prompt WITH test_completion_model
164+
| KEEP title, summary, rating"""
165+
)
166+
167+
77168
def test_dissect():
78169
query = (
79170
ESQL.row(a="2023-01-23T12:15:00.000Z - some text - 127.0.0.1")
@@ -260,7 +351,8 @@ def test_limit():
260351
def test_lookup_join():
261352
query = (
262353
ESQL.from_("firewall_logs")
263-
.lookup_join("threat_list", "source.IP")
354+
.lookup_join("threat_list")
355+
.on("source.IP")
264356
.where("threat_level IS NOT NULL")
265357
)
266358
assert (
@@ -272,8 +364,10 @@ def test_lookup_join():
272364

273365
query = (
274366
ESQL.from_("system_metrics")
275-
.lookup_join("host_inventory", "host.name")
276-
.lookup_join("ownerships", "host.name")
367+
.lookup_join("host_inventory")
368+
.on("host.name")
369+
.lookup_join("ownerships")
370+
.on("host.name")
277371
)
278372
assert (
279373
query.render()
@@ -282,7 +376,7 @@ def test_lookup_join():
282376
| LOOKUP JOIN ownerships ON host.name"""
283377
)
284378

285-
query = ESQL.from_("app_logs").lookup_join("service_owners", "service_id")
379+
query = ESQL.from_("app_logs").lookup_join("service_owners").on("service_id")
286380
assert (
287381
query.render()
288382
== """FROM app_logs
@@ -293,7 +387,8 @@ def test_lookup_join():
293387
ESQL.from_("employees")
294388
.eval(language_code="languages")
295389
.where(E("emp_no") >= 10091, E("emp_no") < 10094)
296-
.lookup_join("languages_lookup", "language_code")
390+
.lookup_join("languages_lookup")
391+
.on("language_code")
297392
)
298393
assert (
299394
query.render()

0 commit comments

Comments
 (0)