Skip to content

Commit

Permalink
[Security solution] naturalLanguageToEsql Tool added to default ass…
Browse files Browse the repository at this point in the history
…istant graph (elastic#192042)
  • Loading branch information
stephmilovic committed Sep 18, 2024
1 parent d4ee1ca commit 798a26f
Show file tree
Hide file tree
Showing 19 changed files with 129 additions and 6 deletions.
1 change: 1 addition & 0 deletions x-pack/plugins/elastic_assistant/kibana.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"ml",
"taskManager",
"licensing",
"inference",
"spaces",
"security"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export const createMockClients = () => {
getAIAssistantAnonymizationFieldsDataClient: dataClientMock.create(),
getSpaceId: jest.fn(),
getCurrentUser: jest.fn(),
inference: jest.fn(),
},
savedObjectsClient: core.savedObjects.client,

Expand Down Expand Up @@ -130,6 +131,7 @@ const createElasticAssistantRequestContextMock = (
getCurrentUser: jest.fn(),
getServerBasePath: jest.fn(),
getSpaceId: jest.fn(),
inference: { getClient: jest.fn() },
core: clients.core,
telemetry: clients.elasticAssistant.telemetry,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain';
import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common';
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store';
Expand Down Expand Up @@ -47,6 +48,7 @@ export interface AgentExecutorParams<T extends boolean> {
langChainMessages: BaseMessage[];
llmType?: string;
logger: Logger;
inference: InferenceServerStart;
onNewReplacements?: (newReplacements: Replacements) => void;
replacements: Replacements;
isStream?: T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
dataClients,
esClient,
esStore,
inference,
langChainMessages,
llmType,
logger: parentLogger,
Expand Down Expand Up @@ -107,7 +108,9 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
alertsIndexPattern,
anonymizationFields,
chain,
connectorId,
esClient,
inference,
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
logger,
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/elastic_assistant/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export class ElasticAssistantPlugin

return {
actions: plugins.actions,
inference: plugins.inference,
getRegisteredFeatures: (pluginName: string) => {
return appContextService.getRegisteredFeatures(pluginName);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const formatAssistantToolParams = ({
ExecuteConnectorRequestBody | AttackDiscoveryPostRequestBody
>;
size: number;
}): AssistantToolParams => ({
}): Omit<AssistantToolParams, 'connectorId' | 'inference'> => ({
alertsIndexPattern,
anonymizationFields: [...(anonymizationFields ?? []), ...REQUIRED_FOR_ATTACK_DISCOVERY],
isEnabledKnowledgeBase: false, // not required for attack discovery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export const chatCompleteRoute = (
const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']);
const logger: Logger = ctx.elasticAssistant.logger;
telemetry = ctx.elasticAssistant.telemetry;
const inference = ctx.elasticAssistant.inference;

// Perform license and authenticated user checks
const checkResponse = performChecks({
Expand Down Expand Up @@ -195,6 +196,7 @@ export const chatCompleteRoute = (
context: ctx,
getElser,
logger,
inference,
messages: messages ?? [],
onLlmResponse,
onNewReplacements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ export const postEvaluateRoute = (
// Default ELSER model
const elserId = await getElser();

const inference = ctx.elasticAssistant.inference;

// Data clients
const anonymizationFieldsDataClient =
(await assistantContext.getAIAssistantAnonymizationFieldsDataClient()) ?? undefined;
Expand Down Expand Up @@ -260,6 +262,8 @@ export const postEvaluateRoute = (
alertsIndexPattern,
// onNewReplacements,
replacements,
inference,
connectorId: connector.id,
size,
};

Expand Down
4 changes: 4 additions & 0 deletions x-pack/plugins/elastic_assistant/server/routes/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { AwaitedProperties, PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities';
import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base';
import { FindResponse } from '../ai_assistant_data_clients/find';
import { EsPromptsSchema } from '../ai_assistant_data_clients/prompts/types';
Expand Down Expand Up @@ -321,6 +322,7 @@ export interface LangChainExecuteParams {
telemetry: AnalyticsServiceSetup;
actionTypeId: string;
connectorId: string;
inference: InferenceServerStart;
conversationId?: string;
context: AwaitedProperties<
Pick<ElasticAssistantRequestHandlerContext, 'elasticAssistant' | 'licensing' | 'core'>
Expand Down Expand Up @@ -349,6 +351,7 @@ export const langChainExecute = async ({
connectorId,
context,
actionsClient,
inference,
request,
logger,
conversationId,
Expand Down Expand Up @@ -418,6 +421,7 @@ export const langChainExecute = async ({
connectorId,
esClient,
esStore,
inference,
isStream,
llmType: getLlmType(actionTypeId),
langChainMessages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export const postActionsConnectorExecuteRoute = (

// get the actions plugin start contract from the request context:
const actions = ctx.elasticAssistant.actions;
const inference = ctx.elasticAssistant.inference;
const actionsClient = await actions.getActionsClientWithRequest(request);

const conversationsDataClient =
Expand Down Expand Up @@ -132,6 +133,7 @@ export const postActionsConnectorExecuteRoute = (
context: ctx,
getElser,
logger,
inference,
messages: (newMessage ? [newMessage] : messages) ?? [],
onLlmResponse,
onNewReplacements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ export class RequestContextFactory implements IRequestContextFactory {
return appContextService.getRegisteredFeatures(pluginName);
},

inference: startPlugins.inference,

telemetry: core.analytics,

// Note: Due to plugin lifecycle and feature flag registration timing, we need to pass in the feature flag here
Expand Down
9 changes: 9 additions & 0 deletions x-pack/plugins/elastic_assistant/server/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server';

import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AttackDiscoveryDataClient } from './ai_assistant_data_clients/attack_discovery';
import { AIAssistantConversationsDataClient } from './ai_assistant_data_clients/conversations';
import type { GetRegisteredFeatures, GetRegisteredTools } from './services/app_context';
Expand All @@ -64,6 +65,10 @@ export interface ElasticAssistantPluginStart {
* Actions plugin start contract.
*/
actions: ActionsPluginStart;
/**
* Inference plugin start contract.
*/
inference: InferenceServerStart;
/**
* Register features to be used by the elastic assistant.
*
Expand Down Expand Up @@ -104,6 +109,7 @@ export interface ElasticAssistantPluginSetupDependencies {
}
export interface ElasticAssistantPluginStartDependencies {
actions: ActionsPluginStart;
inference: InferenceServerStart;
spaces?: SpacesPluginStart;
security: SecurityServiceStart;
licensing: LicensingPluginStart;
Expand All @@ -125,6 +131,7 @@ export interface ElasticAssistantApiRequestHandlerContext {
getAttackDiscoveryDataClient: () => Promise<AttackDiscoveryDataClient | null>;
getAIAssistantPromptsDataClient: () => Promise<AIAssistantDataClient | null>;
getAIAssistantAnonymizationFieldsDataClient: () => Promise<AIAssistantDataClient | null>;
inference: InferenceServerStart;
telemetry: AnalyticsServiceSetup;
}
/**
Expand Down Expand Up @@ -228,7 +235,9 @@ export type AssistantToolLlm =
export interface AssistantToolParams {
alertsIndexPattern?: string;
anonymizationFields?: AnonymizationFieldResponse[];
inference?: InferenceServerStart;
isEnabledKnowledgeBase: boolean;
connectorId?: string;
chain?: RetrievalQAChain;
esClient: ElasticsearchClient;
kbDataClient?: AIAssistantKnowledgeBaseDataClient;
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/elastic_assistant/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"@kbn/apm-utils",
"@kbn/std",
"@kbn/zod",
"@kbn/inference-plugin"
],
"exclude": [
"target/**/*",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ export const allowedExperimentalValues = Object.freeze({
*/
assistantBedrockChat: true,

/**
* Enables the NaturalLanguageESQLTool and disables the ESQLKnowledgeBaseTool, introduced in `8.16.0`.
*/
assistantNaturalLanguageESQLTool: false,

/**
* Enables the Managed User section inside the new user details flyout.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { DynamicStructuredTool } from '@langchain/core/tools';
import { z } from '@kbn/zod';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { APP_UI_ID } from '../../../../common';

export type ESQLToolParams = AssistantToolParams;

const TOOL_NAME = 'NaturalLanguageESQLTool';

const toolDetails = {
id: 'nl-to-esql-tool',
name: TOOL_NAME,
description: `You MUST use the "${TOOL_NAME}" function when the user wants to:
- run any arbitrary query
- breakdown or filter ES|QL queries that are displayed on the current page
- convert queries from another language to ES|QL
- asks general questions about ES|QL
DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries or explain anything about the ES|QL query language yourself.
DO NOT UNDER ANY CIRCUMSTANCES try to correct an ES|QL query yourself - always use the "${TOOL_NAME}" function for this.
Even if the "${TOOL_NAME}" function was used before that, follow it up with the "${TOOL_NAME}" function. If a query fails, do not attempt to correct it yourself. Again you should call the "${TOOL_NAME}" function,
even if it has been called before.`,
};

export const NL_TO_ESQL_TOOL: AssistantTool = {
...toolDetails,
sourceRegister: APP_UI_ID,
isSupported: (params: ESQLToolParams): params is ESQLToolParams => {
const { chain, isEnabledKnowledgeBase, modelExists } = params;
return isEnabledKnowledgeBase && modelExists && chain != null;
},
getTool(params: ESQLToolParams) {
if (!this.isSupported(params)) return null;

const { connectorId, inference, logger, request } = params as ESQLToolParams;
if (inference == null || connectorId == null) return null;

const callNaturalLanguageToEsql = async (question: string) => {
return lastValueFrom(
naturalLanguageToEsql({
client: inference.getClient({ request }),
connectorId,
input: question,
logger: {
debug: (source) => {
logger.debug(typeof source === 'function' ? source() : source);
},
},
})
);
};

return new DynamicStructuredTool({
name: toolDetails.name,
description: toolDetails.description,
schema: z.object({
question: z.string().describe(`The user's exact question about ESQL`),
}),
func: async (input) => {
const generateEvent = await callNaturalLanguageToEsql(input.question);
const answer = generateEvent.content ?? 'An error occurred in the tool';

logger.debug(`Received response from NL to ESQL tool: ${answer}`);
return answer;
},
tags: ['esql', 'query-generation', 'knowledge-base'],
// TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts
}) as unknown as DynamicStructuredTool;
},
};
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ describe('getAssistantTools', () => {
});

it('should return an array of applicable tools', () => {
const tools = getAssistantTools();
const tools = getAssistantTools(true);

const minExpectedTools = 3; // 3 tools are currently implemented

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@

import type { AssistantTool } from '@kbn/elastic-assistant-plugin/server';

import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool';
import { ESQL_KNOWLEDGE_BASE_TOOL } from './esql_language_knowledge_base/esql_language_knowledge_base_tool';
import { NL_TO_ESQL_TOOL } from './esql_language_knowledge_base/nl_to_esql_tool';
import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool';
import { OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL } from './open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool';
import { ATTACK_DISCOVERY_TOOL } from './attack_discovery/attack_discovery_tool';
import { KNOWLEDGE_BASE_RETRIEVAL_TOOL } from './knowledge_base/knowledge_base_retrieval_tool';
import { KNOWLEDGE_BASE_WRITE_TOOL } from './knowledge_base/knowledge_base_write_tool';

export const getAssistantTools = (): AssistantTool[] => [
export const getAssistantTools = (naturalLanguageESQLToolEnabled: boolean): AssistantTool[] => [
ALERT_COUNTS_TOOL,
ATTACK_DISCOVERY_TOOL,
ESQL_KNOWLEDGE_BASE_TOOL,
naturalLanguageESQLToolEnabled ? NL_TO_ESQL_TOOL : ESQL_KNOWLEDGE_BASE_TOOL,
KNOWLEDGE_BASE_RETRIEVAL_TOOL,
KNOWLEDGE_BASE_WRITE_TOOL,
OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL,
Expand Down
5 changes: 4 additions & 1 deletion x-pack/plugins/security_solution/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,10 @@ export class Plugin implements ISecuritySolutionPlugin {
this.licensing$ = plugins.licensing.license$;

// Assistant Tool and Feature Registration
plugins.elasticAssistant.registerTools(APP_UI_ID, getAssistantTools());
plugins.elasticAssistant.registerTools(
APP_UI_ID,
getAssistantTools(config.experimentalFeatures.assistantNaturalLanguageESQLTool)
);
plugins.elasticAssistant.registerFeatures(APP_UI_ID, {
assistantBedrockChat: config.experimentalFeatures.assistantBedrockChat,
assistantKnowledgeBaseByDefault: config.experimentalFeatures.assistantKnowledgeBaseByDefault,
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/security_solution/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,6 @@
"@kbn/presentation-publishing",
"@kbn/entityManager-plugin",
"@kbn/entities-schema",
"@kbn/inference-plugin",
]
}

0 comments on commit 798a26f

Please sign in to comment.