Skip to content

Commit c7fc780

Browse files
add new COMPLETION command
1 parent c6fb3f9 commit c7fc780

File tree

3 files changed

+209
-17
lines changed

3 files changed

+209
-17
lines changed

docs/sphinx/esql.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Commands
2525
:members:
2626
:exclude-members: __init__
2727

28+
.. autoclass:: elasticsearch.esql.esql.ChangePoint
29+
:members:
30+
:exclude-members: __init__
31+
2832
.. autoclass:: elasticsearch.esql.esql.Dissect
2933
:members:
3034
:exclude-members: __init__

elasticsearch/esql/esql.py

Lines changed: 105 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,49 @@ def change_point(self, value: FieldType) -> "ChangePoint":
134134
"""
135135
return ChangePoint(self, value)
136136

137+
def completion(
138+
self, *prompt: ExpressionType, **named_prompt: ExpressionType
139+
) -> "Completion":
140+
"""The `COMPLETION` command allows you to send prompts and context to a Large
141+
Language Model (LLM) directly within your ES|QL queries, to perform text
142+
generation tasks.
143+
144+
:param prompt: The input text or expression used to prompt the LLM. This can
145+
be a string literal or a reference to a column containing text.
146+
:param named_prompt: The input text or expresion, given as a keyword argument.
147+
The argument name is used for the column name. If not
148+
specified, the results will be stored in a column named
149+
`completion`. If the specified column already exists, it
150+
will be overwritten with the new results.
151+
152+
Examples::
153+
154+
query1 = (
155+
ESQL.row(question="What is Elasticsearch?")
156+
.completion("question").with_("test_completion_model")
157+
.keep("question", "completion")
158+
)
159+
query2 = (
160+
ESQL.row(question="What is Elasticsearch?")
161+
.completion(answer="question").with_("test_completion_model")
162+
.keep("question", "answer")
163+
)
164+
query3 = (
165+
ESQL.from_("movies")
166+
.sort("rating DESC")
167+
.limit(10)
168+
.eval(prompt=\"\"\"CONCAT(
169+
"Summarize this movie using the following information: \\n",
170+
"Title: ", title, "\\n",
171+
"Synopsis: ", synopsis, "\\n",
172+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
173+
)\"\"\")
174+
.completion(summary="prompt").with_("test_completion_model")
175+
.keep("title", "summary", "rating")
176+
)
177+
"""
178+
return Completion(self, *prompt, **named_prompt)
179+
137180
def dissect(self, input: FieldType, pattern: str) -> "Dissect":
138181
"""``DISSECT`` enables you to extract structured data out of a string.
139182
@@ -306,43 +349,39 @@ def limit(self, max_number_of_rows: int) -> "Limit":
306349
"""
307350
return Limit(self, max_number_of_rows)
308351

309-
def lookup_join(self, lookup_index: IndexType, field: FieldType) -> "LookupJoin":
352+
def lookup_join(self, lookup_index: IndexType) -> "LookupJoin":
310353
"""`LOOKUP JOIN` enables you to add data from another index, AKA a 'lookup' index,
311354
to your ES|QL query results, simplifying data enrichment and analysis workflows.
312355
313356
:param lookup_index: The name of the lookup index. This must be a specific index
314357
name - wildcards, aliases, and remote cluster references are
315358
not supported. Indices used for lookups must be configured
316359
with the lookup index mode.
317-
:param field: The field to join on. This field must exist in both your current query
318-
results and in the lookup index. If the field contains multi-valued
319-
entries, those entries will not match anything (the added fields will
320-
contain null for those rows).
321360
322361
Examples::
323362
324363
query1 = (
325364
ESQL.from_("firewall_logs")
326-
.lookup_join("threat_list", "source.IP")
365+
.lookup_join("threat_list").on("source.IP")
327366
.where("threat_level IS NOT NULL")
328367
)
329368
query2 = (
330369
ESQL.from_("system_metrics")
331-
.lookup_join("host_inventory", "host.name")
332-
.lookup_join("ownerships", "host.name")
370+
.lookup_join("host_inventory").on("host.name")
371+
.lookup_join("ownerships").on("host.name")
333372
)
334373
query3 = (
335374
ESQL.from_("app_logs")
336-
.lookup_join("service_owners", "service_id")
375+
.lookup_join("service_owners").on("service_id")
337376
)
338377
query4 = (
339378
ESQL.from_("employees")
340379
.eval(language_code="languages")
341380
.where("emp_no >= 10091 AND emp_no < 10094")
342-
.lookup_join("languages_lookup", "language_code")
381+
.lookup_join("languages_lookup").on("language_code")
343382
)
344383
"""
345-
return LookupJoin(self, lookup_index, field)
384+
return LookupJoin(self, lookup_index)
346385

347386
def mv_expand(self, column: FieldType) -> "MvExpand":
348387
"""The `MV_EXPAND` processing command expands multivalued columns into one row per
@@ -635,6 +674,47 @@ def _render_internal(self) -> str:
635674
return f"CHANGE_POINT {self._value}{key}{names}"
636675

