From 7ca6b086ce5a42b3a4a2fcdbbf069bacb04e2c78 Mon Sep 17 00:00:00 2001 From: pgayvallet Date: Fri, 13 Sep 2024 10:54:03 +0200 Subject: [PATCH 1/6] start refactor --- .../tasks/nl_to_esql/doc_base/aliases.ts | 32 ++++++ .../nl_to_esql/doc_base/esql_doc_base.ts | 101 ++++++++++++++++++ .../server/tasks/nl_to_esql/doc_base/index.ts | 8 ++ .../tasks/nl_to_esql/doc_base/load_data.ts | 59 ++++++++++ .../server/tasks/nl_to_esql/index.ts | 65 +++++------ .../server/tasks/nl_to_esql/load_documents.ts | 55 ---------- 6 files changed, 227 insertions(+), 93 deletions(-) create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts delete mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts new file mode 100644 index 0000000000000..29f07af2d1121 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * Sometimes the LLM request documentation by wrongly naming the command. + * This is mostly for the case for STATS. + */ +const aliases: Record = { + STATS: ['STATS_BY', 'BY', 'STATS...BY'], +}; + +const getAliasMap = () => { + return Object.entries(aliases).reduce>( + (aliasMap, [command, commandAliases]) => { + commandAliases.forEach((alias) => { + aliasMap[alias] = command; + }); + return aliasMap; + }, + {} + ); +}; + +const aliasMap = getAliasMap(); + +export const tryResolveAlias = (maybeAlias: string): string => { + return aliasMap[maybeAlias] ?? maybeAlias; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts new file mode 100644 index 0000000000000..70044e00da484 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { loadData, type EsqlDocData, type EsqlDocEntry } from './load_data'; +import { tryResolveAlias } from './aliases'; + +interface GetDocsOptions { + /** + * If true (default), will include general ES|QL documentation entries + * such as the overview, syntax and operators page. + */ + addOverview?: boolean; + /** + * If true (default) will try to resolve aliases for commands. + */ + resolveAliases?: boolean; + + /** + * If true (default) will generate a fake doc page for missing keywords. + * Useful for the LLM to understand that the requested keyword does not exist. + */ + generateMissingKeywordDoc?: boolean; + + /** + * If true (default), additional documentation will be included to help the LLM. + * E.g. for STATS, BUCKET will be included. + */ + addSuggestions?: boolean; +} + +const overviewEntries = ['SYNTAX', 'OVERVIEW', 'OPERATORS']; + +export class EsqlDocumentBase { + private systemMessage: string; + private docRecords: Record; + + static async load(): Promise { + const data = await loadData(); + return new EsqlDocumentBase(data); + } + + constructor(rawData: EsqlDocData) { + this.systemMessage = rawData.systemMessage; + this.docRecords = rawData.docs; + } + + getSystemMessage() { + return this.systemMessage; + } + + getDocumentation( + keywords: string[], + { + generateMissingKeywordDoc = true, + addSuggestions = true, + addOverview = true, + resolveAliases = true, + }: GetDocsOptions = {} + ) { + keywords = keywords.map((raw) => { + let keyword = format(raw); + if (resolveAliases) { + keyword = tryResolveAlias(keyword); + } + return keyword; + }); + + if (addSuggestions) { + // TODO + } + + if (addOverview) { + keywords.push(...overviewEntries); + } + + return keywords.reduce>((results, keyword) => { + if (Object.hasOwn(this.docRecords, keyword)) { + results[keyword] = this.docRecords[keyword].data; + } else if (generateMissingKeywordDoc) { + results[keyword] = createDocForUnknownKeyword(keyword); + } + return results; + }, {}); + } +} + +const format = (keyword: string) => { + return keyword.replaceAll(' ', '').toUpperCase(); +}; + +const createDocForUnknownKeyword = (keyword: string) => { + return ` + ## ${keyword} + + There is no ${keyword} function or command in ES|QL. Do NOT use it. + `; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts new file mode 100644 index 0000000000000..e498b799f577c --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { EsqlDocumentBase } from './esql_doc_base'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts new file mode 100644 index 0000000000000..340f06fd0fced --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import Path from 'path'; +import { keyBy } from 'lodash'; +import pLimit from 'p-limit'; +import { readdir, readFile } from 'fs/promises'; + +export interface EsqlDocEntry { + keyword: string; + data: string; +} + +export interface EsqlDocData { + systemMessage: string; + docs: Record; +} + +export const loadData = async (): Promise => { + const [systemMessage, docs] = await Promise.all([loadSystemMessage(), loadEsqlDocs()]); + return { + systemMessage, + docs, + }; +}; + +const loadSystemMessage = async () => { + return (await readFile(Path.join(__dirname, '../system_message.txt'))).toString('utf-8'); +}; + +const loadEsqlDocs = async (): Promise> => { + const dir = Path.join(__dirname, '../esql_docs'); + const files = (await readdir(dir)).filter((file) => Path.extname(file) === '.txt'); + + const limiter = pLimit(10); + + return keyBy( + await Promise.all( + files.map((file) => + limiter(async () => { + const data = (await readFile(Path.join(dir, file))).toString('utf-8'); + const filename = Path.basename(file, '.txt'); + + const keyword = filename.replace('esql-', '').replaceAll('-', '_').toUpperCase(); + + return { + keyword, + data, + }; + }) + ) + ), + 'keyword' + ); +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts index 2fcc204a9f47a..2620d57f243ff 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts @@ -6,7 +6,7 @@ */ import type { Logger } from '@kbn/logging'; -import { isEmpty, has } from 'lodash'; +import { isEmpty, once } from 'lodash'; import { Observable, from, map, merge, of, switchMap } from 'rxjs'; import { ToolSchema, generateFakeToolCallId, isChatCompletionMessageEvent } from '../../../common'; import { @@ -15,14 +15,14 @@ import { Message, MessageRole, } from '../../../common/chat_complete'; -import { ToolChoiceType, type ToolOptions } from '../../../common/chat_complete/tools'; +import { ToolChoiceType, type ToolOptions, ToolCall } from '../../../common/chat_complete/tools'; import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events'; import { OutputCompleteEvent, OutputEventType } from '../../../common/output'; import { withoutOutputUpdateEvents } from '../../../common/output/without_output_update_events'; import { INLINE_ESQL_QUERY_REGEX } from '../../../common/tasks/nl_to_esql/constants'; import { correctCommonEsqlMistakes } from '../../../common/tasks/nl_to_esql/correct_common_esql_mistakes'; import type { InferenceClient } from '../../types'; -import { loadDocuments } from './load_documents'; +import { EsqlDocumentBase } from './doc_base'; type NlToEsqlTaskEvent = | OutputCompleteEvent< @@ -32,6 +32,8 @@ type NlToEsqlTaskEvent = | ChatCompletionChunkEvent | ChatCompletionMessageEvent; +const loadDocBase = once(() => EsqlDocumentBase.load()); + export function naturalLanguageToEsql({ client, connectorId, @@ -71,47 +73,18 @@ export function naturalLanguageToEsql({ const messages: Message[] = 'input' in rest ? [{ role: MessageRole.User, content: rest.input }] : rest.messages; - return from(loadDocuments()).pipe( - switchMap(([systemMessage, esqlDocs]) => { + return from(loadDocBase()).pipe( + switchMap((docBase) => { function askLlmToRespond({ documentationRequest: { commands, functions }, }: { documentationRequest: { commands?: string[]; functions?: string[] }; }): Observable> { - const keywords = [ - ...(commands ?? []), - ...(functions ?? []), - 'SYNTAX', - 'OVERVIEW', - 'OPERATORS', - ].map((keyword) => keyword.toUpperCase()); - - const requestedDocumentation = keywords.reduce>( - (documentation, keyword) => { - if (has(esqlDocs, keyword)) { - documentation[keyword] = esqlDocs[keyword].data; - } else { - documentation[keyword] = ` - ## ${keyword} - - There is no ${keyword} function or command in ES|QL. Do NOT try to use it. - `; - } - return documentation; - }, - {} - ); + const keywords = [...(commands ?? []), ...(functions ?? [])]; - const fakeRequestDocsToolCall = { - function: { - name: 'request_documentation', - arguments: { - commands, - functions, - }, - }, - toolCallId: generateFakeToolCallId(), - }; + const systemMessage = docBase.getSystemMessage(); + const requestedDocumentation = docBase.getDocumentation(keywords); + const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); return merge( of< @@ -276,3 +249,19 @@ export function naturalLanguageToEsql({ }) ); } + +const createFakeTooCall = ( + commands: string[] | undefined, + functions: string[] | undefined +): ToolCall => { + return { + function: { + name: 'request_documentation', + arguments: { + commands, + functions, + }, + }, + toolCallId: generateFakeToolCallId(), + }; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts deleted file mode 100644 index 73359d6c614df..0000000000000 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import Path from 'path'; -import Fs from 'fs'; -import { keyBy, once } from 'lodash'; -import { promisify } from 'util'; -import pLimit from 'p-limit'; - -const readFile = promisify(Fs.readFile); -const readdir = promisify(Fs.readdir); - -const loadSystemMessage = once(async () => { - const data = await readFile(Path.join(__dirname, './system_message.txt')); - return data.toString('utf-8'); -}); - -const loadEsqlDocs = async () => { - const dir = Path.join(__dirname, './esql_docs'); - const files = (await readdir(dir)).filter((file) => Path.extname(file) === '.txt'); - - if (!files.length) { - return {}; - } - - const limiter = pLimit(10); - return keyBy( - await Promise.all( - files.map((file) => - limiter(async () => { - const data = (await readFile(Path.join(dir, file))).toString('utf-8'); - const filename = Path.basename(file, '.txt'); - - const keyword = filename - .replace('esql-', '') - .replace('agg-', '') - .replaceAll('-', '_') - .toUpperCase(); - - return { - keyword: keyword === 'STATS_BY' ? 'STATS' : keyword, - data, - }; - }) - ) - ), - 'keyword' - ); -}; - -export const loadDocuments = once(() => Promise.all([loadSystemMessage(), loadEsqlDocs()])); From 6137e68c1432f8655d56139fbd1d7495d391094a Mon Sep 17 00:00:00 2001 From: pgayvallet Date: Fri, 13 Sep 2024 13:50:42 +0200 Subject: [PATCH 2/6] refactor main workflow --- .../server/tasks/nl_to_esql/index.ts | 354 ++++++++++-------- 1 file changed, 204 insertions(+), 150 deletions(-) diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts index 2620d57f243ff..4cbcb6aba7d9f 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts @@ -34,6 +34,27 @@ type NlToEsqlTaskEvent = const loadDocBase = once(() => EsqlDocumentBase.load()); +const requestDocumentationSchema = { + type: 'object', + properties: { + commands: { + type: 'array', + items: { + type: 'string', + }, + description: + 'ES|QL source and processing commands you want to analyze before generating the query.', + }, + functions: { + type: 'array', + items: { + type: 'string', + }, + description: 'ES|QL functions you want to analyze before generating the query.', + }, + }, +} satisfies ToolSchema; + export function naturalLanguageToEsql({ client, connectorId, @@ -47,64 +68,93 @@ export function naturalLanguageToEsql({ logger: Pick; } & TToolOptions & ({ input: string } | { messages: Message[] })): Observable> { - const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; - - const requestDocumentationSchema = { - type: 'object', - properties: { - commands: { - type: 'array', - items: { - type: 'string', - }, - description: - 'ES|QL source and processing commands you want to analyze before generating the query.', - }, - functions: { - type: 'array', - items: { - type: 'string', - }, - description: 'ES|QL functions you want to analyze before generating the query.', - }, - }, - } satisfies ToolSchema; - const messages: Message[] = 'input' in rest ? [{ role: MessageRole.User, content: rest.input }] : rest.messages; return from(loadDocBase()).pipe( switchMap((docBase) => { - function askLlmToRespond({ - documentationRequest: { commands, functions }, - }: { - documentationRequest: { commands?: string[]; functions?: string[] }; - }): Observable> { - const keywords = [...(commands ?? []), ...(functions ?? [])]; + const systemMessage = docBase.getSystemMessage(); - const systemMessage = docBase.getSystemMessage(); - const requestedDocumentation = docBase.getDocumentation(keywords); - const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); + const askLlmToRespond = generateEsqlTask({ + connectorId, + chatCompleteApi: client.chatComplete, + messages, + docBase, + logger, + systemMessage, + toolOptions: { + tools, + toolChoice, + }, + }); - return merge( - of< - OutputCompleteEvent< - 'request_documentation', - { keywords: string[]; requestedDocumentation: Record } - > - >({ - type: OutputEventType.OutputComplete, - id: 'request_documentation', - output: { - keywords, - requestedDocumentation, + return requestDocumentation({ + connectorId, + outputApi: client.output, + messages, + system: systemMessage, + toolOptions: { + tools, + toolChoice, + }, + }).pipe( + switchMap((documentationEvent) => { + return askLlmToRespond({ + documentationRequest: { + commands: documentationEvent.output.commands, + functions: documentationEvent.output.functions, }, - content: '', - }), - client - .chatComplete({ - connectorId, - system: `${systemMessage} + }); + }) + ); + }) + ); +} + +const generateEsqlTask = ({ + chatCompleteApi, + connectorId, + systemMessage, + messages, + toolOptions: { tools, toolChoice }, + docBase, + logger, +}: { + connectorId: string; + systemMessage: string; + messages: Message[]; + toolOptions: ToolOptions; + chatCompleteApi: InferenceClient['chatComplete']; + docBase: EsqlDocumentBase; + logger: Pick; +}) => { + return function askLlmToRespond({ + documentationRequest: { commands, functions }, + }: { + documentationRequest: { commands?: string[]; functions?: string[] }; + }): Observable> { + const keywords = [...(commands ?? []), ...(functions ?? [])]; + const requestedDocumentation = docBase.getDocumentation(keywords); + const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); + + return merge( + of< + OutputCompleteEvent< + 'request_documentation', + { keywords: string[]; requestedDocumentation: Record } + > + >({ + type: OutputEventType.OutputComplete, + id: 'request_documentation', + output: { + keywords, + requestedDocumentation, + }, + content: '', + }), + chatCompleteApi({ + connectorId, + system: `${systemMessage} # Current task @@ -117,106 +167,123 @@ export function naturalLanguageToEsql({ \`\`\` - When generating ES|QL, you must use commands and functions present on the - requested documentation, and follow the syntax as described in the documentation - and its examples. - - DO NOT UNDER ANY CIRCUMSTANCES use commands, functions, parameters, or syntaxes that are not - explicitly mentioned as supported capability by ES|QL, either in the system message or documentation. - assume that ONLY the set of capabilities described in the requested documentation is valid. - Do not try to guess parameters or syntax based on other query languages. + When generating ES|QL, it is VERY important that you only use commands and functions present in the + requested documentation, and follow the syntax as described in the documentation and its examples. + Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and + do not try to guess parameters or syntax based on other query languages. If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform the user. DO NOT invent capabilities not described in the documentation just to provide - a positive answer to the user. E.g. LIMIT only has one parameter, do not assume you can add more. + a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent + workarounds based on other languages. When converting queries from one language to ES|QL, make sure that the functions are available and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE. -`, - messages: messages.concat([ - { - role: MessageRole.Assistant, - content: null, - toolCalls: [fakeRequestDocsToolCall], - }, - { - role: MessageRole.Tool, - response: { - documentation: requestedDocumentation, - }, - toolCallId: fakeRequestDocsToolCall.toolCallId, - }, - ]), - toolChoice, - tools: { - ...tools, - request_documentation: { - description: 'Request additional documentation if needed', - schema: requestDocumentationSchema, - }, - }, - }) - .pipe( - withoutTokenCountEvents(), - map((generateEvent) => { - if (isChatCompletionMessageEvent(generateEvent)) { - const correctedContent = generateEvent.content?.replaceAll( - INLINE_ESQL_QUERY_REGEX, - (_match, query) => { - const correction = correctCommonEsqlMistakes(query); - if (correction.isCorrection) { - logger.debug( - `Corrected query, from: \n${correction.input}\nto:\n${correction.output}` - ); - } - return '```esql\n' + correction.output + '\n```'; - } - ); + `, + messages: [ + ...messages, + { + role: MessageRole.Assistant, + content: null, + toolCalls: [fakeRequestDocsToolCall], + }, + { + role: MessageRole.Tool, + response: { + documentation: requestedDocumentation, + }, + toolCallId: fakeRequestDocsToolCall.toolCallId, + }, + ], + toolChoice, + tools: { + ...tools, + request_documentation: { + description: 'Request additional ES|QL documentation if needed', + schema: requestDocumentationSchema, + }, + }, + }).pipe( + withoutTokenCountEvents(), + map((generateEvent) => { + if (isChatCompletionMessageEvent(generateEvent)) { + return { + ...generateEvent, + content: generateEvent.content + ? correctEsqlMistakes({ content: generateEvent.content, logger }) + : generateEvent.content, + }; + } - return { - ...generateEvent, - content: correctedContent, - }; - } + return generateEvent; + }), + switchMap((generateEvent) => { + if (isChatCompletionMessageEvent(generateEvent)) { + const onlyToolCall = + generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; - return generateEvent; - }), - switchMap((generateEvent) => { - if (isChatCompletionMessageEvent(generateEvent)) { - const onlyToolCall = - generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; + if (onlyToolCall?.function.name === 'request_documentation') { + const args = onlyToolCall.function.arguments; - if (onlyToolCall?.function.name === 'request_documentation') { - const args = onlyToolCall.function.arguments; + return askLlmToRespond({ + documentationRequest: { + commands: args.commands, + functions: args.functions, + }, + }); + } + } - return askLlmToRespond({ - documentationRequest: { - commands: args.commands, - functions: args.functions, - }, - }); - } - } + return of(generateEvent); + }) + ) + ); + }; +}; - return of(generateEvent); - }) - ) - ); - } +const correctEsqlMistakes = ({ + content, + logger, +}: { + content: string; + logger: Pick; +}) => { + return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => { + const correction = correctCommonEsqlMistakes(query); + if (correction.isCorrection) { + logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`); + } + return '```esql\n' + correction.output + '\n```'; + }); +}; - return client - .output('request_documentation', { - connectorId, - system: systemMessage, - previousMessages: messages, - input: `Based on the previous conversation, request documentation +const requestDocumentation = ({ + outputApi, + system, + messages, + connectorId, + toolOptions: { tools, toolChoice }, +}: { + outputApi: InferenceClient['output']; + system: string; + messages: Message[]; + connectorId: string; + toolOptions: ToolOptions; +}) => { + const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; + + return outputApi('request_documentation', { + connectorId, + system, + previousMessages: messages, + input: `Based on the previous conversation, request documentation from the ES|QL handbook to help you get the right information needed to generate a query. Examples for functions and commands: - Do you need to group data? Request \`STATS\`. - Extract data? Request \`DISSECT\` AND \`GROK\`. - Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. + - Do you need to group data? Request \`STATS\`. + - Extract data? Request \`DISSECT\` AND \`GROK\`. + - Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. ${ hasTools @@ -233,22 +300,9 @@ export function naturalLanguageToEsql({ : '' } `, - schema: requestDocumentationSchema, - }) - .pipe( - withoutOutputUpdateEvents(), - switchMap((documentationEvent) => { - return askLlmToRespond({ - documentationRequest: { - commands: documentationEvent.output.commands, - functions: documentationEvent.output.functions, - }, - }); - }) - ); - }) - ); -} + schema: requestDocumentationSchema, + }).pipe(withoutOutputUpdateEvents()); +}; const createFakeTooCall = ( commands: string[] | undefined, From 92aa58716491a152def26fa9b744290a64d420b8 Mon Sep 17 00:00:00 2001 From: pgayvallet Date: Fri, 13 Sep 2024 15:38:23 +0200 Subject: [PATCH 3/6] extract actions --- .../evaluation/scenarios/esql/index.spec.ts | 19 +- .../tasks/nl_to_esql/actions/generate_esql.ts | 184 ++++++++++ .../server/tasks/nl_to_esql/actions/index.ts | 9 + .../actions/request_documentation.ts | 59 ++++ .../server/tasks/nl_to_esql/actions/shared.ts | 29 ++ .../nl_to_esql/doc_base/esql_doc_base.ts | 5 +- .../server/tasks/nl_to_esql/index.ts | 316 +----------------- .../inference/server/tasks/nl_to_esql/task.ts | 66 ++++ .../server/tasks/nl_to_esql/types.ts | 31 ++ 9 files changed, 391 insertions(+), 327 deletions(-) create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts diff --git a/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts b/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts index 83868884e1429..3aeca67030366 100644 --- a/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts +++ b/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts @@ -8,11 +8,10 @@ /// import expect from '@kbn/expect'; -import { mapValues, pick } from 'lodash'; import { firstValueFrom, lastValueFrom, filter } from 'rxjs'; import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql'; import { chatClient, evaluationClient, logger } from '../../services'; -import { loadDocuments } from '../../../../server/tasks/nl_to_esql/load_documents'; +import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base'; import { isOutputCompleteEvent } from '../../../../common'; interface TestCase { @@ -113,13 +112,9 @@ const retrieveUsedCommands = async ({ const output = commandsListOutput.output; - const keywords = [ - ...(output.commands ?? []), - ...(output.functions ?? []), - 'SYNTAX', - 'OVERVIEW', - 'OPERATORS', - ].map((keyword) => keyword.toUpperCase()); + const keywords = [...(output.commands ?? []), ...(output.functions ?? [])].map((keyword) => + keyword.toUpperCase() + ); return keywords; }; @@ -140,15 +135,15 @@ async function evaluateEsqlQuery({ logger.debug(`Received response: ${answer}`); - const [systemMessage, esqlDocs] = await loadDocuments(); + const docBase = await EsqlDocumentBase.load(); const usedCommands = await retrieveUsedCommands({ question, answer, - esqlDescription: systemMessage, + esqlDescription: docBase.getSystemMessage(), }); - const requestedDocumentation = mapValues(pick(esqlDocs, usedCommands), ({ data }) => data); + const requestedDocumentation = docBase.getDocumentation(usedCommands); const evaluation = await evaluationClient.evaluate({ input: ` diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts new file mode 100644 index 0000000000000..437665ec69a44 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts @@ -0,0 +1,184 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable, map, merge, of, switchMap } from 'rxjs'; +import type { Logger } from '@kbn/logging'; +import { ToolCall, ToolOptions } from '../../../../common/chat_complete/tools'; +import { + correctCommonEsqlMistakes, + generateFakeToolCallId, + isChatCompletionMessageEvent, + Message, + MessageRole, +} from '../../../../common'; +import { InferenceClient, withoutTokenCountEvents } from '../../..'; +import { OutputCompleteEvent, OutputEventType } from '../../../../common/output'; +import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants'; +import { EsqlDocumentBase } from '../doc_base'; +import { requestDocumentationSchema } from './shared'; + +export const generateEsqlTask = ({ + chatCompleteApi, + connectorId, + systemMessage, + messages, + toolOptions: { tools, toolChoice }, + docBase, + logger, +}: { + connectorId: string; + systemMessage: string; + messages: Message[]; + toolOptions: ToolOptions; + chatCompleteApi: InferenceClient['chatComplete']; + docBase: EsqlDocumentBase; + logger: Pick; +}) => { + return function askLlmToRespond({ + documentationRequest: { commands, functions }, + }: { + documentationRequest: { commands?: string[]; functions?: string[] }; + }): Observable> { + const keywords = [...(commands ?? []), ...(functions ?? [])]; + const requestedDocumentation = docBase.getDocumentation(keywords); + const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); + + return merge( + of< + OutputCompleteEvent< + 'request_documentation', + { keywords: string[]; requestedDocumentation: Record } + > + >({ + type: OutputEventType.OutputComplete, + id: 'request_documentation', + output: { + keywords, + requestedDocumentation, + }, + content: '', + }), + chatCompleteApi({ + connectorId, + system: `${systemMessage} + + # Current task + + Your current task is to respond to the user's question. If there is a tool + suitable for answering the user's question, use that tool, preferably + with a natural language reply included. + + Format any ES|QL query as follows: + \`\`\`esql + + \`\`\` + + When generating ES|QL, it is VERY important that you only use commands and functions present in the + requested documentation, and follow the syntax as described in the documentation and its examples. + Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and + do not try to guess parameters or syntax based on other query languages. + + If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform + the user. DO NOT invent capabilities not described in the documentation just to provide + a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent + workarounds based on other languages. + + When converting queries from one language to ES|QL, make sure that the functions are available + and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE. + `, + messages: [ + ...messages, + { + role: MessageRole.Assistant, + content: null, + toolCalls: [fakeRequestDocsToolCall], + }, + { + role: MessageRole.Tool, + response: { + documentation: requestedDocumentation, + }, + toolCallId: fakeRequestDocsToolCall.toolCallId, + }, + ], + toolChoice, + tools: { + ...tools, + request_documentation: { + description: 'Request additional ES|QL documentation if needed', + schema: requestDocumentationSchema, + }, + }, + }).pipe( + withoutTokenCountEvents(), + map((generateEvent) => { + if (isChatCompletionMessageEvent(generateEvent)) { + return { + ...generateEvent, + content: generateEvent.content + ? correctEsqlMistakes({ content: generateEvent.content, logger }) + : generateEvent.content, + }; + } + + return generateEvent; + }), + switchMap((generateEvent) => { + if (isChatCompletionMessageEvent(generateEvent)) { + const onlyToolCall = + generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; + + if (onlyToolCall?.function.name === 'request_documentation') { + const args = onlyToolCall.function.arguments; + + return askLlmToRespond({ + documentationRequest: { + commands: args.commands, + functions: args.functions, + }, + }); + } + } + + return of(generateEvent); + }) + ) + ); + }; +}; + +const correctEsqlMistakes = ({ + content, + logger, +}: { + content: string; + logger: Pick; +}) => { + return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => { + const correction = correctCommonEsqlMistakes(query); + if (correction.isCorrection) { + logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`); + } + return '```esql\n' + correction.output + '\n```'; + }); +}; + +const createFakeTooCall = ( + commands: string[] | undefined, + functions: string[] | undefined +): ToolCall => { + return { + function: { + name: 'request_documentation', + arguments: { + commands, + functions, + }, + }, + toolCallId: generateFakeToolCallId(), + }; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts new file mode 100644 index 0000000000000..ec1d54dd8a26b --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { requestDocumentation } from './request_documentation'; +export { generateEsqlTask } from './generate_esql'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts new file mode 100644 index 0000000000000..05f454c044d31 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { isEmpty } from 'lodash'; +import { InferenceClient, withoutOutputUpdateEvents } from '../../..'; +import { Message } from '../../../../common'; +import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools'; +import { requestDocumentationSchema } from './shared'; + +export const requestDocumentation = ({ + outputApi, + system, + messages, + connectorId, + toolOptions: { tools, toolChoice }, +}: { + outputApi: InferenceClient['output']; + system: string; + messages: Message[]; + connectorId: string; + toolOptions: ToolOptions; +}) => { + const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; + + return outputApi('request_documentation', { + connectorId, + system, + previousMessages: messages, + input: `Based on the previous conversation, request documentation + from the ES|QL handbook to help you get the right information + needed to generate a query. + + Examples for functions and commands: + - Do you need to group data? Request \`STATS\`. + - Extract data? Request \`DISSECT\` AND \`GROK\`. + - Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. + + ${ + hasTools + ? `### Tools + + The following tools will be available to be called in the step after this. + + \`\`\`json + ${JSON.stringify({ + tools, + toolChoice, + })} + \`\`\`` + : '' + } + `, + schema: requestDocumentationSchema, + }).pipe(withoutOutputUpdateEvents()); +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts new file mode 100644 index 0000000000000..f0fc796173b23 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ToolSchema } from '../../../../common'; + +export const requestDocumentationSchema = { + type: 'object', + properties: { + commands: { + type: 'array', + items: { + type: 'string', + }, + description: + 'ES|QL source and processing commands you want to analyze before generating the query.', + }, + functions: { + type: 'array', + items: { + type: 'string', + }, + description: 'ES|QL functions you want to analyze before generating the query.', + }, + }, +} satisfies ToolSchema; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts index 70044e00da484..b81f19b80f95e 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { once } from 'lodash'; import { loadData, type EsqlDocData, type EsqlDocEntry } from './load_data'; import { tryResolveAlias } from './aliases'; @@ -32,6 +33,8 @@ interface GetDocsOptions { addSuggestions?: boolean; } +const loadDataOnce = once(loadData); + const overviewEntries = ['SYNTAX', 'OVERVIEW', 'OPERATORS']; export class EsqlDocumentBase { @@ -39,7 +42,7 @@ export class EsqlDocumentBase { private docRecords: Record; static async load(): Promise { - const data = await loadData(); + const data = await loadDataOnce(); return new EsqlDocumentBase(data); } diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts index 4cbcb6aba7d9f..50854d3af7fd8 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts @@ -5,317 +5,5 @@ * 2.0. */ -import type { Logger } from '@kbn/logging'; -import { isEmpty, once } from 'lodash'; -import { Observable, from, map, merge, of, switchMap } from 'rxjs'; -import { ToolSchema, generateFakeToolCallId, isChatCompletionMessageEvent } from '../../../common'; -import { - ChatCompletionChunkEvent, - ChatCompletionMessageEvent, - Message, - MessageRole, -} from '../../../common/chat_complete'; -import { ToolChoiceType, type ToolOptions, ToolCall } from '../../../common/chat_complete/tools'; -import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events'; -import { OutputCompleteEvent, OutputEventType } from '../../../common/output'; -import { withoutOutputUpdateEvents } from '../../../common/output/without_output_update_events'; -import { INLINE_ESQL_QUERY_REGEX } from '../../../common/tasks/nl_to_esql/constants'; -import { correctCommonEsqlMistakes } from '../../../common/tasks/nl_to_esql/correct_common_esql_mistakes'; -import type { InferenceClient } from '../../types'; -import { EsqlDocumentBase } from './doc_base'; - -type NlToEsqlTaskEvent = - | OutputCompleteEvent< - 'request_documentation', - { keywords: string[]; requestedDocumentation: Record } - > - | ChatCompletionChunkEvent - | ChatCompletionMessageEvent; - -const loadDocBase = once(() => EsqlDocumentBase.load()); - -const requestDocumentationSchema = { - type: 'object', - properties: { - commands: { - type: 'array', - items: { - type: 'string', - }, - description: - 'ES|QL source and processing commands you want to analyze before generating the query.', - }, - functions: { - type: 'array', - items: { - type: 'string', - }, - description: 'ES|QL functions you want to analyze before generating the query.', - }, - }, -} satisfies ToolSchema; - -export function naturalLanguageToEsql({ - client, - connectorId, - tools, - toolChoice, - logger, - ...rest -}: { - client: Pick; - connectorId: string; - logger: Pick; -} & TToolOptions & - ({ input: string } | { messages: Message[] })): Observable> { - const messages: Message[] = - 'input' in rest ? [{ role: MessageRole.User, content: rest.input }] : rest.messages; - - return from(loadDocBase()).pipe( - switchMap((docBase) => { - const systemMessage = docBase.getSystemMessage(); - - const askLlmToRespond = generateEsqlTask({ - connectorId, - chatCompleteApi: client.chatComplete, - messages, - docBase, - logger, - systemMessage, - toolOptions: { - tools, - toolChoice, - }, - }); - - return requestDocumentation({ - connectorId, - outputApi: client.output, - messages, - system: systemMessage, - toolOptions: { - tools, - toolChoice, - }, - }).pipe( - switchMap((documentationEvent) => { - return askLlmToRespond({ - documentationRequest: { - commands: documentationEvent.output.commands, - functions: documentationEvent.output.functions, - }, - }); - }) - ); - }) - ); -} - -const generateEsqlTask = ({ - chatCompleteApi, - connectorId, - systemMessage, - messages, - toolOptions: { tools, toolChoice }, - docBase, - logger, -}: { - connectorId: string; - systemMessage: string; - messages: Message[]; - toolOptions: ToolOptions; - chatCompleteApi: InferenceClient['chatComplete']; - docBase: EsqlDocumentBase; - logger: Pick; -}) => { - return function askLlmToRespond({ - documentationRequest: { commands, functions }, - }: { - documentationRequest: { commands?: string[]; functions?: string[] }; - }): Observable> { - const keywords = [...(commands ?? []), ...(functions ?? [])]; - const requestedDocumentation = docBase.getDocumentation(keywords); - const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); - - return merge( - of< - OutputCompleteEvent< - 'request_documentation', - { keywords: string[]; requestedDocumentation: Record } - > - >({ - type: OutputEventType.OutputComplete, - id: 'request_documentation', - output: { - keywords, - requestedDocumentation, - }, - content: '', - }), - chatCompleteApi({ - connectorId, - system: `${systemMessage} - - # Current task - - Your current task is to respond to the user's question. If there is a tool - suitable for answering the user's question, use that tool, preferably - with a natural language reply included. - - Format any ES|QL query as follows: - \`\`\`esql - - \`\`\` - - When generating ES|QL, it is VERY important that you only use commands and functions present in the - requested documentation, and follow the syntax as described in the documentation and its examples. - Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and - do not try to guess parameters or syntax based on other query languages. - - If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform - the user. DO NOT invent capabilities not described in the documentation just to provide - a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent - workarounds based on other languages. - - When converting queries from one language to ES|QL, make sure that the functions are available - and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE. - `, - messages: [ - ...messages, - { - role: MessageRole.Assistant, - content: null, - toolCalls: [fakeRequestDocsToolCall], - }, - { - role: MessageRole.Tool, - response: { - documentation: requestedDocumentation, - }, - toolCallId: fakeRequestDocsToolCall.toolCallId, - }, - ], - toolChoice, - tools: { - ...tools, - request_documentation: { - description: 'Request additional ES|QL documentation if needed', - schema: requestDocumentationSchema, - }, - }, - }).pipe( - withoutTokenCountEvents(), - map((generateEvent) => { - if (isChatCompletionMessageEvent(generateEvent)) { - return { - ...generateEvent, - content: generateEvent.content - ? correctEsqlMistakes({ content: generateEvent.content, logger }) - : generateEvent.content, - }; - } - - return generateEvent; - }), - switchMap((generateEvent) => { - if (isChatCompletionMessageEvent(generateEvent)) { - const onlyToolCall = - generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; - - if (onlyToolCall?.function.name === 'request_documentation') { - const args = onlyToolCall.function.arguments; - - return askLlmToRespond({ - documentationRequest: { - commands: args.commands, - functions: args.functions, - }, - }); - } - } - - return of(generateEvent); - }) - ) - ); - }; -}; - -const correctEsqlMistakes = ({ - content, - logger, -}: { - content: string; - logger: Pick; -}) => { - return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => { - const correction = correctCommonEsqlMistakes(query); - if (correction.isCorrection) { - logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`); - } - return '```esql\n' + correction.output + '\n```'; - }); -}; - -const requestDocumentation = ({ - outputApi, - system, - messages, - connectorId, - toolOptions: { tools, toolChoice }, -}: { - outputApi: InferenceClient['output']; - system: string; - messages: Message[]; - connectorId: string; - toolOptions: ToolOptions; -}) => { - const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; - - return outputApi('request_documentation', { - connectorId, - system, - previousMessages: messages, - input: `Based on the previous conversation, request documentation - from the ES|QL handbook to help you get the right information - needed to generate a query. - - Examples for functions and commands: - - Do you need to group data? Request \`STATS\`. - - Extract data? Request \`DISSECT\` AND \`GROK\`. - - Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. - - ${ - hasTools - ? `### Tools - - The following tools will be available to be called in the step after this. - - \`\`\`json - ${JSON.stringify({ - tools, - toolChoice, - })} - \`\`\`` - : '' - } - `, - schema: requestDocumentationSchema, - }).pipe(withoutOutputUpdateEvents()); -}; - -const createFakeTooCall = ( - commands: string[] | undefined, - functions: string[] | undefined -): ToolCall => { - return { - function: { - name: 'request_documentation', - arguments: { - commands, - functions, - }, - }, - toolCallId: generateFakeToolCallId(), - }; -}; +export { naturalLanguageToEsql } from './task'; +export type { NlToEsqlTaskEvent, NlToEsqlTaskParams } from './types'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts new file mode 100644 index 0000000000000..04b879351cc54 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { once } from 'lodash'; +import { Observable, from, switchMap } from 'rxjs'; +import { Message, MessageRole } from '../../../common/chat_complete'; +import type { ToolOptions } from '../../../common/chat_complete/tools'; +import { EsqlDocumentBase } from './doc_base'; +import { requestDocumentation, generateEsqlTask } from './actions'; +import { NlToEsqlTaskParams, NlToEsqlTaskEvent } from './types'; + +const loadDocBase = once(() => EsqlDocumentBase.load()); + +export function naturalLanguageToEsql({ + client, + connectorId, + tools, + toolChoice, + logger, + ...rest +}: NlToEsqlTaskParams): Observable> { + return from(loadDocBase()).pipe( + switchMap((docBase) => { + const systemMessage = docBase.getSystemMessage(); + const messages: Message[] = + 'input' in rest ? [{ role: MessageRole.User, content: rest.input }] : rest.messages; + + const askLlmToRespond = generateEsqlTask({ + connectorId, + chatCompleteApi: client.chatComplete, + messages, + docBase, + logger, + systemMessage, + toolOptions: { + tools, + toolChoice, + }, + }); + + return requestDocumentation({ + connectorId, + outputApi: client.output, + messages, + system: systemMessage, + toolOptions: { + tools, + toolChoice, + }, + }).pipe( + switchMap((documentationEvent) => { + return askLlmToRespond({ + documentationRequest: { + commands: documentationEvent.output.commands, + functions: documentationEvent.output.functions, + }, + }); + }) + ); + }) + ); +} diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts new file mode 100644 index 0000000000000..c460f029b147e --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Logger } from '@kbn/logging'; +import type { + ChatCompletionChunkEvent, + ChatCompletionMessageEvent, + Message, +} from '../../../common/chat_complete'; +import type { ToolOptions } from '../../../common/chat_complete/tools'; +import type { OutputCompleteEvent } from '../../../common/output'; +import type { InferenceClient } from '../../types'; + +export type NlToEsqlTaskEvent = + | OutputCompleteEvent< + 'request_documentation', + { keywords: string[]; requestedDocumentation: Record } + > + | ChatCompletionChunkEvent + | ChatCompletionMessageEvent; + +export type NlToEsqlTaskParams = { + client: Pick; + connectorId: string; + logger: Pick; +} & TToolOptions & + ({ input: string } | { messages: Message[] }); From 87f77c0ab0221b2403ff5520be6f0e91d8e2402a Mon Sep 17 00:00:00 2001 From: pgayvallet Date: Sat, 14 Sep 2024 20:20:29 +0200 Subject: [PATCH 4/6] fix types --- .../inference/server/tasks/nl_to_esql/actions/generate_esql.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts index 437665ec69a44..8a111322a8de6 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts @@ -20,6 +20,7 @@ import { OutputCompleteEvent, OutputEventType } from '../../../../common/output' import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants'; import { EsqlDocumentBase } from '../doc_base'; import { requestDocumentationSchema } from './shared'; +import type { NlToEsqlTaskEvent } from '../types'; export const generateEsqlTask = ({ chatCompleteApi, From f2ea19c7e40e5b5e8d4e5031032d6e4d0ca057ee Mon Sep 17 00:00:00 2001 From: pgayvallet Date: Tue, 17 Sep 2024 11:26:30 +0200 Subject: [PATCH 5/6] add basic support for suggestions --- .../nl_to_esql/doc_base/esql_doc_base.ts | 30 +++---------------- .../tasks/nl_to_esql/doc_base/suggestions.ts | 30 +++++++++++++++++++ .../server/tasks/nl_to_esql/doc_base/types.ts | 30 +++++++++++++++++++ 3 files changed, 64 insertions(+), 26 deletions(-) create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts index b81f19b80f95e..403fb2658d407 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts @@ -8,30 +8,8 @@ import { once } from 'lodash'; import { loadData, type EsqlDocData, type EsqlDocEntry } from './load_data'; import { tryResolveAlias } from './aliases'; - -interface GetDocsOptions { - /** - * If true (default), will include general ES|QL documentation entries - * such as the overview, syntax and operators page. - */ - addOverview?: boolean; - /** - * If true (default) will try to resolve aliases for commands. - */ - resolveAliases?: boolean; - - /** - * If true (default) will generate a fake doc page for missing keywords. - * Useful for the LLM to understand that the requested keyword does not exist. - */ - generateMissingKeywordDoc?: boolean; - - /** - * If true (default), additional documentation will be included to help the LLM. - * E.g. for STATS, BUCKET will be included. - */ - addSuggestions?: boolean; -} +import { getSuggestions } from './suggestions'; +import type { GetDocsOptions } from './types'; const loadDataOnce = once(loadData); @@ -73,14 +51,14 @@ export class EsqlDocumentBase { }); if (addSuggestions) { - // TODO + keywords.push(...getSuggestions(keywords)); } if (addOverview) { keywords.push(...overviewEntries); } - return keywords.reduce>((results, keyword) => { + return [...new Set(keywords)].reduce>((results, keyword) => { if (Object.hasOwn(this.docRecords, keyword)) { results[keyword] = this.docRecords[keyword].data; } else if (generateMissingKeywordDoc) { diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts new file mode 100644 index 0000000000000..42ee960301b76 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +type Suggestion = (keywords: string[]) => string[] | undefined; + +const suggestions: Suggestion[] = [ + (keywords) => { + if (keywords.includes('STATS') && keywords.includes('DATE_TRUNC')) { + return ['BUCKET']; + } + }, +]; + +/** + * Based on the list of keywords the model asked to get documentation for, + * Try to provide suggestion on other commands or keywords that may be useful. + * + * E.g. when requesting documentation for `STATS` and `DATE_TRUNC`, suggests `BUCKET` + * + */ +export const getSuggestions = (keywords: string[]): string[] => { + return suggestions.reduce((list, sugg) => { + list.push(...(sugg(keywords) ?? [])); + return list; + }, []); +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts new file mode 100644 index 0000000000000..b5b3a8475c5f5 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export interface GetDocsOptions { + /** + * If true (default), will include general ES|QL documentation entries + * such as the overview, syntax and operators page. + */ + addOverview?: boolean; + /** + * If true (default) will try to resolve aliases for commands. + */ + resolveAliases?: boolean; + + /** + * If true (default) will generate a fake doc page for missing keywords. + * Useful for the LLM to understand that the requested keyword does not exist. + */ + generateMissingKeywordDoc?: boolean; + + /** + * If true (default), additional documentation will be included to help the LLM. + * E.g. for STATS, BUCKET will be included. + */ + addSuggestions?: boolean; +} From 4580b89ecc4fa3fa294996849404c663bb52e7be Mon Sep 17 00:00:00 2001 From: pgayvallet Date: Tue, 17 Sep 2024 15:05:22 +0200 Subject: [PATCH 6/6] fix conversion --- .../common/convert_messages_for_inference.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts index 91d7f00467540..1dc8638626d0b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts @@ -21,7 +21,7 @@ export function convertMessagesForInference(messages: Message[]): InferenceMessa inferenceMessages.push({ role: InferenceMessageRole.Assistant, content: message.message.content ?? null, - ...(message.message.function_call + ...(message.message.function_call?.name ? { toolCalls: [ {