From 798a26f93ce0501ed8fe72e6de94fd7454315d8e Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 18 Sep 2024 15:05:41 -0600 Subject: [PATCH] [Security solution] `naturalLanguageToEsql` Tool added to default assistant graph (#192042) --- x-pack/plugins/elastic_assistant/kibana.jsonc | 1 + .../server/__mocks__/request_context.ts | 2 + .../server/lib/langchain/executors/types.ts | 2 + .../graphs/default_assistant_graph/index.ts | 3 + .../elastic_assistant/server/plugin.ts | 1 + .../server/routes/attack_discovery/helpers.ts | 2 +- .../server/routes/chat/chat_complete_route.ts | 2 + .../server/routes/evaluate/post_evaluate.ts | 4 + .../server/routes/helpers.ts | 4 + .../routes/post_actions_connector_execute.ts | 2 + .../server/routes/request_context_factory.ts | 2 + .../plugins/elastic_assistant/server/types.ts | 9 +++ .../plugins/elastic_assistant/tsconfig.json | 1 + .../common/experimental_features.ts | 5 ++ .../nl_to_esql_tool.ts | 80 +++++++++++++++++++ .../server/assistant/tools/index.test.ts | 2 +- .../server/assistant/tools/index.ts | 7 +- .../security_solution/server/plugin.ts | 5 +- .../plugins/security_solution/tsconfig.json | 1 + 19 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts diff --git a/x-pack/plugins/elastic_assistant/kibana.jsonc b/x-pack/plugins/elastic_assistant/kibana.jsonc index 9879ba274d209..8a3e0725c782a 100644 --- a/x-pack/plugins/elastic_assistant/kibana.jsonc +++ b/x-pack/plugins/elastic_assistant/kibana.jsonc @@ -13,6 +13,7 @@ "ml", "taskManager", "licensing", + "inference", "spaces", "security" ] diff --git a/x-pack/plugins/elastic_assistant/server/__mocks__/request_context.ts b/x-pack/plugins/elastic_assistant/server/__mocks__/request_context.ts index adface75dbced..6ae7ec9e4469b 100644 --- a/x-pack/plugins/elastic_assistant/server/__mocks__/request_context.ts +++ b/x-pack/plugins/elastic_assistant/server/__mocks__/request_context.ts @@ -45,6 +45,7 @@ export const createMockClients = () => { getAIAssistantAnonymizationFieldsDataClient: dataClientMock.create(), getSpaceId: jest.fn(), getCurrentUser: jest.fn(), + inference: jest.fn(), }, savedObjectsClient: core.savedObjects.client, @@ -130,6 +131,7 @@ const createElasticAssistantRequestContextMock = ( getCurrentUser: jest.fn(), getServerBasePath: jest.fn(), getSpaceId: jest.fn(), + inference: { getClient: jest.fn() }, core: clients.core, telemetry: clients.elasticAssistant.telemetry, }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 060616d280efe..2395221ea14b3 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -14,6 +14,7 @@ import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common'; import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { PublicMethodsOf } from '@kbn/utility-types'; +import type { InferenceServerStart } from '@kbn/inference-plugin/server'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; @@ -47,6 +48,7 @@ export interface AgentExecutorParams { langChainMessages: BaseMessage[]; llmType?: string; logger: Logger; + inference: InferenceServerStart; onNewReplacements?: (newReplacements: Replacements) => void; replacements: Replacements; isStream?: T; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 0222720d95e37..8cc676cd851a7 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -38,6 +38,7 @@ export const callAssistantGraph: AgentExecutor = async ({ dataClients, esClient, esStore, + inference, langChainMessages, llmType, logger: parentLogger, @@ -107,7 +108,9 @@ export const callAssistantGraph: AgentExecutor = async ({ alertsIndexPattern, anonymizationFields, chain, + connectorId, esClient, + inference, isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, logger, diff --git a/x-pack/plugins/elastic_assistant/server/plugin.ts b/x-pack/plugins/elastic_assistant/server/plugin.ts index a8fc3a7de570c..4386b95c3fa7a 100755 --- a/x-pack/plugins/elastic_assistant/server/plugin.ts +++ b/x-pack/plugins/elastic_assistant/server/plugin.ts @@ -112,6 +112,7 @@ export class ElasticAssistantPlugin return { actions: plugins.actions, + inference: plugins.inference, getRegisteredFeatures: (pluginName: string) => { return appContextService.getRegisteredFeatures(pluginName); }, diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts index 1de3b86e74deb..cccf37aff48e0 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts @@ -149,7 +149,7 @@ const formatAssistantToolParams = ({ ExecuteConnectorRequestBody | AttackDiscoveryPostRequestBody >; size: number; -}): AssistantToolParams => ({ +}): Omit => ({ alertsIndexPattern, anonymizationFields: [...(anonymizationFields ?? []), ...REQUIRED_FOR_ATTACK_DISCOVERY], isEnabledKnowledgeBase: false, // not required for attack discovery diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index b8f75e2376863..dd90241809015 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -66,6 +66,7 @@ export const chatCompleteRoute = ( const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); const logger: Logger = ctx.elasticAssistant.logger; telemetry = ctx.elasticAssistant.telemetry; + const inference = ctx.elasticAssistant.inference; // Perform license and authenticated user checks const checkResponse = performChecks({ @@ -195,6 +196,7 @@ export const chatCompleteRoute = ( context: ctx, getElser, logger, + inference, messages: messages ?? [], onLlmResponse, onNewReplacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 6cc8853d119dd..27ea4eea46d45 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -150,6 +150,8 @@ export const postEvaluateRoute = ( // Default ELSER model const elserId = await getElser(); + const inference = ctx.elasticAssistant.inference; + // Data clients const anonymizationFieldsDataClient = (await assistantContext.getAIAssistantAnonymizationFieldsDataClient()) ?? undefined; @@ -260,6 +262,8 @@ export const postEvaluateRoute = ( alertsIndexPattern, // onNewReplacements, replacements, + inference, + connectorId: connector.id, size, }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index d457f9c88bf69..1e8acf4bee885 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -28,6 +28,7 @@ import { AwaitedProperties, PublicMethodsOf } from '@kbn/utility-types'; import { ActionsClient } from '@kbn/actions-plugin/server'; import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; +import type { InferenceServerStart } from '@kbn/inference-plugin/server'; import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base'; import { FindResponse } from '../ai_assistant_data_clients/find'; import { EsPromptsSchema } from '../ai_assistant_data_clients/prompts/types'; @@ -321,6 +322,7 @@ export interface LangChainExecuteParams { telemetry: AnalyticsServiceSetup; actionTypeId: string; connectorId: string; + inference: InferenceServerStart; conversationId?: string; context: AwaitedProperties< Pick @@ -349,6 +351,7 @@ export const langChainExecute = async ({ connectorId, context, actionsClient, + inference, request, logger, conversationId, @@ -418,6 +421,7 @@ export const langChainExecute = async ({ connectorId, esClient, esStore, + inference, isStream, llmType: getLlmType(actionTypeId), langChainMessages, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 0988d9e5f8973..97ff073ecd5cc 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -92,6 +92,7 @@ export const postActionsConnectorExecuteRoute = ( // get the actions plugin start contract from the request context: const actions = ctx.elasticAssistant.actions; + const inference = ctx.elasticAssistant.inference; const actionsClient = await actions.getActionsClientWithRequest(request); const conversationsDataClient = @@ -132,6 +133,7 @@ export const postActionsConnectorExecuteRoute = ( context: ctx, getElser, logger, + inference, messages: (newMessage ? [newMessage] : messages) ?? [], onLlmResponse, onNewReplacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/request_context_factory.ts b/x-pack/plugins/elastic_assistant/server/routes/request_context_factory.ts index 3d004994b3236..e861fa6ffe279 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/request_context_factory.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/request_context_factory.ts @@ -79,6 +79,8 @@ export class RequestContextFactory implements IRequestContextFactory { return appContextService.getRegisteredFeatures(pluginName); }, + inference: startPlugins.inference, + telemetry: core.analytics, // Note: Due to plugin lifecycle and feature flag registration timing, we need to pass in the feature flag here diff --git a/x-pack/plugins/elastic_assistant/server/types.ts b/x-pack/plugins/elastic_assistant/server/types.ts index 6885b07a42c30..ca0010ae1e6b8 100755 --- a/x-pack/plugins/elastic_assistant/server/types.ts +++ b/x-pack/plugins/elastic_assistant/server/types.ts @@ -45,6 +45,7 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server'; +import type { InferenceServerStart } from '@kbn/inference-plugin/server'; import { AttackDiscoveryDataClient } from './ai_assistant_data_clients/attack_discovery'; import { AIAssistantConversationsDataClient } from './ai_assistant_data_clients/conversations'; import type { GetRegisteredFeatures, GetRegisteredTools } from './services/app_context'; @@ -64,6 +65,10 @@ export interface ElasticAssistantPluginStart { * Actions plugin start contract. */ actions: ActionsPluginStart; + /** + * Inference plugin start contract. + */ + inference: InferenceServerStart; /** * Register features to be used by the elastic assistant. * @@ -104,6 +109,7 @@ export interface ElasticAssistantPluginSetupDependencies { } export interface ElasticAssistantPluginStartDependencies { actions: ActionsPluginStart; + inference: InferenceServerStart; spaces?: SpacesPluginStart; security: SecurityServiceStart; licensing: LicensingPluginStart; @@ -125,6 +131,7 @@ export interface ElasticAssistantApiRequestHandlerContext { getAttackDiscoveryDataClient: () => Promise; getAIAssistantPromptsDataClient: () => Promise; getAIAssistantAnonymizationFieldsDataClient: () => Promise; + inference: InferenceServerStart; telemetry: AnalyticsServiceSetup; } /** @@ -228,7 +235,9 @@ export type AssistantToolLlm = export interface AssistantToolParams { alertsIndexPattern?: string; anonymizationFields?: AnonymizationFieldResponse[]; + inference?: InferenceServerStart; isEnabledKnowledgeBase: boolean; + connectorId?: string; chain?: RetrievalQAChain; esClient: ElasticsearchClient; kbDataClient?: AIAssistantKnowledgeBaseDataClient; diff --git a/x-pack/plugins/elastic_assistant/tsconfig.json b/x-pack/plugins/elastic_assistant/tsconfig.json index c210253af04a4..747a58ed930d3 100644 --- a/x-pack/plugins/elastic_assistant/tsconfig.json +++ b/x-pack/plugins/elastic_assistant/tsconfig.json @@ -48,6 +48,7 @@ "@kbn/apm-utils", "@kbn/std", "@kbn/zod", + "@kbn/inference-plugin" ], "exclude": [ "target/**/*", diff --git a/x-pack/plugins/security_solution/common/experimental_features.ts b/x-pack/plugins/security_solution/common/experimental_features.ts index 121c8d6a97a1a..4147404e940c1 100644 --- a/x-pack/plugins/security_solution/common/experimental_features.ts +++ b/x-pack/plugins/security_solution/common/experimental_features.ts @@ -113,6 +113,11 @@ export const allowedExperimentalValues = Object.freeze({ */ assistantBedrockChat: true, + /** + * Enables the NaturalLanguageESQLTool and disables the ESQLKnowledgeBaseTool, introduced in `8.16.0`. + */ + assistantNaturalLanguageESQLTool: false, + /** * Enables the Managed User section inside the new user details flyout. */ diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts new file mode 100644 index 0000000000000..b5dc209043d5d --- /dev/null +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { DynamicStructuredTool } from '@langchain/core/tools'; +import { z } from '@kbn/zod'; +import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; +import { lastValueFrom } from 'rxjs'; +import { naturalLanguageToEsql } from '@kbn/inference-plugin/server'; +import { APP_UI_ID } from '../../../../common'; + +export type ESQLToolParams = AssistantToolParams; + +const TOOL_NAME = 'NaturalLanguageESQLTool'; + +const toolDetails = { + id: 'nl-to-esql-tool', + name: TOOL_NAME, + description: `You MUST use the "${TOOL_NAME}" function when the user wants to: + - run any arbitrary query + - breakdown or filter ES|QL queries that are displayed on the current page + - convert queries from another language to ES|QL + - asks general questions about ES|QL + + DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries or explain anything about the ES|QL query language yourself. + DO NOT UNDER ANY CIRCUMSTANCES try to correct an ES|QL query yourself - always use the "${TOOL_NAME}" function for this. + + Even if the "${TOOL_NAME}" function was used before that, follow it up with the "${TOOL_NAME}" function. If a query fails, do not attempt to correct it yourself. Again you should call the "${TOOL_NAME}" function, + even if it has been called before.`, +}; + +export const NL_TO_ESQL_TOOL: AssistantTool = { + ...toolDetails, + sourceRegister: APP_UI_ID, + isSupported: (params: ESQLToolParams): params is ESQLToolParams => { + const { chain, isEnabledKnowledgeBase, modelExists } = params; + return isEnabledKnowledgeBase && modelExists && chain != null; + }, + getTool(params: ESQLToolParams) { + if (!this.isSupported(params)) return null; + + const { connectorId, inference, logger, request } = params as ESQLToolParams; + if (inference == null || connectorId == null) return null; + + const callNaturalLanguageToEsql = async (question: string) => { + return lastValueFrom( + naturalLanguageToEsql({ + client: inference.getClient({ request }), + connectorId, + input: question, + logger: { + debug: (source) => { + logger.debug(typeof source === 'function' ? source() : source); + }, + }, + }) + ); + }; + + return new DynamicStructuredTool({ + name: toolDetails.name, + description: toolDetails.description, + schema: z.object({ + question: z.string().describe(`The user's exact question about ESQL`), + }), + func: async (input) => { + const generateEvent = await callNaturalLanguageToEsql(input.question); + const answer = generateEvent.content ?? 'An error occurred in the tool'; + + logger.debug(`Received response from NL to ESQL tool: ${answer}`); + return answer; + }, + tags: ['esql', 'query-generation', 'knowledge-base'], + // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts + }) as unknown as DynamicStructuredTool; + }, +}; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/index.test.ts b/x-pack/plugins/security_solution/server/assistant/tools/index.test.ts index 047c84ceddf3b..b64f34e4b6ee9 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/index.test.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/index.test.ts @@ -13,7 +13,7 @@ describe('getAssistantTools', () => { }); it('should return an array of applicable tools', () => { - const tools = getAssistantTools(); + const tools = getAssistantTools(true); const minExpectedTools = 3; // 3 tools are currently implemented diff --git a/x-pack/plugins/security_solution/server/assistant/tools/index.ts b/x-pack/plugins/security_solution/server/assistant/tools/index.ts index 0e5ea3a8f69d1..181e55353adc7 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/index.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/index.ts @@ -7,17 +7,18 @@ import type { AssistantTool } from '@kbn/elastic-assistant-plugin/server'; -import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool'; import { ESQL_KNOWLEDGE_BASE_TOOL } from './esql_language_knowledge_base/esql_language_knowledge_base_tool'; +import { NL_TO_ESQL_TOOL } from './esql_language_knowledge_base/nl_to_esql_tool'; +import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool'; import { OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL } from './open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool'; import { ATTACK_DISCOVERY_TOOL } from './attack_discovery/attack_discovery_tool'; import { KNOWLEDGE_BASE_RETRIEVAL_TOOL } from './knowledge_base/knowledge_base_retrieval_tool'; import { KNOWLEDGE_BASE_WRITE_TOOL } from './knowledge_base/knowledge_base_write_tool'; -export const getAssistantTools = (): AssistantTool[] => [ +export const getAssistantTools = (naturalLanguageESQLToolEnabled: boolean): AssistantTool[] => [ ALERT_COUNTS_TOOL, ATTACK_DISCOVERY_TOOL, - ESQL_KNOWLEDGE_BASE_TOOL, + naturalLanguageESQLToolEnabled ? NL_TO_ESQL_TOOL : ESQL_KNOWLEDGE_BASE_TOOL, KNOWLEDGE_BASE_RETRIEVAL_TOOL, KNOWLEDGE_BASE_WRITE_TOOL, OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL, diff --git a/x-pack/plugins/security_solution/server/plugin.ts b/x-pack/plugins/security_solution/server/plugin.ts index a46863c78c25e..17f31718070b3 100644 --- a/x-pack/plugins/security_solution/server/plugin.ts +++ b/x-pack/plugins/security_solution/server/plugin.ts @@ -550,7 +550,10 @@ export class Plugin implements ISecuritySolutionPlugin { this.licensing$ = plugins.licensing.license$; // Assistant Tool and Feature Registration - plugins.elasticAssistant.registerTools(APP_UI_ID, getAssistantTools()); + plugins.elasticAssistant.registerTools( + APP_UI_ID, + getAssistantTools(config.experimentalFeatures.assistantNaturalLanguageESQLTool) + ); plugins.elasticAssistant.registerFeatures(APP_UI_ID, { assistantBedrockChat: config.experimentalFeatures.assistantBedrockChat, assistantKnowledgeBaseByDefault: config.experimentalFeatures.assistantKnowledgeBaseByDefault, diff --git a/x-pack/plugins/security_solution/tsconfig.json b/x-pack/plugins/security_solution/tsconfig.json index e6ec61c44d89e..e33ecb852b5b1 100644 --- a/x-pack/plugins/security_solution/tsconfig.json +++ b/x-pack/plugins/security_solution/tsconfig.json @@ -225,5 +225,6 @@ "@kbn/presentation-publishing", "@kbn/entityManager-plugin", "@kbn/entities-schema", + "@kbn/inference-plugin", ] }