637676

677+
class Completion(ESQLBase):
678+
"""Implementation of the ``COMPLETION`` processing command.
679+
680+
This class inherits from :class:`ESQLBase <elasticsearch.esql.esql.ESQLBase>`,
681+
to make it possible to chain all the commands that belong to an ES|QL query
682+
in a single expression.
683+
"""
684+
685+
def __init__(
686+
self, parent: ESQLBase, *prompt: ExpressionType, **named_prompt: ExpressionType
687+
):
688+
if len(prompt) + len(named_prompt) > 1:
689+
raise ValueError(
690+
"this method requires either one positional or one keyword argument only"
691+
)
692+
super().__init__(parent)
693+
self._prompt = prompt
694+
self._named_prompt = named_prompt
695+
self._inference_id: Optional[str] = None
696+
697+
def with_(self, inference_id: str) -> "Completion":
698+
"""Continuation of the `COMPLETION` command.
699+
700+
:param inference_id: The ID of the inference endpoint to use for the task. The
701+
inference endpoint must be configured with the completion
702+
task type.
703+
"""
704+
self._inference_id = inference_id
705+
return self
706+
707+
def _render_internal(self) -> str:
708+
if self._inference_id is None:
709+
raise ValueError("The completion command requires an inference ID")
710+
if self._named_prompt:
711+
column = list(self._named_prompt.keys())[0]
712+
prompt = list(self._named_prompt.values())[0]
713+
return f"COMPLETION {column} = {prompt} WITH {self._inference_id}"
714+
else:
715+
return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}"
716+
717+
638718
class Dissect(ESQLBase):
639719
"""Implementation of the ``DISSECT`` processing command.
640720
@@ -861,12 +941,25 @@ class LookupJoin(ESQLBase):
861941
in a single expression.
862942
"""
863943

864-
def __init__(self, parent: ESQLBase, lookup_index: IndexType, field: FieldType):
944+
def __init__(self, parent: ESQLBase, lookup_index: IndexType):
865945
super().__init__(parent)
866946
self._lookup_index = lookup_index
947+
self._field = None
948+
949+
def on(self, field: FieldType) -> "LookupJoin":
950+
"""Continuation of the `LOOKUP_JOIN` command.
951+
952+
:param field: The field to join on. This field must exist in both your current query
953+
results and in the lookup index. If the field contains multi-valued
954+
entries, those entries will not match anything (the added fields will
955+
contain null for those rows).
956+
"""
867957
self._field = field
958+
return self
868959

869960
def _render_internal(self) -> str:
961+
if self._field is None:
962+
raise ValueError("Joins require a field to join on.")
870963
index = (
871964
self._lookup_index
872965
if isinstance(self._lookup_index, str)

test_elasticsearch/test_esql.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,97 @@ def test_change_point():
7474
)
7575

7676

77+
def test_completion():
78+
query = (
79+
ESQL.row(question="What is Elasticsearch?")
80+
.completion("question")
81+
.with_("test_completion_model")
82+
.keep("question", "completion")
83+
)
84+
assert (
85+
query.render()
86+
== """ROW question = "What is Elasticsearch?"
87+
| COMPLETION question WITH test_completion_model
88+
| KEEP question, completion"""
89+
)
90+
91+
query = (
92+
ESQL.row(question="What is Elasticsearch?")
93+
.completion(answer=E("question"))
94+
.with_("test_completion_model")
95+
.keep("question", "answer")
96+
)
97+
assert (
98+
query.render()
99+
== """ROW question = "What is Elasticsearch?"
100+
| COMPLETION answer = question WITH test_completion_model
101+
| KEEP question, answer"""
102+
)
103+
104+
query = (
105+
ESQL.from_("movies")
106+
.sort("rating DESC")
107+
.limit(10)
108+
.eval(
109+
prompt="""CONCAT(
110+
"Summarize this movie using the following information: \\n",
111+
"Title: ", title, "\\n",
112+
"Synopsis: ", synopsis, "\\n",
113+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
114+
)"""
115+
)
116+
.completion(summary="prompt")
117+
.with_("test_completion_model")
118+
.keep("title", "summary", "rating")
119+
)
120+
assert (
121+
query.render()
122+
== """FROM movies
123+
| SORT rating DESC
124+
| LIMIT 10
125+
| EVAL prompt = CONCAT(
126+
"Summarize this movie using the following information: \\n",
127+
"Title: ", title, "\\n",
128+
"Synopsis: ", synopsis, "\\n",
129+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
130+
)
131+
| COMPLETION summary = prompt WITH test_completion_model
132+
| KEEP title, summary, rating"""
133+
)
134+
135+
query = (
136+
ESQL.from_("movies")
137+
.sort("rating DESC")
138+
.limit(10)
139+
.eval(
140+
prompt=functions.concat(
141+
"Summarize this movie using the following information: \n",
142+
"Title: ",
143+
E("title"),
144+
"\n",
145+
"Synopsis: ",
146+
E("synopsis"),
147+
"\n",
148+
"Actors: ",
149+
functions.mv_concat(E("actors"), ", "),
150+
"\n",
151+
)
152+
)
153+
.completion(summary="prompt")
154+
.with_("test_completion_model")
155+
.keep("title", "summary", "rating")
156+
)
157+
assert (
158+
query.render()
159+
== """FROM movies
160+
| SORT rating DESC
161+
| LIMIT 10
162+
| EVAL prompt = CONCAT("Summarize this movie using the following information: \\n", "Title: ", title, "\\n", "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n")
163+
| COMPLETION summary = prompt WITH test_completion_model
164+
| KEEP title, summary, rating"""
165+
)
166+
167+
77168
def test_dissect():
78169
query = (
79170
ESQL.row(a="2023-01-23T12:15:00.000Z - some text - 127.0.0.1")
@@ -260,7 +351,8 @@ def test_limit():
260351
def test_lookup_join():
261352
query = (
262353
ESQL.from_("firewall_logs")
263-
.lookup_join("threat_list", "source.IP")
354+
.lookup_join("threat_list")
355+
.on("source.IP")
264356
.where("threat_level IS NOT NULL")
265357
)
266358
assert (
@@ -272,8 +364,10 @@ def test_lookup_join():
272364

273365
query = (
274366
ESQL.from_("system_metrics")
275-
.lookup_join("host_inventory", "host.name")
276-
.lookup_join("ownerships", "host.name")
367+
.lookup_join("host_inventory")
368+
.on("host.name")
369+
.lookup_join("ownerships")
370+
.on("host.name")
277371
)
278372
assert (
279373
query.render()
@@ -282,7 +376,7 @@ def test_lookup_join():
282376
| LOOKUP JOIN ownerships ON host.name"""
283377
)
284378

285-
query = ESQL.from_("app_logs").lookup_join("service_owners", "service_id")
379+
query = ESQL.from_("app_logs").lookup_join("service_owners").on("service_id")
286380
assert (
287381
query.render()
288382
== """FROM app_logs
@@ -293,7 +387,8 @@ def test_lookup_join():
293387
ESQL.from_("employees")
294388
.eval(language_code="languages")
295389
.where(E("emp_no") >= 10091, E("emp_no") < 10094)
296-
.lookup_join("languages_lookup", "language_code")
390+
.lookup_join("languages_lookup")
391+
.on("language_code")
297392
)
298393
assert (
299394
query.render()

0 commit comments

Comments
 (0)