From 65fdf44096a55aa73bc3eab827d44d4856098800 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 24 Jun 2024 13:31:07 +0200 Subject: [PATCH 01/55] BedrockChat --- package.json | 27 +- x-pack/packages/kbn-langchain/server/index.ts | 2 + .../server/language_models/bedrock_chat.ts | 120 +++++ .../server/language_models/index.ts | 1 + .../utils/bedrock/anthropic.ts | 226 ++++++++ .../language_models/utils/bedrock/index.ts | 486 ++++++++++++++++++ .../application/connector/methods/get/get.ts | 1 + .../execute_custom_llm_chain/index.ts | 17 +- .../server/lib/langchain/executors/types.ts | 2 + .../graphs/default_assistant_graph/index.ts | 23 +- .../routes/post_actions_connector_execute.ts | 22 +- .../elastic_assistant/server/routes/utils.ts | 14 + .../plugins/elastic_assistant/server/types.ts | 8 +- .../integration_assistant/kibana.jsonc | 1 + .../categorization/categorization.test.ts | 7 +- .../graphs/categorization/categorization.ts | 11 +- .../graphs/categorization/errors.test.ts | 7 +- .../server/graphs/categorization/errors.ts | 11 +- .../graphs/categorization/graph.test.ts | 7 +- .../server/graphs/categorization/graph.ts | 7 +- .../graphs/categorization/invalid.test.ts | 7 +- .../server/graphs/categorization/invalid.ts | 8 +- .../graphs/categorization/review.test.ts | 7 +- .../server/graphs/categorization/review.ts | 11 +- .../server/graphs/ecs/duplicates.test.ts | 7 +- .../server/graphs/ecs/duplicates.ts | 11 +- .../server/graphs/ecs/graph.test.ts | 7 +- .../server/graphs/ecs/graph.ts | 7 +- .../server/graphs/ecs/invalid.test.ts | 7 +- .../server/graphs/ecs/invalid.ts | 11 +- .../server/graphs/ecs/mapping.test.ts | 7 +- .../server/graphs/ecs/mapping.ts | 11 +- .../server/graphs/ecs/missing.test.ts | 7 +- .../server/graphs/ecs/missing.ts | 11 +- .../server/graphs/related/errors.test.ts | 7 +- .../server/graphs/related/errors.ts | 11 +- .../server/graphs/related/graph.test.ts | 7 +- .../server/graphs/related/graph.ts | 10 +- .../server/graphs/related/related.test.ts | 7 +- .../server/graphs/related/related.ts | 11 +- .../server/graphs/related/review.test.ts | 7 +- .../server/graphs/related/review.ts | 11 +- .../server/routes/categorization_routes.ts | 11 +- .../server/routes/ecs_routes.ts | 11 +- .../server/routes/related_routes.ts | 11 +- .../common/bedrock/constants.ts | 1 + .../stack_connectors/common/bedrock/schema.ts | 4 + .../server/connector_types/bedrock/bedrock.ts | 50 ++ yarn.lock | 196 ++++--- 49 files changed, 1168 insertions(+), 306 deletions(-) create mode 100644 x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts create mode 100644 x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts create mode 100644 x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts diff --git a/package.json b/package.json index 55eb4bfd3daec..693b6e80b77c1 100644 --- a/package.json +++ b/package.json @@ -80,7 +80,9 @@ "resolutions": { "**/@bazel/typescript/protobufjs": "6.11.4", "**/@hello-pangea/dnd": "16.6.0", - "**/@langchain/core": "0.2.3", + "**/@langchain/core": "0.2.9", + "**/@langchain/openai": "0.2.0", + "**/@smithy/util-utf8": "3.0.0", "**/@types/node": "20.10.5", "**/@typescript-eslint/utils": "5.62.0", "**/chokidar": "^3.5.3", @@ -88,6 +90,7 @@ "**/globule/minimatch": "^3.1.2", "**/hoist-non-react-statics": "^3.3.2", "**/isomorphic-fetch/node-fetch": "^2.6.7", + "**/langchain": "0.2.6", "**/react-intl/**/@types/react": "^17.0.45", "**/remark-parse/trim": "1.0.1", "**/sharp": "0.32.6", @@ -96,6 +99,8 @@ }, "dependencies": { "@appland/sql-parser": "^1.5.1", + "@aws-crypto/sha256-js": "^5.2.0", + "@aws-crypto/util": "^5.2.0", "@babel/runtime": "^7.24.4", "@cfworker/json-schema": "^1.12.7", "@dnd-kit/core": "^6.1.0", @@ -930,10 +935,10 @@ "@kbn/watcher-plugin": "link:x-pack/plugins/watcher", "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", - "@langchain/community": "^0.2.4", - "@langchain/core": "0.2.3", - "@langchain/langgraph": "^0.0.23", - "@langchain/openai": "^0.0.34", + "@langchain/community": "^0.2.13", + "@langchain/core": "^0.2.9", + "@langchain/langgraph": "^0.0.24", + "@langchain/openai": "^0.2.0", "@langtrase/trace-attributes": "^3.0.8", "@launchdarkly/node-server-sdk": "^9.4.5", "@loaders.gl/core": "^3.4.7", @@ -955,9 +960,11 @@ "@paralleldrive/cuid2": "^2.2.2", "@reduxjs/toolkit": "1.9.7", "@slack/webhook": "^7.0.1", - "@smithy/eventstream-codec": "^3.0.0", - "@smithy/eventstream-serde-node": "^3.0.0", - "@smithy/types": "^3.0.0", + "@smithy/eventstream-codec": "^3.1.1", + "@smithy/eventstream-serde-node": "^3.0.3", + "@smithy/protocol-http": "^4.0.2", + "@smithy/signature-v4": "^3.1.1", + "@smithy/types": "^3.2.0", "@smithy/util-utf8": "^3.0.0", "@tanstack/react-query": "^4.29.12", "@tanstack/react-query-devtools": "^4.29.12", @@ -1070,8 +1077,8 @@ "jsonwebtoken": "^9.0.2", "jsts": "^1.6.2", "kea": "^2.6.0", - "langchain": "0.2.3", - "langsmith": "^0.1.30", + "langchain": "^0.2.6", + "langsmith": "^0.1.32", "launchdarkly-js-client-sdk": "^3.3.0", "launchdarkly-node-server-sdk": "^7.0.3", "load-json-file": "^6.2.0", diff --git a/x-pack/packages/kbn-langchain/server/index.ts b/x-pack/packages/kbn-langchain/server/index.ts index 1d52159951809..126a9f6bdbfc6 100644 --- a/x-pack/packages/kbn-langchain/server/index.ts +++ b/x-pack/packages/kbn-langchain/server/index.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { ActionsClientBedrockChatModel } from './language_models/bedrock_chat'; import { ActionsClientChatOpenAI } from './language_models/chat_openai'; import { ActionsClientLlm } from './language_models/llm'; import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model'; @@ -16,6 +17,7 @@ export { parseBedrockStream, parseGeminiResponse, getDefaultArguments, + ActionsClientBedrockChatModel, ActionsClientChatOpenAI, ActionsClientLlm, ActionsClientSimpleChatModel, diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts new file mode 100644 index 0000000000000..6a28d8bd7606a --- /dev/null +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + BedrockChat as _BedrockChat, + convertMessagesToPromptAnthropic, +} from '@langchain/community/chat_models/bedrock/web'; +import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { BaseMessage } from '@langchain/core/messages'; +import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; +import { Logger } from '@kbn/logging'; +import { KibanaRequest } from '@kbn/core/server'; +import { BaseBedrockInput, BedrockLLMInputOutputAdapter } from './utils/bedrock'; + +export class ActionsClientBedrockChatModel extends _BedrockChat { + // Kibana variables + #actions: ActionsPluginStart; + #connectorId: string; + #logger: Logger; + #request: KibanaRequest; + + constructor({ + actions, + request, + connectorId, + logger, + ...params + }: { + actions: ActionsPluginStart; + connectorId: string; + logger: Logger; + request: KibanaRequest; + } & Partial & + BaseChatModelParams) { + // Just to make Langchain BedrockChat happy + super({ + ...params, + credentials: { accessKeyId: '', secretAccessKey: '' }, + }); + + this.#actions = actions; + this.#request = request; + this.#connectorId = connectorId; + this.#logger = logger; + } + + async _signedFetch( + messages: BaseMessage[], + options: this['ParsedCallOptions'], + fields: { + bedrockMethod: 'invoke' | 'invoke-with-response-stream'; + endpointHost: string; + provider: string; + } + ) { + const { bedrockMethod, endpointHost, provider } = fields; + const { + max_tokens: maxTokens, + temperature, + stop, + modelKwargs, + guardrailConfig, + tools, + } = this.invocationParams(options); + const inputBody = this.usesMessagesApi + ? BedrockLLMInputOutputAdapter.prepareMessagesInput( + provider, + messages, + maxTokens, + temperature, + stop, + modelKwargs, + guardrailConfig, + tools, + this.#logger + ) + : BedrockLLMInputOutputAdapter.prepareInput( + provider, + convertMessagesToPromptAnthropic(messages), + maxTokens, + temperature, + stop, + modelKwargs, + fields.bedrockMethod, + guardrailConfig + ); + + // create an actions client from the authenticated request context: + const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); + + const data = (await actionsClient.execute({ + actionId: this.#connectorId, + params: { + subAction: 'invokeAIRaw', + subActionParams: { + bedrockMethod, + model: this.model, + endpointHost, + anthropicVersion: inputBody.anthropicVersion, + messages: inputBody.messages, + temperature: inputBody.temperature, + stopSequences: inputBody.stopSequences, + system: inputBody.system, + maxTokens: inputBody.maxTokens, + signal: options.signal, + timeout: options.timeout, + }, + }, + })) as unknown as Promise; + + return { + ok: data.status === 'ok', + json: () => data.data, + }; + } +} diff --git a/x-pack/packages/kbn-langchain/server/language_models/index.ts b/x-pack/packages/kbn-langchain/server/language_models/index.ts index fcde4156e0d02..d2039f098c74e 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/index.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/index.ts @@ -5,6 +5,7 @@ * 2.0. */ +export { ActionsClientBedrockChatModel } from './bedrock_chat'; export { ActionsClientChatOpenAI } from './chat_openai'; export { ActionsClientLlm } from './llm'; export { ActionsClientSimpleChatModel } from './simple_chat_model'; diff --git a/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts b/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts new file mode 100644 index 0000000000000..a9b04dc4e0dcf --- /dev/null +++ b/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +// origin: https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain-community/src/utils/bedrock/anthropic.ts +// Error: Package subpath './dist/utils/bedrock/anthropic' is not defined by "exports" in langchain/community/package.json + +import { Logger } from '@kbn/logging'; +import { + AIMessage, + BaseMessage, + HumanMessage, + MessageContent, + SystemMessage, + ToolMessage, + isAIMessage, +} from '@langchain/core/messages'; +import { ToolCall } from '@langchain/core/messages/tool'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function extractToolCalls(content: Array>) { + const toolCalls: ToolCall[] = []; + for (const block of content) { + if (block.type === 'tool_use') { + toolCalls.push({ name: block.name, args: block.input, id: block.id }); + } + } + return toolCalls; +} + +function _formatImage(imageUrl: string) { + const regex = /^data:(image\/.+);base64,(.+)$/; + const match = imageUrl.match(regex); + if (match === null) { + throw new Error( + [ + 'Anthropic only supports base64-encoded images currently.', + 'Example: data:image/png;base64,/9j/4AAQSk...', + ].join('\n\n') + ); + } + return { + type: 'base64', + media_type: match[1] ?? '', + data: match[2] ?? '', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any; +} + +function _mergeMessages(messages: BaseMessage[]): Array { + // Merge runs of human/tool messages into single human messages with content blocks. + const merged: HumanMessage[] = []; + for (const message of messages) { + if (message._getType() === 'tool') { + if (typeof message.content === 'string') { + merged.push( + new HumanMessage({ + content: [ + { + type: 'tool_result', + content: message.content, + tool_use_id: (message as ToolMessage).tool_call_id, + }, + ], + }) + ); + } else { + merged.push(new HumanMessage({ content: message.content })); + } + } else { + const previousMessage = merged[merged.length - 1]; + if (previousMessage?._getType() === 'human' && message._getType() === 'human') { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let combinedContent: Array>; + if (typeof previousMessage.content === 'string') { + combinedContent = [{ type: 'text', text: previousMessage.content }]; + } else { + combinedContent = previousMessage.content; + } + if (typeof message.content === 'string') { + combinedContent.push({ type: 'text', text: message.content }); + } else { + combinedContent = combinedContent.concat(message.content); + } + previousMessage.content = combinedContent; + } else { + merged.push(message); + } + } + } + return merged; +} + +export function _convertLangChainToolCallToAnthropic( + toolCall: ToolCall + // eslint-disable-next-line @typescript-eslint/no-explicit-any +): Record { + if (toolCall.id === undefined) { + throw new Error(`Anthropic requires all tool calls to have an "id".`); + } + return { + type: 'tool_use', + id: toolCall.id, + name: toolCall.name, + input: toolCall.args, + }; +} + +function _formatContent(content: MessageContent) { + if (typeof content === 'string') { + return content; + } else { + const contentBlocks = content.map((contentPart) => { + if (contentPart.type === 'image_url') { + let source; + if (typeof contentPart.image_url === 'string') { + source = _formatImage(contentPart.image_url); + } else { + source = _formatImage(contentPart.image_url.url); + } + return { + type: 'image' as const, // Explicitly setting the type as "image" + source, + }; + } else if (contentPart.type === 'text') { + // Assuming contentPart is of type MessageContentText here + return { + type: 'text' as const, // Explicitly setting the type as "text" + text: contentPart.text, + }; + } else if (contentPart.type === 'tool_use' || contentPart.type === 'tool_result') { + // TODO: Fix when SDK types are fixed + return { + ...contentPart, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any; + } else { + throw new Error('Unsupported message content format'); + } + }); + return contentBlocks; + } +} + +export function formatMessagesForAnthropic( + messages: BaseMessage[], + logger?: Logger +): { + system?: string; + messages: Array>; +} { + const mergedMessages = _mergeMessages(messages); + let system: string | undefined; + if (mergedMessages.length > 0 && mergedMessages[0]._getType() === 'system') { + if (typeof messages[0].content !== 'string') { + throw new Error('System message content must be a string.'); + } + system = messages[0].content; + } + const conversationMessages = system !== undefined ? mergedMessages.slice(1) : mergedMessages; + const formattedMessages = conversationMessages.map((message) => { + let role; + if (message._getType() === 'human') { + role = 'user' as const; + } else if (message._getType() === 'ai') { + role = 'assistant' as const; + } else if (message._getType() === 'tool') { + role = 'user' as const; + } else if (message._getType() === 'system') { + throw new Error('System messages are only permitted as the first passed message.'); + } else { + throw new Error(`Message type "${message._getType()}" is not supported.`); + } + if (isAIMessage(message) && !!message.tool_calls?.length) { + if (typeof message.content === 'string') { + if (message.content === '') { + return { + role, + content: message.tool_calls.map(_convertLangChainToolCallToAnthropic), + }; + } else { + return { + role, + content: [ + { type: 'text', text: message.content }, + ...message.tool_calls.map(_convertLangChainToolCallToAnthropic), + ], + }; + } + } else { + const { content } = message; + const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) => + content.find( + (contentPart) => contentPart.type === 'tool_use' && contentPart.id === toolCall.id + ) + ); + if (hasMismatchedToolCalls) { + logger?.warn( + `The "tool_calls" field on a message is only respected if content is a string.` + ); + } + return { + role, + content: _formatContent(message.content), + }; + } + } else { + return { + role, + content: _formatContent(message.content), + }; + } + }); + return { + messages: formattedMessages, + system, + }; +} + +export function isAnthropicTool(tool: unknown): tool is Record { + if (typeof tool !== 'object' || !tool) return false; + return 'input_schema' in tool; +} diff --git a/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts b/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts new file mode 100644 index 0000000000000..7cad55e1b6e06 --- /dev/null +++ b/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts @@ -0,0 +1,486 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/* eslint-disable complexity */ + +// origin: https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain-community/src/utils/bedrock/index.ts +// // Error: Package subpath './dist/utils/bedrock' is not defined by "exports" in langchain/community/package.json + +import type { AwsCredentialIdentity, Provider } from '@aws-sdk/types'; +import { AIMessage, AIMessageChunk, BaseMessage } from '@langchain/core/messages'; +import { StructuredToolInterface } from '@langchain/core/tools'; +import { ChatGeneration, ChatGenerationChunk } from '@langchain/core/outputs'; +import { Logger } from '@kbn/logging'; +import { extractToolCalls, formatMessagesForAnthropic } from './anthropic'; + +export type CredentialType = AwsCredentialIdentity | Provider; + +/** + * format messages for Cohere Command-R and CommandR+ via AWS Bedrock. + * + * @param messages messages The base messages to format as a prompt. + * + * @returns The formatted prompt for Cohere. + * + * `system`: user system prompts. Overrides the default preamble for search query generation. Has no effect on tool use generations.\ + * `message`: (Required) Text input for the model to respond to.\ + * `chatHistory`: A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.\ + * The following are required fields. + * - `role` - The role for the message. Valid values are USER or CHATBOT.\ + * - `message` – Text contents of the message.\ + * + * The following is example JSON for the chat_history field.\ + * "chat_history": [ + * {"role": "USER", "message": "Who discovered gravity?"}, + * {"role": "CHATBOT", "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}]\ + * + * docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + */ +function formatMessagesForCohere(messages: BaseMessage[]): { + system?: string; + message: string; + chatHistory: Array>; +} { + const systemMessages = messages.filter((system) => system._getType() === 'system'); + + const system = systemMessages + .filter((m) => typeof m.content === 'string') + .map((m) => m.content) + .join('\n\n'); + + const conversationMessages = messages.filter((message) => message._getType() !== 'system'); + + const questionContent = conversationMessages.slice(-1); + + if (!questionContent.length || questionContent[0]._getType() !== 'human') { + throw new Error('question message content must be a human message.'); + } + + if (typeof questionContent[0].content !== 'string') { + throw new Error('question message content must be a string.'); + } + + const formattedMessage = questionContent[0].content; + + const formattedChatHistories = conversationMessages.slice(0, -1).map((message) => { + let role; + switch (message._getType()) { + case 'human': + role = 'USER' as const; + break; + case 'ai': + role = 'CHATBOT' as const; + break; + case 'system': + throw new Error('chat_history can not include system prompts.'); + default: + throw new Error(`Message type "${message._getType()}" is not supported.`); + } + + if (typeof message.content !== 'string') { + throw new Error('message content must be a string.'); + } + return { + role, + message: message.content, + }; + }); + + return { + chatHistory: formattedChatHistories, + message: formattedMessage, + system, + }; +} + +/** Bedrock models. + To authenticate, the AWS client uses the following methods to automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. + Make sure the credentials / roles used have the required policies to access the Bedrock service. +*/ +export interface BaseBedrockInput { + /** Model to use. + For example, "amazon.titan-tg1-large", this is equivalent to the modelId property in the list-foundation-models api. + */ + model: string; + + /** The AWS region e.g. `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here. + */ + region?: string; + + /** AWS Credentials. + If no credentials are provided, the default credentials from `@aws-sdk/credential-provider-node` will be used. + */ + credentials?: CredentialType; + + /** Temperature. */ + temperature?: number; + + /** Max tokens. */ + maxTokens?: number; + + /** A custom fetch function for low-level access to AWS API. Defaults to fetch(). */ + fetchFn?: typeof fetch; + + /** @deprecated Use endpointHost instead Override the default endpoint url. */ + endpointUrl?: string; + + /** Override the default endpoint hostname. */ + endpointHost?: string; + + /** + * Optional additional stop sequences to pass to the model. Currently only supported for Anthropic and AI21. + * @deprecated Use .bind({ "stop": [...] }) instead + * */ + stopSequences?: string[]; + + /** Additional kwargs to pass to the model. */ + modelKwargs?: Record; + + /** Whether or not to stream responses */ + streaming: boolean; + + /** Trace settings for the Bedrock Guardrails. */ + trace?: 'ENABLED' | 'DISABLED'; + + /** Identifier for the guardrail configuration. */ + guardrailIdentifier?: string; + + /** Version for the guardrail configuration. */ + guardrailVersion?: string; + + /** Required when Guardrail is in use. */ + guardrailConfig?: { + tagSuffix: string; + streamProcessingMode: 'SYNCHRONOUS' | 'ASYNCHRONOUS'; + }; +} + +interface Dict { + [key: string]: unknown; +} + +/** + * A helper class used within the `Bedrock` class. It is responsible for + * preparing the input and output for the Bedrock service. It formats the + * input prompt based on the provider (e.g., "anthropic", "ai21", + * "amazon") and extracts the generated text from the service response. + */ +export class BedrockLLMInputOutputAdapter { + /** Adapter class to prepare the inputs from Langchain to a format + that LLM model expects. Also, provides a helper function to extract + the generated text from the model response. */ + + static prepareInput( + provider: string, + prompt: string, + maxTokens = 50, + temperature = 0, + stopSequences: string[] | undefined = undefined, + modelKwargs: Record = {}, + bedrockMethod: 'invoke' | 'invoke-with-response-stream' = 'invoke', + guardrailConfig: + | { + tagSuffix: string; + streamProcessingMode: 'SYNCHRONOUS' | 'ASYNCHRONOUS'; + } + | undefined = undefined + ): Dict { + const inputBody: Dict = {}; + + if (provider === 'anthropic') { + inputBody.prompt = prompt; + inputBody.max_tokens_to_sample = maxTokens; + inputBody.temperature = temperature; + inputBody.stop_sequences = stopSequences; + } else if (provider === 'ai21') { + inputBody.prompt = prompt; + inputBody.maxTokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stopSequences = stopSequences; + } else if (provider === 'meta') { + inputBody.prompt = prompt; + inputBody.max_gen_len = maxTokens; + inputBody.temperature = temperature; + } else if (provider === 'amazon') { + inputBody.inputText = prompt; + inputBody.textGenerationConfig = { + maxTokenCount: maxTokens, + temperature, + }; + } else if (provider === 'cohere') { + inputBody.prompt = prompt; + inputBody.max_tokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stop_sequences = stopSequences; + if (bedrockMethod === 'invoke-with-response-stream') { + inputBody.stream = true; + } + } else if (provider === 'mistral') { + inputBody.prompt = prompt; + inputBody.max_tokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stop = stopSequences; + } + + if (guardrailConfig && guardrailConfig.tagSuffix && guardrailConfig.streamProcessingMode) { + inputBody['amazon-bedrock-guardrailConfig'] = guardrailConfig; + } + + return { ...inputBody, ...modelKwargs }; + } + + static prepareMessagesInput( + provider: string, + messages: BaseMessage[], + maxTokens = 1024, + temperature = 0, + stopSequences: string[] | undefined = undefined, + modelKwargs: Record = {}, + guardrailConfig: + | { + tagSuffix: string; + streamProcessingMode: 'SYNCHRONOUS' | 'ASYNCHRONOUS'; + } + | undefined = undefined, + tools: Array> = [], + logger: Logger + ): Dict { + const inputBody: Dict = {}; + + if (provider === 'anthropic') { + const { system, messages: formattedMessages } = formatMessagesForAnthropic(messages, logger); + if (system !== undefined) { + inputBody.system = system; + } + inputBody.anthropic_version = 'bedrock-2023-05-31'; + inputBody.messages = formattedMessages; + inputBody.max_tokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stop_sequences = stopSequences; + + if (tools.length > 0) { + inputBody.tools = tools; + } + return { ...inputBody, ...modelKwargs }; + } else if (provider === 'cohere') { + const { + system, + message: formattedMessage, + chatHistory: formattedChatHistories, + } = formatMessagesForCohere(messages); + + if (system !== undefined && system.length > 0) { + inputBody.preamble = system; + } + inputBody.message = formattedMessage; + inputBody.chat_history = formattedChatHistories; + inputBody.max_tokens = maxTokens; + inputBody.temperature = temperature; + inputBody.stop_sequences = stopSequences; + } else { + throw new Error('The messages API is currently only supported by Anthropic or Cohere'); + } + + if (guardrailConfig && guardrailConfig.tagSuffix && guardrailConfig.streamProcessingMode) { + inputBody['amazon-bedrock-guardrailConfig'] = guardrailConfig; + } + + return { ...inputBody, ...modelKwargs }; + } + + /** + * Extracts the generated text from the service response. + * @param provider The provider name. + * @param responseBody The response body from the service. + * @returns The generated text. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + static prepareOutput(provider: string, responseBody: any): string { + if (provider === 'anthropic') { + return responseBody.completion; + } else if (provider === 'ai21') { + return responseBody?.completions?.[0]?.data?.text ?? ''; + } else if (provider === 'cohere') { + return responseBody?.generations?.[0]?.text ?? responseBody?.text ?? ''; + } else if (provider === 'meta') { + return responseBody.generation; + } else if (provider === 'mistral') { + return responseBody?.outputs?.[0]?.text; + } + + // I haven't been able to get a response with more than one result in it. + return responseBody.results?.[0]?.outputText; + } + + static prepareMessagesOutput( + provider: string, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + response: any + ): ChatGeneration | undefined { + const responseBody = response ?? {}; + if (provider === 'anthropic') { + if (responseBody.type === 'message_start') { + return parseMessage(responseBody.message, true); + } else if ( + responseBody.type === 'content_block_delta' && + responseBody.delta?.type === 'text_delta' && + typeof responseBody.delta?.text === 'string' + ) { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: responseBody.delta.text, + }), + text: responseBody.delta.text, + }); + } else if (responseBody.type === 'message_delta') { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ content: '' }), + text: '', + generationInfo: { + ...responseBody.delta, + usage: responseBody.usage, + }, + }); + } else if ( + responseBody.type === 'message_stop' && + responseBody['amazon-bedrock-invocationMetrics'] !== undefined + ) { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ content: '' }), + text: '', + generationInfo: { + 'amazon-bedrock-invocationMetrics': responseBody['amazon-bedrock-invocationMetrics'], + }, + }); + } else if (responseBody.type === 'message') { + return parseMessage(responseBody); + } else { + return undefined; + } + } else if (provider === 'cohere') { + if (responseBody.event_type === 'stream-start') { + return parseMessageCohere(responseBody.message, true); + } else if ( + responseBody.event_type === 'text-generation' && + typeof responseBody?.text === 'string' + ) { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: responseBody.text, + }), + text: responseBody.text, + }); + } else if (responseBody.event_type === 'search-queries-generation') { + return parseMessageCohere(responseBody); + } else if ( + responseBody.event_type === 'stream-end' && + responseBody.response !== undefined && + responseBody['amazon-bedrock-invocationMetrics'] !== undefined + ) { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ content: '' }), + text: '', + generationInfo: { + response: responseBody.response, + 'amazon-bedrock-invocationMetrics': responseBody['amazon-bedrock-invocationMetrics'], + }, + }); + } else { + if ( + responseBody.finish_reason === 'COMPLETE' || + responseBody.finish_reason === 'MAX_TOKENS' + ) { + return parseMessageCohere(responseBody); + } else { + return undefined; + } + } + } else { + throw new Error('The messages API is currently only supported by Anthropic or Cohere.'); + } + } +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration { + const { content, id, ...generationInfo } = responseBody; + let parsedContent; + if (Array.isArray(content) && content.length === 1 && content[0].type === 'text') { + parsedContent = content[0].text; + } else if (Array.isArray(content) && content.length === 0) { + parsedContent = ''; + } else { + parsedContent = content; + } + if (asChunk) { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: parsedContent, + additional_kwargs: { id }, + }), + text: typeof parsedContent === 'string' ? parsedContent : '', + generationInfo, + }); + } else { + // TODO: we are throwing away here the text response, as the interface of this method returns only one + const toolCalls = extractToolCalls(responseBody.content); + + if (toolCalls.length > 0) { + return { + message: new AIMessage({ + content: '', + additional_kwargs: { id }, + tool_calls: toolCalls, + }), + text: typeof parsedContent === 'string' ? parsedContent : '', + generationInfo, + }; + } + + return { + message: new AIMessage({ + content: parsedContent, + additional_kwargs: { id }, + tool_calls: toolCalls, + }), + text: typeof parsedContent === 'string' ? parsedContent : '', + generationInfo, + }; + } +} + +function parseMessageCohere( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + responseBody: any, + asChunk?: boolean +): ChatGeneration { + const { text, ...generationInfo } = responseBody; + let parsedContent = text; + if (typeof text !== 'string') { + parsedContent = ''; + } + if (asChunk) { + return new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: parsedContent, + }), + text: parsedContent, + generationInfo, + }); + } else { + return { + message: new AIMessage({ + content: parsedContent, + }), + text: parsedContent, + generationInfo, + }; + } +} diff --git a/x-pack/plugins/actions/server/application/connector/methods/get/get.ts b/x-pack/plugins/actions/server/application/connector/methods/get/get.ts index 2d4a94f5615d7..35e8101757bc9 100644 --- a/x-pack/plugins/actions/server/application/connector/methods/get/get.ts +++ b/x-pack/plugins/actions/server/application/connector/methods/get/get.ts @@ -62,6 +62,7 @@ export async function get({ id, actionTypeId: foundInMemoryConnector.actionTypeId, name: foundInMemoryConnector.name, + config: foundInMemoryConnector.config, isPreconfigured: foundInMemoryConnector.isPreconfigured, isSystemAction: foundInMemoryConnector.isSystemAction, isDeprecated: isConnectorDeprecated(foundInMemoryConnector), diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index bcf39320f21cc..21b031b073722 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -12,12 +12,9 @@ import { ToolInterface } from '@langchain/core/tools'; import { streamFactory } from '@kbn/ml-response-stream/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { RetrievalQAChain } from 'langchain/chains'; -import { - getDefaultArguments, - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server'; +import { getDefaultArguments } from '@kbn/langchain/server'; import { MessagesPlaceholder } from '@langchain/core/prompts'; +import { getLlmClass, isToolCallingSupported } from '../../../routes/utils'; import { AgentExecutor } from '../executors/types'; import { APMTracer } from '../tracers/apm_tracer'; import { AssistantToolParams } from '../../../types'; @@ -45,13 +42,14 @@ export const callAgentExecutor: AgentExecutor = async ({ isStream = false, onLlmResponse, onNewReplacements, + model, + region, replacements, request, size, traceOptions, }) => { - const isOpenAI = llmType === 'openai'; - const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; + const llmClass = getLlmClass(llmType, isStream); const llm = new llmClass({ actions, @@ -59,9 +57,10 @@ export const callAgentExecutor: AgentExecutor = async ({ request, llmType, logger, + region, // possible client model override, // let this be undefined otherwise so the connector handles the model - model: request.body.model, + model, // ensure this is defined because we default to it in the language_models // This is where the LangSmith logs (Metadata > Invocation Params) are set temperature: getDefaultArguments(llmType).temperature, @@ -116,7 +115,7 @@ export const callAgentExecutor: AgentExecutor = async ({ handleParsingErrors: 'Try again, paying close attention to the allowed tool input', }; // isOpenAI check is not on agentType alone because typescript doesn't like - const executor = isOpenAI + const executor = isToolCallingSupported(llmType) ? await initializeAgentExecutorWithOptions(tools, llm, { agentType: 'openai-functions', ...executorArgs, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index bd07099e312b3..ef5a28fa79bc6 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -48,6 +48,8 @@ export interface AgentExecutorParams { langChainMessages: BaseMessage[]; llmType?: string; logger: Logger; + model?: string; + region?: string; onNewReplacements?: (newReplacements: Replacements) => void; replacements: Replacements; isStream?: T; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 1e40f6b2fe127..8f1011d04f887 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -7,12 +7,9 @@ import { StructuredTool } from '@langchain/core/tools'; import { RetrievalQAChain } from 'langchain/chains'; -import { - getDefaultArguments, - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server'; +import { getDefaultArguments } from '@kbn/langchain/server'; import { createOpenAIFunctionsAgent, createStructuredChatAgent } from 'langchain/agents'; +import { getLlmClass, isToolCallingSupported } from '../../../../routes/utils'; import { AssistantToolParams } from '../../../../types'; import { AgentExecutor } from '../../executors/types'; import { openAIFunctionAgentPrompt, structuredChatAgentPrompt } from './prompts'; @@ -42,13 +39,14 @@ export const callAssistantGraph: AgentExecutor = async ({ onLlmResponse, onNewReplacements, replacements, + model, + region, request, size, traceOptions, }) => { const logger = parentLogger.get('defaultAssistantGraph'); - const isOpenAI = llmType === 'openai'; - const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; + const llmClass = getLlmClass(llmType, isStream); const llm = new llmClass({ actions, @@ -58,7 +56,8 @@ export const callAssistantGraph: AgentExecutor = async ({ logger, // possible client model override, // let this be undefined otherwise so the connector handles the model - model: request.body.model, + model, + region, // ensure this is defined because we default to it in the language_models // This is where the LangSmith logs (Metadata > Invocation Params) are set temperature: getDefaultArguments(llmType).temperature, @@ -68,7 +67,7 @@ export const callAssistantGraph: AgentExecutor = async ({ // failure could be due to bad connector, we should deliver that result to the client asap maxRetries: 0, }); - const model = llm; + const graphModel = llm; const messages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message @@ -76,7 +75,7 @@ export const callAssistantGraph: AgentExecutor = async ({ const modelExists = await esStore.isModelInstalled(); // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(model, esStore.asRetriever(10)); + const chain = RetrievalQAChain.fromLLM(graphModel, esStore.asRetriever(10)); // Fetch any applicable tools that the source plugin may have registered const assistantToolParams: AssistantToolParams = { @@ -86,7 +85,7 @@ export const callAssistantGraph: AgentExecutor = async ({ esClient, isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, - llm: model, + llm: graphModel, logger, modelExists, onNewReplacements, @@ -99,7 +98,7 @@ export const callAssistantGraph: AgentExecutor = async ({ (tool) => tool.getTool(assistantToolParams) ?? [] ); - const agentRunnable = isOpenAI + const agentRunnable = isToolCallingSupported(llmType) ? await createOpenAIFunctionsAgent({ llm, tools, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 197479fc24dd5..a7f4c992cc11d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -356,6 +356,24 @@ export const postActionsConnectorExecuteRoute = ( kbDataClient, }; + const llmType = getLlmType(actionTypeId); + + const actionsClient = await actions.getActionsClientWithRequest(request); + + let region; + let model = request.body.model; + if (llmType === 'bedrock') { + try { + const connector = await actionsClient.get({ id: connectorId }); + region = connector.config?.apiUrl.split('.').reverse()[2]; + if (!model) { + model = connector.config?.defaultModel; + } + } catch (e) { + logger.error(`Failed to get region: ${e.message}`); + } + } + // Shared executor params const executorParams: AgentExecutorParams = { abortSignal, @@ -372,11 +390,13 @@ export const postActionsConnectorExecuteRoute = ( esClient, esStore, isStream: request.body.subAction !== 'invokeAI', - llmType: getLlmType(actionTypeId), + llmType, langChainMessages, logger, + model, onNewReplacements, onLlmResponse, + region, request, response, replacements: request.body.replacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 324d2fefa46b9..2c54c2a927031 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -13,6 +13,11 @@ import type { KibanaResponseFactory, CustomHttpResponseOptions, } from '@kbn/core/server'; +import { + ActionsClientChatOpenAI, + ActionsClientBedrockChatModel, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server'; import { CustomHttpRequestError } from './custom_http_request_error'; export interface OutputError { @@ -174,3 +179,12 @@ export const getLlmType = (actionTypeId: string): string | undefined => { }; return llmTypeDictionary[actionTypeId]; }; + +export const getLlmClass = (llmType?: string, isStreaming?: boolean) => + llmType === 'openai' + ? ActionsClientChatOpenAI + : llmType === 'bedrock' && !isStreaming + ? ActionsClientBedrockChatModel + : ActionsClientSimpleChatModel; + +export const isToolCallingSupported = (llmType?: string) => ['openai'].includes(llmType ?? ''); diff --git a/x-pack/plugins/elastic_assistant/server/types.ts b/x-pack/plugins/elastic_assistant/server/types.ts index f12bacde983df..23f4e3c49303a 100755 --- a/x-pack/plugins/elastic_assistant/server/types.ts +++ b/x-pack/plugins/elastic_assistant/server/types.ts @@ -34,6 +34,7 @@ import { import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen'; import { LicensingApiRequestHandlerContext } from '@kbn/licensing-plugin/server'; import { + ActionsClientBedrockChatModel, ActionsClientChatOpenAI, ActionsClientLlm, ActionsClientSimpleChatModel, @@ -206,6 +207,11 @@ export interface AssistantTool { getTool: (params: AssistantToolParams) => Tool | DynamicStructuredTool | null; } +export type AssistantToolLlm = + | ActionsClientBedrockChatModel + | ActionsClientChatOpenAI + | ActionsClientSimpleChatModel; + export interface AssistantToolParams { alertsIndexPattern?: string; anonymizationFields?: AnonymizationFieldResponse[]; @@ -214,7 +220,7 @@ export interface AssistantToolParams { esClient: ElasticsearchClient; kbDataClient?: AIAssistantKnowledgeBaseDataClient; langChainTimeout?: number; - llm?: ActionsClientLlm | ActionsClientChatOpenAI | ActionsClientSimpleChatModel; + llm?: ActionsClientLlm | AssistantToolLlm; logger: Logger; modelExists: boolean; onNewReplacements?: (newReplacements: Replacements) => void; diff --git a/x-pack/plugins/integration_assistant/kibana.jsonc b/x-pack/plugins/integration_assistant/kibana.jsonc index a70120d9cefba..bf52e0abcabf4 100644 --- a/x-pack/plugins/integration_assistant/kibana.jsonc +++ b/x-pack/plugins/integration_assistant/kibana.jsonc @@ -16,6 +16,7 @@ "triggersActionsUi", "actions", "stackConnectors", + "elasticAssistant" ], } } diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts index 3ad0926297bbc..b30fa1d66a534 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts @@ -13,14 +13,11 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts index ed1a88c3a1cfd..03f95d36ff5a4 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts @@ -4,20 +4,15 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_MAIN_PROMPT } from './prompts'; -export async function handleCategorization( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleCategorization(state: CategorizationState, model: AssistantToolLlm) { const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts index 18d8c1842080a..93ca5030e5104 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts @@ -13,14 +13,11 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts index d8cb7beedc9bf..8ce8792604af5 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts @@ -4,20 +4,15 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_ERROR_PROMPT } from './prompts'; -export async function handleErrors( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleErrors(state: CategorizationState, model: AssistantToolLlm) { const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT; const outputParser = new JsonOutputParser(); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts index 4122d4540dbc0..7fd76e3bd7a60 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts @@ -25,14 +25,11 @@ import { handleCategorization } from './categorization'; import { handleErrors } from './errors'; import { handleInvalidCategorization } from './invalid'; import { testPipeline, combineProcessors } from '../../util'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: "I'll callback later.", -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; jest.mock('./errors'); jest.mock('./review'); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts index 6834fcf892a9e..79795b3c39ecf 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts @@ -8,10 +8,7 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import type { CategorizationState } from '../../types'; import { modifySamples, formatSamples } from '../../util/samples'; import { handleCategorization } from './categorization'; @@ -151,7 +148,7 @@ function chainRouter(state: CategorizationState): string { export async function getCategorizationGraph( client: IScopedClusterClient, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel + model: AssistantToolLlm ) { const workflow = new StateGraph({ channels: graphState, diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts index 10560137093d8..7c8e1fe1c4a62 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts @@ -13,14 +13,11 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts index 413694b594518..b5ec203b54cef 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts @@ -4,10 +4,8 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { CategorizationState } from '../../types'; @@ -17,7 +15,7 @@ import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts'; export async function handleInvalidCategorization( state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel + model: AssistantToolLlm ) { const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts index 7775b69c5b6a8..a053226b65afc 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts @@ -13,14 +13,11 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts index 12b3880737237..df11a6ef8d4eb 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts @@ -4,10 +4,8 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import { CATEGORIZATION_REVIEW_PROMPT } from './prompts'; @@ -16,10 +14,7 @@ import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants'; -export async function handleReview( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleReview(state: CategorizationState, model: AssistantToolLlm) { const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts index 9270b2453e261..2aa950f5a0591 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts @@ -9,14 +9,11 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleDuplicates } from './duplicates'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts index fd11a660e75ab..8e576b3775a9c 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts @@ -4,18 +4,13 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_DUPLICATES_PROMPT } from './prompts'; -export async function handleDuplicates( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleDuplicates(state: EcsMappingState, model: AssistantToolLlm) { const ecsDuplicatesPrompt = ECS_DUPLICATES_PROMPT; const outputParser = new JsonOutputParser(); const ecsDuplicatesGraph = ecsDuplicatesPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts index 0ae626924c349..41729e7e54c06 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts @@ -19,14 +19,11 @@ import { handleEcsMapping } from './mapping'; import { handleDuplicates } from './duplicates'; import { handleMissingKeys } from './missing'; import { handleInvalidEcs } from './invalid'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: "I'll callback later.", -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; jest.mock('./mapping'); jest.mock('./duplicates'); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts index 2c8e7283d4728..b4a9c7b0dfd80 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts @@ -5,10 +5,7 @@ * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import type { StateGraphArgs } from '@langchain/langgraph'; import { END, START, StateGraph } from '@langchain/langgraph'; import type { EcsMappingState } from '../../types'; @@ -140,7 +137,7 @@ function chainRouter(state: EcsMappingState): string { return END; } -export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function getEcsGraph(model: AssistantToolLlm) { const workflow = new StateGraph({ channels: graphState, }) diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts index ce1f76ce7a721..15da3809e2d97 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts @@ -9,14 +9,11 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleInvalidEcs } from './invalid'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts index dcbba0ebe9d13..e06113135c910 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts @@ -4,18 +4,13 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_INVALID_PROMPT } from './prompts'; -export async function handleInvalidEcs( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleInvalidEcs(state: EcsMappingState, model: AssistantToolLlm) { const ecsInvalidEcsPrompt = ECS_INVALID_PROMPT; const outputParser = new JsonOutputParser(); const ecsInvalidEcsGraph = ecsInvalidEcsPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts index dbbfc0608d010..4170505e458cd 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts @@ -9,14 +9,11 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleEcsMapping } from './mapping'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts index 7ecb108659f45..98c9fe4eca82f 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts @@ -4,18 +4,13 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_MAIN_PROMPT } from './prompts'; -export async function handleEcsMapping( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleEcsMapping(state: EcsMappingState, model: AssistantToolLlm) { const ecsMainPrompt = ECS_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const ecsMainGraph = ecsMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts index b369d28b1e177..d283a6f3fe1c1 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts @@ -9,14 +9,11 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleMissingKeys } from './missing'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts index d7f1f65b2b4ea..ca7f7501f4eef 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts @@ -4,18 +4,13 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_MISSING_KEYS_PROMPT } from './prompts'; -export async function handleMissingKeys( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleMissingKeys(state: EcsMappingState, model: AssistantToolLlm) { const ecsMissingPrompt = ECS_MISSING_KEYS_PROMPT; const outputParser = new JsonOutputParser(); const ecsMissingGraph = ecsMissingPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts index 24dc4365dcbff..9f530d49fc6f3 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts @@ -13,14 +13,11 @@ import { relatedMockProcessors, relatedExpectedHandlerResponse, } from '../../../__jest__/fixtures/related'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: RelatedState = relatedTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts index 025422008c4dc..20c5f5e108226 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts @@ -4,20 +4,15 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_ERROR_PROMPT } from './prompts'; -export async function handleErrors( - state: RelatedState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleErrors(state: RelatedState, model: AssistantToolLlm) { const relatedErrorPrompt = RELATED_ERROR_PROMPT; const outputParser = new JsonOutputParser(); const relatedErrorGraph = relatedErrorPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts index 40989e9733800..19cf49b989ea1 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts @@ -22,14 +22,11 @@ import { handleReview } from './review'; import { handleRelated } from './related'; import { handleErrors } from './errors'; import { testPipeline, combineProcessors } from '../../util'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: "I'll callback later.", -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; jest.mock('./errors'); jest.mock('./review'); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts index 9b50c05889402..51a6f9583fe64 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts @@ -8,10 +8,7 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import type { RelatedState } from '../../types'; import { modifySamples, formatSamples } from '../../util/samples'; import { handleValidatePipeline } from '../../util/graph'; @@ -137,10 +134,7 @@ function chainRouter(state: RelatedState): string { return END; } -export async function getRelatedGraph( - client: IScopedClusterClient, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function getRelatedGraph(client: IScopedClusterClient, model: AssistantToolLlm) { const workflow = new StateGraph({ channels: graphState }) .addNode('modelInput', modelInput) .addNode('modelOutput', modelOutput) diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts index 3a741020fb530..b81de2b1025e0 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts @@ -13,14 +13,11 @@ import { relatedMockProcessors, relatedExpectedHandlerResponse, } from '../../../__jest__/fixtures/related'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: RelatedState = relatedTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts index 2c98381510d9b..0cd1a7f8251b1 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts @@ -4,20 +4,15 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_MAIN_PROMPT } from './prompts'; -export async function handleRelated( - state: RelatedState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleRelated(state: RelatedState, model: AssistantToolLlm) { const relatedMainPrompt = RELATED_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const relatedMainGraph = relatedMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts index 475f0d72b988d..a814d25d0c3a2 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts @@ -13,14 +13,11 @@ import { relatedMockProcessors, relatedExpectedHandlerResponse, } from '../../../__jest__/fixtures/related'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; const mockLlm = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), -}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; +}) as unknown as AssistantToolLlm; const testState: RelatedState = relatedTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts index 6c07079e18f48..2e76b822af2ae 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts @@ -4,20 +4,15 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + +import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_REVIEW_PROMPT } from './prompts'; -export async function handleReview( - state: RelatedState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleReview(state: RelatedState, model: AssistantToolLlm) { const relatedReviewPrompt = RELATED_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const relatedReviewGraph = relatedReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts index 6654898bd0232..b0befabd78384 100644 --- a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts @@ -7,10 +7,7 @@ import type { IKibanaResponse, IRouter } from '@kbn/core/server'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import { getLlmType, getLlmClass } from '@kbn/elastic-assistant-plugin/server/routes/utils'; import { CATEGORIZATION_GRAPH_PATH, CategorizationRequestBody, @@ -59,15 +56,15 @@ export function registerCategorizationRoutes( )[0]; const abortSignal = getRequestAbortedSignal(req.events.aborted$); - const isOpenAI = connector.actionTypeId === '.gen-ai'; - const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; + const llmType = getLlmType(connector.actionTypeId); + const llmClass = getLlmClass(llmType); const model = new llmClass({ actions: actionsPlugin, connectorId: connector.id, request: req, logger, - llmType: isOpenAI ? 'openai' : 'bedrock', + llmType, model: connector.config?.defaultModel, temperature: 0.05, maxTokens: 4096, diff --git a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts index ee461b94feba4..c53fcb49442ad 100644 --- a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts @@ -7,10 +7,7 @@ import type { IKibanaResponse, IRouter } from '@kbn/core/server'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; +import { getLlmType, getLlmClass } from '@kbn/elastic-assistant-plugin/server/routes/utils'; import { ECS_GRAPH_PATH, EcsMappingRequestBody, EcsMappingResponse } from '../../common'; import { ROUTE_HANDLER_TIMEOUT } from '../constants'; import { getEcsGraph } from '../graphs/ecs'; @@ -50,15 +47,15 @@ export function registerEcsRoutes(router: IRouter connectorItem.actionTypeId === '.bedrock' )[0]; - const isOpenAI = connector.actionTypeId === '.gen-ai'; - const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; + const llmType = getLlmType(connector.actionTypeId); + const llmClass = getLlmClass(llmType); const abortSignal = getRequestAbortedSignal(req.events.aborted$); const model = new llmClass({ @@ -60,7 +57,7 @@ export function registerRelatedRoutes(router: IRouter { schema: InvokeAIActionParamsSchema, }); + this.registerSubAction({ + name: SUB_ACTION.INVOKE_AI_RAW, + method: 'invokeAIRaw', + schema: InvokeAIRawActionParamsSchema, + }); + this.registerSubAction({ name: SUB_ACTION.INVOKE_STREAM, method: 'invokeStream', @@ -320,6 +329,47 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B }); return { message: res.completion.trim() }; } + + /** + * Non-streamed security solution AI Assistant requests + * Responsible for invoking the runApi method with the provided body. + * It then formats the response into a string + * @param messages An array of messages to be sent to the API + * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. + * @returns an object with the response string as a property called message + */ + public async invokeAIRaw({ + bedrockMethod = 'invoke', + messages, + model, + stopSequences, + system, + temperature, + maxTokens, + signal, + timeout, + }: InvokeAIRawActionParams) { + // set model on per request basis + const currentModel = model ?? this.model; + const path = `/model/${currentModel}/${bedrockMethod}`; + const body = JSON.stringify( + formatBedrockBody({ messages, stopSequences, system, temperature, maxTokens }) + ); + const signed = this.signRequest(body, path, false); + const params = { + ...signed, + url: `${this.url}${path}`, + method: 'post' as Method, + data: body, + signal, + // give up to 2 minutes for response + timeout: timeout ?? DEFAULT_TIMEOUT_MS, + }; + + const response = await this.request({ ...params, responseSchema: RunActionRawResponseSchema }); + + return response.data; + } } const formatBedrockBody = ({ diff --git a/yarn.lock b/yarn.lock index d168b25927c98..21617305e5257 100644 --- a/yarn.lock +++ b/yarn.lock @@ -77,23 +77,32 @@ resolved "https://registry.yarnpkg.com/@assemblyscript/loader/-/loader-0.10.1.tgz#70e45678f06c72fa2e350e8553ec4a4d72b92e06" integrity sha512-H71nDOOL8Y7kWRLqf6Sums+01Q5msqBW2KhDUTemh1tvY04eSkSXrK0uj/4mmY0Xr16/3zyZmsrxN7CKuRbNRg== -"@aws-crypto/crc32@3.0.0": - version "3.0.0" - resolved "https://registry.yarnpkg.com/@aws-crypto/crc32/-/crc32-3.0.0.tgz#07300eca214409c33e3ff769cd5697b57fdd38fa" - integrity sha512-IzSgsrxUcsrejQbPVilIKy16kAT52EwB6zSaI+M3xxIhKh5+aldEyvI+z6erM7TCLB2BJsFrtHjp6/4/sr+3dA== +"@aws-crypto/crc32@5.2.0": + version "5.2.0" + resolved "https://registry.yarnpkg.com/@aws-crypto/crc32/-/crc32-5.2.0.tgz#cfcc22570949c98c6689cfcbd2d693d36cdae2e1" + integrity sha512-nLbCWqQNgUiwwtFsen1AdzAtvuLRsQS8rYgMuxCrdKf9kOssamGLuPwyTY9wyYblNr9+1XM8v6zoDTPPSIeANg== dependencies: - "@aws-crypto/util" "^3.0.0" + "@aws-crypto/util" "^5.2.0" "@aws-sdk/types" "^3.222.0" - tslib "^1.11.1" + tslib "^2.6.2" -"@aws-crypto/util@^3.0.0": - version "3.0.0" - resolved "https://registry.yarnpkg.com/@aws-crypto/util/-/util-3.0.0.tgz#1c7ca90c29293f0883468ad48117937f0fe5bfb0" - integrity sha512-2OJlpeJpCR48CC8r+uKVChzs9Iungj9wkZrl8Z041DWEWvyIHILYKCPNzJghKsivj+S3mLo6BVc7mBNzdxA46w== +"@aws-crypto/sha256-js@^5.2.0": + version "5.2.0" + resolved "https://registry.yarnpkg.com/@aws-crypto/sha256-js/-/sha256-js-5.2.0.tgz#c4fdb773fdbed9a664fc1a95724e206cf3860042" + integrity sha512-FFQQyu7edu4ufvIZ+OadFpHHOt+eSTBaYaki44c+akjg7qZg9oOQeLlk77F6tSYqjDAFClrHJk9tMf0HdVyOvA== dependencies: + "@aws-crypto/util" "^5.2.0" "@aws-sdk/types" "^3.222.0" - "@aws-sdk/util-utf8-browser" "^3.0.0" - tslib "^1.11.1" + tslib "^2.6.2" + +"@aws-crypto/util@^5.2.0": + version "5.2.0" + resolved "https://registry.yarnpkg.com/@aws-crypto/util/-/util-5.2.0.tgz#71284c9cffe7927ddadac793c14f14886d3876da" + integrity sha512-4RkU9EsI6ZpBve5fseQlGNUWKMa1RLPQ1dnjnQoe07ldfIzcsGb5hC5W0Dm7u423KWzawlrpbjXBrXCEv9zazQ== + dependencies: + "@aws-sdk/types" "^3.222.0" + "@smithy/util-utf8" "^2.0.0" + tslib "^2.6.2" "@aws-sdk/types@^3.222.0": version "3.577.0" @@ -103,13 +112,6 @@ "@smithy/types" "^3.0.0" tslib "^2.6.2" -"@aws-sdk/util-utf8-browser@^3.0.0": - version "3.259.0" - resolved "https://registry.yarnpkg.com/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz#3275a6f5eb334f96ca76635b961d3c50259fd9ff" - integrity sha512-UvFa/vR+e19XookZF8RzFZBrw2EUkQWxiBW0yYQAhvk3C+QVGl0H3ouca8LDBlBfQKXwmW3huo/59H8rwb1wJw== - dependencies: - tslib "^2.3.1" - "@babel/cli@^7.24.1": version "7.24.1" resolved "https://registry.yarnpkg.com/@babel/cli/-/cli-7.24.1.tgz#2e11e071e32fe82850b4fe514f56b9c9e1c44911" @@ -6965,33 +6967,33 @@ resolved "https://registry.yarnpkg.com/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz#8ace5259254426ccef57f3175bc64ed7095ed919" integrity sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw== -"@langchain/community@^0.2.4": - version "0.2.4" - resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.4.tgz#fb5feb4f4a01a1b33adfd28ce7126d0dedb3e6d1" - integrity sha512-rwrPNQLyIe84TPqPYbYOfDA4G/ba1rdj7OtZg63dQmxIvNDOmUCh4xIQac2iuRUnM3o4Ben0Faa9qz+V5oPgIA== +"@langchain/community@^0.2.13": + version "0.2.13" + resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.13.tgz#7a327663d2a6006456ff136d2b16cb2b1a76b541" + integrity sha512-f0GZCGM5XP0r+H643GpUU4YelKHsUdhUY1Kb8rKpCoy8zgs1nUkiYDVylAf0ezwUOT4NYCEuwpw0jj8hQSLn1Q== dependencies: - "@langchain/core" "~0.2.0" - "@langchain/openai" "~0.0.28" + "@langchain/core" "~0.2.9" + "@langchain/openai" "~0.1.0" binary-extensions "^2.2.0" expr-eval "^2.0.2" flat "^5.0.2" js-yaml "^4.1.0" langchain "0.2.3" - langsmith "~0.1.1" + langsmith "~0.1.30" uuid "^9.0.0" zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@0.2.3", "@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.56 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@~0.2.0": - version "0.2.3" - resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.3.tgz#7faa82f92b0c7843506e827a38bfcbb60f009d13" - integrity sha512-mVuFHSLpPQ4yOHNXeoSA3LnmIMuFmUiit5rvbYcPZqM6SrB2zCNN2nD4Ty5+3H5X4tYItDoSqsTuUNUQySXRQw== +"@langchain/core@0.2.9", "@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.8 <0.3.0", "@langchain/core@^0.2.9", "@langchain/core@~0.2.0", "@langchain/core@~0.2.9": + version "0.2.9" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.9.tgz#8a22585ef11d2ca8742a8bbfe77dd25baedc7779" + integrity sha512-pJshopBZqMNF020q0OrrO+vfApWTZUlZecRYMM7TWA5M8/zvEyU/mgA9DlzeRjjDmG6pwF6dIKVjpl6fIGVXlQ== dependencies: ansi-styles "^5.0.0" camelcase "6" decamelize "1.2.0" js-tiktoken "^1.0.12" - langsmith "~0.1.7" + langsmith "~0.1.30" ml-distance "^4.0.0" mustache "^4.2.0" p-queue "^6.6.2" @@ -7000,22 +7002,22 @@ zod "^3.22.4" zod-to-json-schema "^3.22.3" -"@langchain/langgraph@^0.0.23": - version "0.0.23" - resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.23.tgz#34b5ad5dc9fe644ee96bcfcf11197ec1d7f9e0e2" - integrity sha512-pXlcsBOseT5xdf9enUqbLQ/59LaZxgMI2dL2vFJ+EpcoK7bQnlzzhRtRPp+vubMyMeEKRoAXlaA9ObwpVi93CA== +"@langchain/langgraph@^0.0.24": + version "0.0.24" + resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.24.tgz#814e686dcef700d15a6f90e27e9c9f79d75faef4" + integrity sha512-fW9cnz62oKZFAlyO/4oEjXxthrqZPQtqyX4f7ttKEi0rJZKeuoohvOtnC8faq6nrtMtX9JpLixHjK0SgN7XN3g== dependencies: "@langchain/core" ">0.1.61 <0.3.0" uuid "^9.0.1" -"@langchain/openai@^0.0.34", "@langchain/openai@~0.0.28": - version "0.0.34" - resolved "https://registry.yarnpkg.com/@langchain/openai/-/openai-0.0.34.tgz#36c9bca0721ab9f7e5d40927e7c0429cacbd5b56" - integrity sha512-M+CW4oXle5fdoz2T2SwdOef8pl3/1XmUx1vjn2mXUVM/128aO0l23FMF0SNBsAbRV6P+p/TuzjodchJbi0Ht/A== +"@langchain/openai@0.2.0", "@langchain/openai@>=0.1.0 <0.3.0", "@langchain/openai@^0.2.0", "@langchain/openai@~0.1.0": + version "0.2.0" + resolved "https://registry.yarnpkg.com/@langchain/openai/-/openai-0.2.0.tgz#342e49d15b946fa01128d1bb81357e688e7cf567" + integrity sha512-gZd+0IOxpiKuh1m6KTT5vtUoOO72GEYyoU4+c6qAUucOEqQS0Vvz3lMGyNWLjK4x4Xpd+r8GAF5mj/jvghwP1A== dependencies: - "@langchain/core" ">0.1.56 <0.3.0" + "@langchain/core" ">=0.2.8 <0.3.0" js-tiktoken "^1.0.12" - openai "^4.41.1" + openai "^4.49.1" zod "^3.22.4" zod-to-json-schema "^3.22.3" @@ -8290,32 +8292,32 @@ "@types/node" ">=18.0.0" axios "^1.6.0" -"@smithy/eventstream-codec@^3.0.0": - version "3.0.0" - resolved "https://registry.yarnpkg.com/@smithy/eventstream-codec/-/eventstream-codec-3.0.0.tgz#81d30391220f73d41f432f65384b606d67673e46" - integrity sha512-PUtyEA0Oik50SaEFCZ0WPVtF9tz/teze2fDptW6WRXl+RrEenH8UbEjudOz8iakiMl3lE3lCVqYf2Y+znL8QFQ== +"@smithy/eventstream-codec@^3.1.1": + version "3.1.1" + resolved "https://registry.yarnpkg.com/@smithy/eventstream-codec/-/eventstream-codec-3.1.1.tgz#b47f30bf4ad791ac7981b9fff58e599d18269cf9" + integrity sha512-s29NxV/ng1KXn6wPQ4qzJuQDjEtxLdS0+g5PQFirIeIZrp66FXVJ5IpZRowbt/42zB5dY8TqJ0G0L9KkgtsEZg== dependencies: - "@aws-crypto/crc32" "3.0.0" - "@smithy/types" "^3.0.0" + "@aws-crypto/crc32" "5.2.0" + "@smithy/types" "^3.2.0" "@smithy/util-hex-encoding" "^3.0.0" tslib "^2.6.2" -"@smithy/eventstream-serde-node@^3.0.0": - version "3.0.0" - resolved "https://registry.yarnpkg.com/@smithy/eventstream-serde-node/-/eventstream-serde-node-3.0.0.tgz#6519523fbb429307be29b151b8ba35bcca2b6e64" - integrity sha512-baRPdMBDMBExZXIUAoPGm/hntixjt/VFpU6+VmCyiYJYzRHRxoaI1MN+5XE+hIS8AJ2GCHLMFEIOLzq9xx1EgQ== +"@smithy/eventstream-serde-node@^3.0.3": + version "3.0.3" + resolved "https://registry.yarnpkg.com/@smithy/eventstream-serde-node/-/eventstream-serde-node-3.0.3.tgz#51df0ca39f453d78a3d6607c1ac2e96cf900c824" + integrity sha512-v61Ftn7x/ubWFqH7GHFAL/RaU7QZImTbuV95DYugYYItzpO7KaHYEuO8EskCaBpZEfzOxhUGKm4teS9YUSt69Q== dependencies: - "@smithy/eventstream-serde-universal" "^3.0.0" - "@smithy/types" "^3.0.0" + "@smithy/eventstream-serde-universal" "^3.0.3" + "@smithy/types" "^3.2.0" tslib "^2.6.2" -"@smithy/eventstream-serde-universal@^3.0.0": - version "3.0.0" - resolved "https://registry.yarnpkg.com/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-3.0.0.tgz#cb8441a73fbde4cbaa68e4a21236f658d914a073" - integrity sha512-HNFfShmotWGeAoW4ujP8meV9BZavcpmerDbPIjkJbxKbN8RsUcpRQ/2OyIxWNxXNH2GWCAxuSB7ynmIGJlQ3Dw== +"@smithy/eventstream-serde-universal@^3.0.3": + version "3.0.3" + resolved "https://registry.yarnpkg.com/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-3.0.3.tgz#2ecac479ba84e10221b4b70545f3d7a223b5345e" + integrity sha512-YXYt3Cjhu9tRrahbTec2uOjwOSeCNfQurcWPGNEUspBhqHoA3KrDrVj+jGbCLWvwkwhzqDnnaeHAxm+IxAjOAQ== dependencies: - "@smithy/eventstream-codec" "^3.0.0" - "@smithy/types" "^3.0.0" + "@smithy/eventstream-codec" "^3.1.1" + "@smithy/types" "^3.2.0" tslib "^2.6.2" "@smithy/is-array-buffer@^3.0.0": @@ -8325,10 +8327,31 @@ dependencies: tslib "^2.6.2" -"@smithy/types@^3.0.0": - version "3.0.0" - resolved "https://registry.yarnpkg.com/@smithy/types/-/types-3.0.0.tgz#00231052945159c64ffd8b91e8909d8d3006cb7e" - integrity sha512-VvWuQk2RKFuOr98gFhjca7fkBS+xLLURT8bUjk5XQoV0ZLm7WPwWPPY3/AwzTLuUBDeoKDCthfe1AsTUWaSEhw== +"@smithy/protocol-http@^4.0.2": + version "4.0.2" + resolved "https://registry.yarnpkg.com/@smithy/protocol-http/-/protocol-http-4.0.2.tgz#502ed3116cb0f1e3f207881df965bac620ccb2da" + integrity sha512-X/90xNWIOqSR2tLUyWxVIBdatpm35DrL44rI/xoeBWUuanE0iyCXJpTcnqlOpnEzgcu0xCKE06+g70TTu2j7RQ== + dependencies: + "@smithy/types" "^3.2.0" + tslib "^2.6.2" + +"@smithy/signature-v4@^3.1.1": + version "3.1.1" + resolved "https://registry.yarnpkg.com/@smithy/signature-v4/-/signature-v4-3.1.1.tgz#4882aacb3260a47b8279b2ffc6a135e03e225260" + integrity sha512-2/vlG86Sr489XX8TA/F+VDA+P04ESef04pSz0wRtlQBExcSPjqO08rvrkcas2zLnJ51i+7ukOURCkgqixBYjSQ== + dependencies: + "@smithy/is-array-buffer" "^3.0.0" + "@smithy/types" "^3.2.0" + "@smithy/util-hex-encoding" "^3.0.0" + "@smithy/util-middleware" "^3.0.2" + "@smithy/util-uri-escape" "^3.0.0" + "@smithy/util-utf8" "^3.0.0" + tslib "^2.6.2" + +"@smithy/types@^3.0.0", "@smithy/types@^3.2.0": + version "3.2.0" + resolved "https://registry.yarnpkg.com/@smithy/types/-/types-3.2.0.tgz#1350fe8a50d5e35e12ffb34be46d946860b2b5ab" + integrity sha512-cKyeKAPazZRVqm7QPvcPD2jEIt2wqDPAL1KJKb0f/5I7uhollvsWZuZKLclmyP6a+Jwmr3OV3t+X0pZUUHS9BA== dependencies: tslib "^2.6.2" @@ -8347,7 +8370,22 @@ dependencies: tslib "^2.6.2" -"@smithy/util-utf8@^3.0.0": +"@smithy/util-middleware@^3.0.2": + version "3.0.2" + resolved "https://registry.yarnpkg.com/@smithy/util-middleware/-/util-middleware-3.0.2.tgz#6daeb9db060552d851801cd7a0afd68769e2f98b" + integrity sha512-7WW5SD0XVrpfqljBYzS5rLR+EiDzl7wCVJZ9Lo6ChNFV4VYDk37Z1QI5w/LnYtU/QKnSawYoHRd7VjSyC8QRQQ== + dependencies: + "@smithy/types" "^3.2.0" + tslib "^2.6.2" + +"@smithy/util-uri-escape@^3.0.0": + version "3.0.0" + resolved "https://registry.yarnpkg.com/@smithy/util-uri-escape/-/util-uri-escape-3.0.0.tgz#e43358a78bf45d50bb736770077f0f09195b6f54" + integrity sha512-LqR7qYLgZTD7nWLBecUi4aqolw8Mhza9ArpNEQ881MJJIU2sE5iHCK6TdyqqzcDLy0OPe10IY4T8ctVdtynubg== + dependencies: + tslib "^2.6.2" + +"@smithy/util-utf8@3.0.0", "@smithy/util-utf8@^2.0.0", "@smithy/util-utf8@^3.0.0": version "3.0.0" resolved "https://registry.yarnpkg.com/@smithy/util-utf8/-/util-utf8-3.0.0.tgz#1a6a823d47cbec1fd6933e5fc87df975286d9d6a" integrity sha512-rUeT12bxFnplYDe815GXbq/oixEGHfRFFtcTF3YdDi/JaENIM6aSYYLJydG83UNzLXeRI5K8abYd/8Sp/QM0kA== @@ -21710,20 +21748,20 @@ kuler@^2.0.0: resolved "https://registry.yarnpkg.com/kuler/-/kuler-2.0.0.tgz#e2c570a3800388fb44407e851531c1d670b061b3" integrity sha512-Xq9nH7KlWZmXAtodXDDRE7vs6DU1gTU8zYDHDiWLSip45Egwq3plLHzPn27NgvzL2r1LMPC1vdqh98sQxtqj4A== -langchain@0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.2.3.tgz#c14bb05cf871b21bd63b84b3ab89580b1d62539f" - integrity sha512-T9xR7zd+Nj0oXy6WoYKmZLy0DlQiDLFPGYWdOXDxy+AvqlujoPdVQgDSpdqiOHvAjezrByAoKxoHCz5XMwTP/Q== +langchain@0.2.3, langchain@0.2.6, langchain@^0.2.6: + version "0.2.6" + resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.2.6.tgz#22249707ba800c38ec9a6cca36e383a881227393" + integrity sha512-vDAJHGu/lA4pn3hkyzSC6RiaZhtj0ozfRyG8L6J2vCnXyJV/lgk9uGMP2x645EBrSozBMHJBng1UYeaUR/1fQQ== dependencies: "@langchain/core" "~0.2.0" - "@langchain/openai" "~0.0.28" + "@langchain/openai" ">=0.1.0 <0.3.0" "@langchain/textsplitters" "~0.0.0" binary-extensions "^2.2.0" js-tiktoken "^1.0.12" js-yaml "^4.1.0" jsonpointer "^5.0.1" langchainhub "~0.0.8" - langsmith "~0.1.7" + langsmith "~0.1.30" ml-distance "^4.0.0" openapi-types "^12.1.3" p-retry "4" @@ -21737,10 +21775,10 @@ langchainhub@~0.0.8: resolved "https://registry.yarnpkg.com/langchainhub/-/langchainhub-0.0.8.tgz#fd4b96dc795e22e36c1a20bad31b61b0c33d3110" integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== -langsmith@^0.1.30, langsmith@~0.1.1, langsmith@~0.1.7: - version "0.1.30" - resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.30.tgz#3000e441605b26e15a87fb991a3929c944edbc0a" - integrity sha512-g8f10H1iiRjCweXJjgM3Y9xl6ApCa1OThDvc0BlSDLVrGVPy1on9wT39vAzYkeadC7oG48p7gfpGlYH3kLkJ9Q== +langsmith@^0.1.32, langsmith@~0.1.30: + version "0.1.32" + resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.32.tgz#38938b0e8685522087b697b8200c488c6490c137" + integrity sha512-EUWHIH6fiOCGRYdzgwGoXwJxCMyUrL+bmUcxoVmkXoXoAGDOVinz8bqJLKbxotsQWqM64NKKsW85OTIutgNaMQ== dependencies: "@types/uuid" "^9.0.1" commander "^10.0.1" @@ -24353,10 +24391,10 @@ open@^8.0.9, open@^8.4.0, open@~8.4.0: is-docker "^2.1.1" is-wsl "^2.2.0" -openai@^4.24.1, openai@^4.41.1: - version "4.47.1" - resolved "https://registry.yarnpkg.com/openai/-/openai-4.47.1.tgz#1d23c7a8eb3d7bcdc69709cd905f4c9af0181dba" - integrity sha512-WWSxhC/69ZhYWxH/OBsLEirIjUcfpQ5+ihkXKp06hmeYXgBBIUCa9IptMzYx6NdkiOCsSGYCnTIsxaic3AjRCQ== +openai@^4.24.1, openai@^4.49.1: + version "4.51.0" + resolved "https://registry.yarnpkg.com/openai/-/openai-4.51.0.tgz#8ab08bba2441375e8e4ce6161f9ac987d2b2c157" + integrity sha512-UKuWc3/qQyklqhHM8CbdXCv0Z0obap6T0ECdcO5oATQxAbKE5Ky3YCXFQY207z+eGG6ez4U9wvAcuMygxhmStg== dependencies: "@types/node" "^18.11.18" "@types/node-fetch" "^2.6.4" @@ -30418,7 +30456,7 @@ tslib@2.6.2, tslib@^2.0.0, tslib@^2.0.1, tslib@^2.0.3, tslib@^2.1.0, tslib@^2.3. resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.6.2.tgz#703ac29425e7b37cd6fd456e92404d46d1f3e4ae" integrity sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q== -tslib@^1.10.0, tslib@^1.11.1, tslib@^1.8.1, tslib@^1.9.0, tslib@^1.9.3: +tslib@^1.10.0, tslib@^1.8.1, tslib@^1.9.0, tslib@^1.9.3: version "1.14.1" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00" integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg== From 538b846469bb4c9bc7ed2a5c273240922276521b Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Mon, 24 Jun 2024 12:07:36 +0000 Subject: [PATCH 02/55] [CI] Auto-commit changed files from 'node scripts/lint_ts_projects --fix' --- x-pack/plugins/integration_assistant/tsconfig.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugins/integration_assistant/tsconfig.json b/x-pack/plugins/integration_assistant/tsconfig.json index 9a18a1ca1794b..92cdd82c56958 100644 --- a/x-pack/plugins/integration_assistant/tsconfig.json +++ b/x-pack/plugins/integration_assistant/tsconfig.json @@ -15,7 +15,6 @@ "kbn_references": [ "@kbn/core", "@kbn/config-schema", - "@kbn/langchain", "@kbn/core-elasticsearch-server", "@kbn/actions-plugin", "@kbn/data-plugin", @@ -30,6 +29,7 @@ "@kbn/spaces-plugin", "@kbn/triggers-actions-ui-plugin", "@kbn/shared-ux-router", - "@kbn/zod-helpers" + "@kbn/zod-helpers", + "@kbn/elastic-assistant-plugin" ] } From 34ff58b808755868023685b860168f5aa1420ac1 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Tue, 25 Jun 2024 14:22:06 +0200 Subject: [PATCH 03/55] test --- .../server/language_models/bedrock_chat.ts | 6 ++++++ .../execute_custom_llm_chain/index.ts | 18 +++++++++++++++++- .../elastic_assistant/server/routes/utils.ts | 2 +- .../stack_connectors/common/bedrock/types.ts | 2 ++ .../server/connector_types/bedrock/bedrock.ts | 17 +++++++++++++++++ 5 files changed, 43 insertions(+), 2 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 6a28d8bd7606a..2925762caf547 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -112,6 +112,12 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { }, })) as unknown as Promise; + if (bedrockMethod === 'invoke-with-response-stream') { + return { + body: data.data, + }; + } + return { ok: data.status === 'ok', json: () => data.data, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 21b031b073722..95914bcac1052 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -172,6 +172,9 @@ export const callAgentExecutor: AgentExecutor = async ({ let message = ''; let tokenParentRunId = ''; + let finalOutputIndex = -1; + const finalOutputStartToken = '"action":"FinalAnswer","action_input":"'; + const finalOutputStopRegex = /(? = async ({ tokenParentRunId = parentRunId; } if (payload.length && !didEnd && tokenParentRunId === parentRunId) { - push({ payload, type: 'content' }); + const finalOutputEndIndex = payload.search(finalOutputStopRegex); + const currentOutput = message.replace(/\s/g, ''); + + if (currentOutput.includes(finalOutputStartToken)) { + finalOutputIndex = currentOutput.indexOf(finalOutputStartToken); + } + + if (finalOutputIndex > -1) { + push({ payload, type: 'content' }); + } + + if (finalOutputIndex > -1 && finalOutputEndIndex > -1) { + didEnd = true; + } // store message in case of error message += payload; } diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 2c54c2a927031..b4ba5c435166e 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -183,7 +183,7 @@ export const getLlmType = (actionTypeId: string): string | undefined => { export const getLlmClass = (llmType?: string, isStreaming?: boolean) => llmType === 'openai' ? ActionsClientChatOpenAI - : llmType === 'bedrock' && !isStreaming + : llmType === 'bedrock' ? ActionsClientBedrockChatModel : ActionsClientSimpleChatModel; diff --git a/x-pack/plugins/stack_connectors/common/bedrock/types.ts b/x-pack/plugins/stack_connectors/common/bedrock/types.ts index 1256831fc7fa0..1847a25893d06 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/types.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/types.ts @@ -15,6 +15,7 @@ import { RunActionResponseSchema, InvokeAIActionParamsSchema, InvokeAIActionResponseSchema, + InvokeAIRawActionParamsSchema, StreamingResponseSchema, RunApiLatestResponseSchema, } from './schema'; @@ -23,6 +24,7 @@ export type Config = TypeOf; export type Secrets = TypeOf; export type RunActionParams = TypeOf; export type InvokeAIActionParams = TypeOf; +export type InvokeAIRawActionParams = TypeOf; export type InvokeAIActionResponse = TypeOf; export type RunApiLatestResponse = TypeOf; export type RunActionResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index a03e6bbff8e71..1d129dfcdf278 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -349,6 +349,23 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, }: InvokeAIRawActionParams) { + if (bedrockMethod === 'invoke-with-response-stream') { + const body = JSON.stringify( + formatBedrockBody({ messages, stopSequences, system, temperature }) + ); + // set model on per request basis + const path = `/model/${model ?? this.model}/invoke-with-response-stream`; + const signed = this.signRequest(body, path, true); + + const res = await fetch(`${this.url}${path}`, { + headers: signed.headers, + body: signed.body, + method: 'POST', + }); + + return res.body; + } + // set model on per request basis const currentModel = model ?? this.model; const path = `/model/${currentModel}/${bedrockMethod}`; From b87106f0e07385a26ffa595214d9b4691112cdfc Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sat, 29 Jun 2024 20:41:48 +0200 Subject: [PATCH 04/55] cleanup --- .../server/language_models/bedrock_chat.ts | 149 +++--- .../utils/bedrock/anthropic.ts | 226 -------- .../language_models/utils/bedrock/index.ts | 486 ------------------ .../application/connector/methods/get/get.ts | 1 - .../execute_custom_llm_chain/index.ts | 42 +- .../server/lib/langchain/executors/types.ts | 2 - .../graphs/default_assistant_graph/index.ts | 40 +- .../routes/post_actions_connector_execute.ts | 24 +- .../elastic_assistant/server/routes/utils.ts | 2 - .../common/bedrock/constants.ts | 1 - .../stack_connectors/common/bedrock/schema.ts | 2 - .../stack_connectors/common/bedrock/types.ts | 2 - .../server/connector_types/bedrock/bedrock.ts | 68 +-- 13 files changed, 108 insertions(+), 937 deletions(-) delete mode 100644 x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts delete mode 100644 x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 2925762caf547..76a282a751714 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -5,16 +5,15 @@ * 2.0. */ -import { - BedrockChat as _BedrockChat, - convertMessagesToPromptAnthropic, -} from '@langchain/community/chat_models/bedrock/web'; +import { BedrockChat as _BedrockChat } from '@langchain/community/chat_models/bedrock/web'; import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; -import { BaseMessage } from '@langchain/core/messages'; import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; import { Logger } from '@kbn/logging'; import { KibanaRequest } from '@kbn/core/server'; -import { BaseBedrockInput, BedrockLLMInputOutputAdapter } from './utils/bedrock'; +import { Readable } from 'stream'; + +export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; +export const DEFAULT_BEDROCK_REGION = 'us-east-1'; export class ActionsClientBedrockChatModel extends _BedrockChat { // Kibana variables @@ -34,93 +33,75 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { connectorId: string; logger: Logger; request: KibanaRequest; - } & Partial & - BaseChatModelParams) { + } & BaseChatModelParams) { // Just to make Langchain BedrockChat happy super({ ...params, credentials: { accessKeyId: '', secretAccessKey: '' }, - }); + usesMessagesApi: true, + // only needed to force BedrockChat to use messages api for Claude v2 + model: DEFAULT_BEDROCK_MODEL, + region: DEFAULT_BEDROCK_REGION, + fetchFn: async (url, options) => { + // create an actions client from the authenticated request context: + const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - this.#actions = actions; - this.#request = request; - this.#connectorId = connectorId; - this.#logger = logger; - } + const inputBody = JSON.parse(options?.body); - async _signedFetch( - messages: BaseMessage[], - options: this['ParsedCallOptions'], - fields: { - bedrockMethod: 'invoke' | 'invoke-with-response-stream'; - endpointHost: string; - provider: string; - } - ) { - const { bedrockMethod, endpointHost, provider } = fields; - const { - max_tokens: maxTokens, - temperature, - stop, - modelKwargs, - guardrailConfig, - tools, - } = this.invocationParams(options); - const inputBody = this.usesMessagesApi - ? BedrockLLMInputOutputAdapter.prepareMessagesInput( - provider, - messages, - maxTokens, - temperature, - stop, - modelKwargs, - guardrailConfig, - tools, - this.#logger - ) - : BedrockLLMInputOutputAdapter.prepareInput( - provider, - convertMessagesToPromptAnthropic(messages), - maxTokens, - temperature, - stop, - modelKwargs, - fields.bedrockMethod, - guardrailConfig - ); + if (this.streaming) { + const data = await actionsClient.execute({ + actionId: this.#connectorId, + params: { + subAction: 'invokeStream', + subActionParams: { + // bedrockMethod: 'invoke-with-response-stream', + model: this.model, + // endpointHost: this.endpointHost, + // anthropicVersion: inputBody.anthropicVersion, + messages: inputBody.messages, + temperature: inputBody.temperature, + stopSequences: inputBody.stopSequences, + system: inputBody.system, + maxTokens: inputBody.maxTokens, + }, + }, + }); - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); + console.error('data', data); - const data = (await actionsClient.execute({ - actionId: this.#connectorId, - params: { - subAction: 'invokeAIRaw', - subActionParams: { - bedrockMethod, - model: this.model, - endpointHost, - anthropicVersion: inputBody.anthropicVersion, - messages: inputBody.messages, - temperature: inputBody.temperature, - stopSequences: inputBody.stopSequences, - system: inputBody.system, - maxTokens: inputBody.maxTokens, - signal: options.signal, - timeout: options.timeout, - }, - }, - })) as unknown as Promise; + return { + body: Readable.toWeb(data.data), + }; + } - if (bedrockMethod === 'invoke-with-response-stream') { - return { - body: data.data, - }; - } + const data = await actionsClient.execute({ + actionId: this.#connectorId, + params: { + subAction: 'invokeAI', + subActionParams: { + model: this.model, + messages: inputBody.messages, + temperature: inputBody.temperature, + stopSequences: inputBody.stopSequences, + system: inputBody.system, + maxTokens: inputBody.maxTokens, + }, + }, + }); - return { - ok: data.status === 'ok', - json: () => data.data, - }; + return { + ok: data.status === 'ok', + json: () => ({ + content: data.data.message, + type: 'message', + }), + }; + }, + }); + + this.#actions = actions; + this.#request = request; + this.#connectorId = connectorId; + this.#logger = logger; } } diff --git a/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts b/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts deleted file mode 100644 index a9b04dc4e0dcf..0000000000000 --- a/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/anthropic.ts +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -// origin: https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain-community/src/utils/bedrock/anthropic.ts -// Error: Package subpath './dist/utils/bedrock/anthropic' is not defined by "exports" in langchain/community/package.json - -import { Logger } from '@kbn/logging'; -import { - AIMessage, - BaseMessage, - HumanMessage, - MessageContent, - SystemMessage, - ToolMessage, - isAIMessage, -} from '@langchain/core/messages'; -import { ToolCall } from '@langchain/core/messages/tool'; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export function extractToolCalls(content: Array>) { - const toolCalls: ToolCall[] = []; - for (const block of content) { - if (block.type === 'tool_use') { - toolCalls.push({ name: block.name, args: block.input, id: block.id }); - } - } - return toolCalls; -} - -function _formatImage(imageUrl: string) { - const regex = /^data:(image\/.+);base64,(.+)$/; - const match = imageUrl.match(regex); - if (match === null) { - throw new Error( - [ - 'Anthropic only supports base64-encoded images currently.', - 'Example: data:image/png;base64,/9j/4AAQSk...', - ].join('\n\n') - ); - } - return { - type: 'base64', - media_type: match[1] ?? '', - data: match[2] ?? '', - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any; -} - -function _mergeMessages(messages: BaseMessage[]): Array { - // Merge runs of human/tool messages into single human messages with content blocks. - const merged: HumanMessage[] = []; - for (const message of messages) { - if (message._getType() === 'tool') { - if (typeof message.content === 'string') { - merged.push( - new HumanMessage({ - content: [ - { - type: 'tool_result', - content: message.content, - tool_use_id: (message as ToolMessage).tool_call_id, - }, - ], - }) - ); - } else { - merged.push(new HumanMessage({ content: message.content })); - } - } else { - const previousMessage = merged[merged.length - 1]; - if (previousMessage?._getType() === 'human' && message._getType() === 'human') { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let combinedContent: Array>; - if (typeof previousMessage.content === 'string') { - combinedContent = [{ type: 'text', text: previousMessage.content }]; - } else { - combinedContent = previousMessage.content; - } - if (typeof message.content === 'string') { - combinedContent.push({ type: 'text', text: message.content }); - } else { - combinedContent = combinedContent.concat(message.content); - } - previousMessage.content = combinedContent; - } else { - merged.push(message); - } - } - } - return merged; -} - -export function _convertLangChainToolCallToAnthropic( - toolCall: ToolCall - // eslint-disable-next-line @typescript-eslint/no-explicit-any -): Record { - if (toolCall.id === undefined) { - throw new Error(`Anthropic requires all tool calls to have an "id".`); - } - return { - type: 'tool_use', - id: toolCall.id, - name: toolCall.name, - input: toolCall.args, - }; -} - -function _formatContent(content: MessageContent) { - if (typeof content === 'string') { - return content; - } else { - const contentBlocks = content.map((contentPart) => { - if (contentPart.type === 'image_url') { - let source; - if (typeof contentPart.image_url === 'string') { - source = _formatImage(contentPart.image_url); - } else { - source = _formatImage(contentPart.image_url.url); - } - return { - type: 'image' as const, // Explicitly setting the type as "image" - source, - }; - } else if (contentPart.type === 'text') { - // Assuming contentPart is of type MessageContentText here - return { - type: 'text' as const, // Explicitly setting the type as "text" - text: contentPart.text, - }; - } else if (contentPart.type === 'tool_use' || contentPart.type === 'tool_result') { - // TODO: Fix when SDK types are fixed - return { - ...contentPart, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any; - } else { - throw new Error('Unsupported message content format'); - } - }); - return contentBlocks; - } -} - -export function formatMessagesForAnthropic( - messages: BaseMessage[], - logger?: Logger -): { - system?: string; - messages: Array>; -} { - const mergedMessages = _mergeMessages(messages); - let system: string | undefined; - if (mergedMessages.length > 0 && mergedMessages[0]._getType() === 'system') { - if (typeof messages[0].content !== 'string') { - throw new Error('System message content must be a string.'); - } - system = messages[0].content; - } - const conversationMessages = system !== undefined ? mergedMessages.slice(1) : mergedMessages; - const formattedMessages = conversationMessages.map((message) => { - let role; - if (message._getType() === 'human') { - role = 'user' as const; - } else if (message._getType() === 'ai') { - role = 'assistant' as const; - } else if (message._getType() === 'tool') { - role = 'user' as const; - } else if (message._getType() === 'system') { - throw new Error('System messages are only permitted as the first passed message.'); - } else { - throw new Error(`Message type "${message._getType()}" is not supported.`); - } - if (isAIMessage(message) && !!message.tool_calls?.length) { - if (typeof message.content === 'string') { - if (message.content === '') { - return { - role, - content: message.tool_calls.map(_convertLangChainToolCallToAnthropic), - }; - } else { - return { - role, - content: [ - { type: 'text', text: message.content }, - ...message.tool_calls.map(_convertLangChainToolCallToAnthropic), - ], - }; - } - } else { - const { content } = message; - const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) => - content.find( - (contentPart) => contentPart.type === 'tool_use' && contentPart.id === toolCall.id - ) - ); - if (hasMismatchedToolCalls) { - logger?.warn( - `The "tool_calls" field on a message is only respected if content is a string.` - ); - } - return { - role, - content: _formatContent(message.content), - }; - } - } else { - return { - role, - content: _formatContent(message.content), - }; - } - }); - return { - messages: formattedMessages, - system, - }; -} - -export function isAnthropicTool(tool: unknown): tool is Record { - if (typeof tool !== 'object' || !tool) return false; - return 'input_schema' in tool; -} diff --git a/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts b/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts deleted file mode 100644 index 7cad55e1b6e06..0000000000000 --- a/x-pack/packages/kbn-langchain/server/language_models/utils/bedrock/index.ts +++ /dev/null @@ -1,486 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -/* eslint-disable complexity */ - -// origin: https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain-community/src/utils/bedrock/index.ts -// // Error: Package subpath './dist/utils/bedrock' is not defined by "exports" in langchain/community/package.json - -import type { AwsCredentialIdentity, Provider } from '@aws-sdk/types'; -import { AIMessage, AIMessageChunk, BaseMessage } from '@langchain/core/messages'; -import { StructuredToolInterface } from '@langchain/core/tools'; -import { ChatGeneration, ChatGenerationChunk } from '@langchain/core/outputs'; -import { Logger } from '@kbn/logging'; -import { extractToolCalls, formatMessagesForAnthropic } from './anthropic'; - -export type CredentialType = AwsCredentialIdentity | Provider; - -/** - * format messages for Cohere Command-R and CommandR+ via AWS Bedrock. - * - * @param messages messages The base messages to format as a prompt. - * - * @returns The formatted prompt for Cohere. - * - * `system`: user system prompts. Overrides the default preamble for search query generation. Has no effect on tool use generations.\ - * `message`: (Required) Text input for the model to respond to.\ - * `chatHistory`: A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.\ - * The following are required fields. - * - `role` - The role for the message. Valid values are USER or CHATBOT.\ - * - `message` – Text contents of the message.\ - * - * The following is example JSON for the chat_history field.\ - * "chat_history": [ - * {"role": "USER", "message": "Who discovered gravity?"}, - * {"role": "CHATBOT", "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}]\ - * - * docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html - */ -function formatMessagesForCohere(messages: BaseMessage[]): { - system?: string; - message: string; - chatHistory: Array>; -} { - const systemMessages = messages.filter((system) => system._getType() === 'system'); - - const system = systemMessages - .filter((m) => typeof m.content === 'string') - .map((m) => m.content) - .join('\n\n'); - - const conversationMessages = messages.filter((message) => message._getType() !== 'system'); - - const questionContent = conversationMessages.slice(-1); - - if (!questionContent.length || questionContent[0]._getType() !== 'human') { - throw new Error('question message content must be a human message.'); - } - - if (typeof questionContent[0].content !== 'string') { - throw new Error('question message content must be a string.'); - } - - const formattedMessage = questionContent[0].content; - - const formattedChatHistories = conversationMessages.slice(0, -1).map((message) => { - let role; - switch (message._getType()) { - case 'human': - role = 'USER' as const; - break; - case 'ai': - role = 'CHATBOT' as const; - break; - case 'system': - throw new Error('chat_history can not include system prompts.'); - default: - throw new Error(`Message type "${message._getType()}" is not supported.`); - } - - if (typeof message.content !== 'string') { - throw new Error('message content must be a string.'); - } - return { - role, - message: message.content, - }; - }); - - return { - chatHistory: formattedChatHistories, - message: formattedMessage, - system, - }; -} - -/** Bedrock models. - To authenticate, the AWS client uses the following methods to automatically load credentials: - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html - If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. - Make sure the credentials / roles used have the required policies to access the Bedrock service. -*/ -export interface BaseBedrockInput { - /** Model to use. - For example, "amazon.titan-tg1-large", this is equivalent to the modelId property in the list-foundation-models api. - */ - model: string; - - /** The AWS region e.g. `us-west-2`. - Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here. - */ - region?: string; - - /** AWS Credentials. - If no credentials are provided, the default credentials from `@aws-sdk/credential-provider-node` will be used. - */ - credentials?: CredentialType; - - /** Temperature. */ - temperature?: number; - - /** Max tokens. */ - maxTokens?: number; - - /** A custom fetch function for low-level access to AWS API. Defaults to fetch(). */ - fetchFn?: typeof fetch; - - /** @deprecated Use endpointHost instead Override the default endpoint url. */ - endpointUrl?: string; - - /** Override the default endpoint hostname. */ - endpointHost?: string; - - /** - * Optional additional stop sequences to pass to the model. Currently only supported for Anthropic and AI21. - * @deprecated Use .bind({ "stop": [...] }) instead - * */ - stopSequences?: string[]; - - /** Additional kwargs to pass to the model. */ - modelKwargs?: Record; - - /** Whether or not to stream responses */ - streaming: boolean; - - /** Trace settings for the Bedrock Guardrails. */ - trace?: 'ENABLED' | 'DISABLED'; - - /** Identifier for the guardrail configuration. */ - guardrailIdentifier?: string; - - /** Version for the guardrail configuration. */ - guardrailVersion?: string; - - /** Required when Guardrail is in use. */ - guardrailConfig?: { - tagSuffix: string; - streamProcessingMode: 'SYNCHRONOUS' | 'ASYNCHRONOUS'; - }; -} - -interface Dict { - [key: string]: unknown; -} - -/** - * A helper class used within the `Bedrock` class. It is responsible for - * preparing the input and output for the Bedrock service. It formats the - * input prompt based on the provider (e.g., "anthropic", "ai21", - * "amazon") and extracts the generated text from the service response. - */ -export class BedrockLLMInputOutputAdapter { - /** Adapter class to prepare the inputs from Langchain to a format - that LLM model expects. Also, provides a helper function to extract - the generated text from the model response. */ - - static prepareInput( - provider: string, - prompt: string, - maxTokens = 50, - temperature = 0, - stopSequences: string[] | undefined = undefined, - modelKwargs: Record = {}, - bedrockMethod: 'invoke' | 'invoke-with-response-stream' = 'invoke', - guardrailConfig: - | { - tagSuffix: string; - streamProcessingMode: 'SYNCHRONOUS' | 'ASYNCHRONOUS'; - } - | undefined = undefined - ): Dict { - const inputBody: Dict = {}; - - if (provider === 'anthropic') { - inputBody.prompt = prompt; - inputBody.max_tokens_to_sample = maxTokens; - inputBody.temperature = temperature; - inputBody.stop_sequences = stopSequences; - } else if (provider === 'ai21') { - inputBody.prompt = prompt; - inputBody.maxTokens = maxTokens; - inputBody.temperature = temperature; - inputBody.stopSequences = stopSequences; - } else if (provider === 'meta') { - inputBody.prompt = prompt; - inputBody.max_gen_len = maxTokens; - inputBody.temperature = temperature; - } else if (provider === 'amazon') { - inputBody.inputText = prompt; - inputBody.textGenerationConfig = { - maxTokenCount: maxTokens, - temperature, - }; - } else if (provider === 'cohere') { - inputBody.prompt = prompt; - inputBody.max_tokens = maxTokens; - inputBody.temperature = temperature; - inputBody.stop_sequences = stopSequences; - if (bedrockMethod === 'invoke-with-response-stream') { - inputBody.stream = true; - } - } else if (provider === 'mistral') { - inputBody.prompt = prompt; - inputBody.max_tokens = maxTokens; - inputBody.temperature = temperature; - inputBody.stop = stopSequences; - } - - if (guardrailConfig && guardrailConfig.tagSuffix && guardrailConfig.streamProcessingMode) { - inputBody['amazon-bedrock-guardrailConfig'] = guardrailConfig; - } - - return { ...inputBody, ...modelKwargs }; - } - - static prepareMessagesInput( - provider: string, - messages: BaseMessage[], - maxTokens = 1024, - temperature = 0, - stopSequences: string[] | undefined = undefined, - modelKwargs: Record = {}, - guardrailConfig: - | { - tagSuffix: string; - streamProcessingMode: 'SYNCHRONOUS' | 'ASYNCHRONOUS'; - } - | undefined = undefined, - tools: Array> = [], - logger: Logger - ): Dict { - const inputBody: Dict = {}; - - if (provider === 'anthropic') { - const { system, messages: formattedMessages } = formatMessagesForAnthropic(messages, logger); - if (system !== undefined) { - inputBody.system = system; - } - inputBody.anthropic_version = 'bedrock-2023-05-31'; - inputBody.messages = formattedMessages; - inputBody.max_tokens = maxTokens; - inputBody.temperature = temperature; - inputBody.stop_sequences = stopSequences; - - if (tools.length > 0) { - inputBody.tools = tools; - } - return { ...inputBody, ...modelKwargs }; - } else if (provider === 'cohere') { - const { - system, - message: formattedMessage, - chatHistory: formattedChatHistories, - } = formatMessagesForCohere(messages); - - if (system !== undefined && system.length > 0) { - inputBody.preamble = system; - } - inputBody.message = formattedMessage; - inputBody.chat_history = formattedChatHistories; - inputBody.max_tokens = maxTokens; - inputBody.temperature = temperature; - inputBody.stop_sequences = stopSequences; - } else { - throw new Error('The messages API is currently only supported by Anthropic or Cohere'); - } - - if (guardrailConfig && guardrailConfig.tagSuffix && guardrailConfig.streamProcessingMode) { - inputBody['amazon-bedrock-guardrailConfig'] = guardrailConfig; - } - - return { ...inputBody, ...modelKwargs }; - } - - /** - * Extracts the generated text from the service response. - * @param provider The provider name. - * @param responseBody The response body from the service. - * @returns The generated text. - */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - static prepareOutput(provider: string, responseBody: any): string { - if (provider === 'anthropic') { - return responseBody.completion; - } else if (provider === 'ai21') { - return responseBody?.completions?.[0]?.data?.text ?? ''; - } else if (provider === 'cohere') { - return responseBody?.generations?.[0]?.text ?? responseBody?.text ?? ''; - } else if (provider === 'meta') { - return responseBody.generation; - } else if (provider === 'mistral') { - return responseBody?.outputs?.[0]?.text; - } - - // I haven't been able to get a response with more than one result in it. - return responseBody.results?.[0]?.outputText; - } - - static prepareMessagesOutput( - provider: string, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - response: any - ): ChatGeneration | undefined { - const responseBody = response ?? {}; - if (provider === 'anthropic') { - if (responseBody.type === 'message_start') { - return parseMessage(responseBody.message, true); - } else if ( - responseBody.type === 'content_block_delta' && - responseBody.delta?.type === 'text_delta' && - typeof responseBody.delta?.text === 'string' - ) { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: responseBody.delta.text, - }), - text: responseBody.delta.text, - }); - } else if (responseBody.type === 'message_delta') { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ content: '' }), - text: '', - generationInfo: { - ...responseBody.delta, - usage: responseBody.usage, - }, - }); - } else if ( - responseBody.type === 'message_stop' && - responseBody['amazon-bedrock-invocationMetrics'] !== undefined - ) { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ content: '' }), - text: '', - generationInfo: { - 'amazon-bedrock-invocationMetrics': responseBody['amazon-bedrock-invocationMetrics'], - }, - }); - } else if (responseBody.type === 'message') { - return parseMessage(responseBody); - } else { - return undefined; - } - } else if (provider === 'cohere') { - if (responseBody.event_type === 'stream-start') { - return parseMessageCohere(responseBody.message, true); - } else if ( - responseBody.event_type === 'text-generation' && - typeof responseBody?.text === 'string' - ) { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: responseBody.text, - }), - text: responseBody.text, - }); - } else if (responseBody.event_type === 'search-queries-generation') { - return parseMessageCohere(responseBody); - } else if ( - responseBody.event_type === 'stream-end' && - responseBody.response !== undefined && - responseBody['amazon-bedrock-invocationMetrics'] !== undefined - ) { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ content: '' }), - text: '', - generationInfo: { - response: responseBody.response, - 'amazon-bedrock-invocationMetrics': responseBody['amazon-bedrock-invocationMetrics'], - }, - }); - } else { - if ( - responseBody.finish_reason === 'COMPLETE' || - responseBody.finish_reason === 'MAX_TOKENS' - ) { - return parseMessageCohere(responseBody); - } else { - return undefined; - } - } - } else { - throw new Error('The messages API is currently only supported by Anthropic or Cohere.'); - } - } -} - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration { - const { content, id, ...generationInfo } = responseBody; - let parsedContent; - if (Array.isArray(content) && content.length === 1 && content[0].type === 'text') { - parsedContent = content[0].text; - } else if (Array.isArray(content) && content.length === 0) { - parsedContent = ''; - } else { - parsedContent = content; - } - if (asChunk) { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: parsedContent, - additional_kwargs: { id }, - }), - text: typeof parsedContent === 'string' ? parsedContent : '', - generationInfo, - }); - } else { - // TODO: we are throwing away here the text response, as the interface of this method returns only one - const toolCalls = extractToolCalls(responseBody.content); - - if (toolCalls.length > 0) { - return { - message: new AIMessage({ - content: '', - additional_kwargs: { id }, - tool_calls: toolCalls, - }), - text: typeof parsedContent === 'string' ? parsedContent : '', - generationInfo, - }; - } - - return { - message: new AIMessage({ - content: parsedContent, - additional_kwargs: { id }, - tool_calls: toolCalls, - }), - text: typeof parsedContent === 'string' ? parsedContent : '', - generationInfo, - }; - } -} - -function parseMessageCohere( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - responseBody: any, - asChunk?: boolean -): ChatGeneration { - const { text, ...generationInfo } = responseBody; - let parsedContent = text; - if (typeof text !== 'string') { - parsedContent = ''; - } - if (asChunk) { - return new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: parsedContent, - }), - text: parsedContent, - generationInfo, - }); - } else { - return { - message: new AIMessage({ - content: parsedContent, - }), - text: parsedContent, - generationInfo, - }; - } -} diff --git a/x-pack/plugins/actions/server/application/connector/methods/get/get.ts b/x-pack/plugins/actions/server/application/connector/methods/get/get.ts index 35e8101757bc9..2d4a94f5615d7 100644 --- a/x-pack/plugins/actions/server/application/connector/methods/get/get.ts +++ b/x-pack/plugins/actions/server/application/connector/methods/get/get.ts @@ -62,7 +62,6 @@ export async function get({ id, actionTypeId: foundInMemoryConnector.actionTypeId, name: foundInMemoryConnector.name, - config: foundInMemoryConnector.config, isPreconfigured: foundInMemoryConnector.isPreconfigured, isSystemAction: foundInMemoryConnector.isSystemAction, isDeprecated: isConnectorDeprecated(foundInMemoryConnector), diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 95914bcac1052..e75a063ddf7dc 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -14,7 +14,7 @@ import { transformError } from '@kbn/securitysolution-es-utils'; import { RetrievalQAChain } from 'langchain/chains'; import { getDefaultArguments } from '@kbn/langchain/server'; import { MessagesPlaceholder } from '@langchain/core/prompts'; -import { getLlmClass, isToolCallingSupported } from '../../../routes/utils'; +import { getLlmClass } from '../../../routes/utils'; import { AgentExecutor } from '../executors/types'; import { APMTracer } from '../tracers/apm_tracer'; import { AssistantToolParams } from '../../../types'; @@ -42,8 +42,6 @@ export const callAgentExecutor: AgentExecutor = async ({ isStream = false, onLlmResponse, onNewReplacements, - model, - region, replacements, request, size, @@ -57,10 +55,9 @@ export const callAgentExecutor: AgentExecutor = async ({ request, llmType, logger, - region, // possible client model override, // let this be undefined otherwise so the connector handles the model - model, + model: request.body.model, // ensure this is defined because we default to it in the language_models // This is where the LangSmith logs (Metadata > Invocation Params) are set temperature: getDefaultArguments(llmType).temperature, @@ -115,23 +112,24 @@ export const callAgentExecutor: AgentExecutor = async ({ handleParsingErrors: 'Try again, paying close attention to the allowed tool input', }; // isOpenAI check is not on agentType alone because typescript doesn't like - const executor = isToolCallingSupported(llmType) - ? await initializeAgentExecutorWithOptions(tools, llm, { - agentType: 'openai-functions', - ...executorArgs, - }) - : await initializeAgentExecutorWithOptions(tools, llm, { - agentType: 'structured-chat-zero-shot-react-description', - ...executorArgs, - returnIntermediateSteps: false, - agentArgs: { - // this is important to help LangChain correctly format tool input - humanMessageTemplate: `Remember, when you have enough information, always prefix your final JSON output with "Final Answer:"\n\nQuestion: {input}\n\n{agent_scratchpad}.`, - memoryPrompts: [new MessagesPlaceholder('chat_history')], - suffix: - 'Begin! Reminder to ALWAYS use the above format, and to use tools if appropriate.', - }, - }); + const executor = + llmType === 'openai' + ? await initializeAgentExecutorWithOptions(tools, llm, { + agentType: 'openai-functions', + ...executorArgs, + }) + : await initializeAgentExecutorWithOptions(tools, llm, { + agentType: 'structured-chat-zero-shot-react-description', + ...executorArgs, + returnIntermediateSteps: false, + agentArgs: { + // this is important to help LangChain correctly format tool input + humanMessageTemplate: `Remember, when you have enough information, always prefix your final JSON output with "Final Answer:"\n\nQuestion: {input}\n\n{agent_scratchpad}.`, + memoryPrompts: [new MessagesPlaceholder('chat_history')], + suffix: + 'Begin! Reminder to ALWAYS use the above format, and to use tools if appropriate.', + }, + }); // Sets up tracer for tracing executions to APM. See x-pack/plugins/elastic_assistant/server/lib/langchain/tracers/README.mdx // If LangSmith env vars are set, executions will be traced there as well. See https://docs.smith.langchain.com/tracing diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index ef5a28fa79bc6..bd07099e312b3 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -48,8 +48,6 @@ export interface AgentExecutorParams { langChainMessages: BaseMessage[]; llmType?: string; logger: Logger; - model?: string; - region?: string; onNewReplacements?: (newReplacements: Replacements) => void; replacements: Replacements; isStream?: T; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 8f1011d04f887..2631b0e50ba1e 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -9,7 +9,7 @@ import { StructuredTool } from '@langchain/core/tools'; import { RetrievalQAChain } from 'langchain/chains'; import { getDefaultArguments } from '@kbn/langchain/server'; import { createOpenAIFunctionsAgent, createStructuredChatAgent } from 'langchain/agents'; -import { getLlmClass, isToolCallingSupported } from '../../../../routes/utils'; +import { getLlmClass } from '../../../../routes/utils'; import { AssistantToolParams } from '../../../../types'; import { AgentExecutor } from '../../executors/types'; import { openAIFunctionAgentPrompt, structuredChatAgentPrompt } from './prompts'; @@ -39,8 +39,6 @@ export const callAssistantGraph: AgentExecutor = async ({ onLlmResponse, onNewReplacements, replacements, - model, - region, request, size, traceOptions, @@ -56,8 +54,7 @@ export const callAssistantGraph: AgentExecutor = async ({ logger, // possible client model override, // let this be undefined otherwise so the connector handles the model - model, - region, + model: request.body.model, // ensure this is defined because we default to it in the language_models // This is where the LangSmith logs (Metadata > Invocation Params) are set temperature: getDefaultArguments(llmType).temperature, @@ -67,7 +64,7 @@ export const callAssistantGraph: AgentExecutor = async ({ // failure could be due to bad connector, we should deliver that result to the client asap maxRetries: 0, }); - const graphModel = llm; + const model = llm; const messages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message @@ -75,7 +72,7 @@ export const callAssistantGraph: AgentExecutor = async ({ const modelExists = await esStore.isModelInstalled(); // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(graphModel, esStore.asRetriever(10)); + const chain = RetrievalQAChain.fromLLM(model, esStore.asRetriever(10)); // Fetch any applicable tools that the source plugin may have registered const assistantToolParams: AssistantToolParams = { @@ -85,7 +82,7 @@ export const callAssistantGraph: AgentExecutor = async ({ esClient, isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, - llm: graphModel, + llm: model, logger, modelExists, onNewReplacements, @@ -98,19 +95,20 @@ export const callAssistantGraph: AgentExecutor = async ({ (tool) => tool.getTool(assistantToolParams) ?? [] ); - const agentRunnable = isToolCallingSupported(llmType) - ? await createOpenAIFunctionsAgent({ - llm, - tools, - prompt: openAIFunctionAgentPrompt, - streamRunnable: isStream, - }) - : await createStructuredChatAgent({ - llm, - tools, - prompt: structuredChatAgentPrompt, - streamRunnable: isStream, - }); + const agentRunnable = + llmType === 'openai' + ? await createOpenAIFunctionsAgent({ + llm, + tools, + prompt: openAIFunctionAgentPrompt, + streamRunnable: isStream, + }) + : await createStructuredChatAgent({ + llm, + tools, + prompt: structuredChatAgentPrompt, + streamRunnable: isStream, + }); const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index a7f4c992cc11d..0254ff159d053 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -5,6 +5,8 @@ * 2.0. */ +/* eslint-disable complexity */ + import { IRouter, Logger } from '@kbn/core/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; @@ -356,24 +358,6 @@ export const postActionsConnectorExecuteRoute = ( kbDataClient, }; - const llmType = getLlmType(actionTypeId); - - const actionsClient = await actions.getActionsClientWithRequest(request); - - let region; - let model = request.body.model; - if (llmType === 'bedrock') { - try { - const connector = await actionsClient.get({ id: connectorId }); - region = connector.config?.apiUrl.split('.').reverse()[2]; - if (!model) { - model = connector.config?.defaultModel; - } - } catch (e) { - logger.error(`Failed to get region: ${e.message}`); - } - } - // Shared executor params const executorParams: AgentExecutorParams = { abortSignal, @@ -390,13 +374,11 @@ export const postActionsConnectorExecuteRoute = ( esClient, esStore, isStream: request.body.subAction !== 'invokeAI', - llmType, + llmType: getLlmType(actionTypeId), langChainMessages, logger, - model, onNewReplacements, onLlmResponse, - region, request, response, replacements: request.body.replacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index b4ba5c435166e..395da909cb786 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -186,5 +186,3 @@ export const getLlmClass = (llmType?: string, isStreaming?: boolean) => : llmType === 'bedrock' ? ActionsClientBedrockChatModel : ActionsClientSimpleChatModel; - -export const isToolCallingSupported = (llmType?: string) => ['openai'].includes(llmType ?? ''); diff --git a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts index f3b133dd783f6..e2414f46dd985 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts @@ -17,7 +17,6 @@ export const BEDROCK_CONNECTOR_ID = '.bedrock'; export enum SUB_ACTION { RUN = 'run', INVOKE_AI = 'invokeAI', - INVOKE_AI_RAW = 'invokeAIRaw', INVOKE_STREAM = 'invokeStream', DASHBOARD = 'getDashboard', TEST = 'test', diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index d88948484ec2a..e300c8f21c408 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -27,8 +27,6 @@ export const RunActionParamsSchema = schema.object({ timeout: schema.maybe(schema.number()), }); -export const InvokeAIRawActionParamsSchema = schema.any(); - export const InvokeAIActionParamsSchema = schema.object({ messages: schema.arrayOf( schema.object({ diff --git a/x-pack/plugins/stack_connectors/common/bedrock/types.ts b/x-pack/plugins/stack_connectors/common/bedrock/types.ts index 1847a25893d06..1256831fc7fa0 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/types.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/types.ts @@ -15,7 +15,6 @@ import { RunActionResponseSchema, InvokeAIActionParamsSchema, InvokeAIActionResponseSchema, - InvokeAIRawActionParamsSchema, StreamingResponseSchema, RunApiLatestResponseSchema, } from './schema'; @@ -24,7 +23,6 @@ export type Config = TypeOf; export type Secrets = TypeOf; export type RunActionParams = TypeOf; export type InvokeAIActionParams = TypeOf; -export type InvokeAIRawActionParams = TypeOf; export type InvokeAIActionResponse = TypeOf; export type RunApiLatestResponse = TypeOf; export type RunActionResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 1d129dfcdf278..616c63b82dcb5 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -15,10 +15,8 @@ import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard'; import { RunActionParamsSchema, InvokeAIActionParamsSchema, - InvokeAIRawActionParamsSchema, StreamingResponseSchema, RunActionResponseSchema, - RunActionRawResponseSchema, RunApiLatestResponseSchema, } from '../../../common/bedrock/schema'; import { @@ -28,7 +26,6 @@ import { RunActionResponse, InvokeAIActionParams, InvokeAIActionResponse, - InvokeAIRawActionParams, RunApiLatestResponse, } from '../../../common/bedrock/types'; import { @@ -88,12 +85,6 @@ export class BedrockConnector extends SubActionConnector { schema: InvokeAIActionParamsSchema, }); - this.registerSubAction({ - name: SUB_ACTION.INVOKE_AI_RAW, - method: 'invokeAIRaw', - schema: InvokeAIRawActionParamsSchema, - }); - this.registerSubAction({ name: SUB_ACTION.INVOKE_STREAM, method: 'invokeStream', @@ -327,65 +318,8 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, }); - return { message: res.completion.trim() }; - } - /** - * Non-streamed security solution AI Assistant requests - * Responsible for invoking the runApi method with the provided body. - * It then formats the response into a string - * @param messages An array of messages to be sent to the API - * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. - * @returns an object with the response string as a property called message - */ - public async invokeAIRaw({ - bedrockMethod = 'invoke', - messages, - model, - stopSequences, - system, - temperature, - maxTokens, - signal, - timeout, - }: InvokeAIRawActionParams) { - if (bedrockMethod === 'invoke-with-response-stream') { - const body = JSON.stringify( - formatBedrockBody({ messages, stopSequences, system, temperature }) - ); - // set model on per request basis - const path = `/model/${model ?? this.model}/invoke-with-response-stream`; - const signed = this.signRequest(body, path, true); - - const res = await fetch(`${this.url}${path}`, { - headers: signed.headers, - body: signed.body, - method: 'POST', - }); - - return res.body; - } - - // set model on per request basis - const currentModel = model ?? this.model; - const path = `/model/${currentModel}/${bedrockMethod}`; - const body = JSON.stringify( - formatBedrockBody({ messages, stopSequences, system, temperature, maxTokens }) - ); - const signed = this.signRequest(body, path, false); - const params = { - ...signed, - url: `${this.url}${path}`, - method: 'post' as Method, - data: body, - signal, - // give up to 2 minutes for response - timeout: timeout ?? DEFAULT_TIMEOUT_MS, - }; - - const response = await this.request({ ...params, responseSchema: RunActionRawResponseSchema }); - - return response.data; + return { message: res.completion.trim() }; } } From e1434bae501e779e59e04004291e8bf7337e335d Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sat, 29 Jun 2024 23:43:29 +0200 Subject: [PATCH 05/55] add FF --- .../impl/capabilities/index.ts | 1 + .../server/language_models/bedrock_chat.ts | 40 ++++--------- .../execute_custom_llm_chain/index.ts | 39 ++++++------ .../server/lib/langchain/executors/types.ts | 1 + .../graphs/default_assistant_graph/helpers.ts | 59 ++++++++++++++++++- .../graphs/default_assistant_graph/index.ts | 31 +++++----- .../nodes/execute_tools.ts | 11 +++- .../server/routes/evaluate/post_evaluate.ts | 3 + .../routes/post_actions_connector_execute.ts | 3 + .../elastic_assistant/server/routes/utils.ts | 4 +- .../common/experimental_features.ts | 5 ++ .../security_solution/server/plugin.ts | 1 + .../stack_connectors/common/bedrock/schema.ts | 2 - .../server/connector_types/bedrock/bedrock.ts | 1 - 14 files changed, 129 insertions(+), 72 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts index c1c101fd74cd8..1e759df2819ed 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts @@ -21,4 +21,5 @@ export type AssistantFeatureKey = keyof AssistantFeatures; export const defaultAssistantFeatures = Object.freeze({ assistantKnowledgeBaseByDefault: false, assistantModelEvaluation: false, + assistantBedrockChat: false, }); diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 76a282a751714..50560abe6471c 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -16,12 +16,6 @@ export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0' export const DEFAULT_BEDROCK_REGION = 'us-east-1'; export class ActionsClientBedrockChatModel extends _BedrockChat { - // Kibana variables - #actions: ActionsPluginStart; - #connectorId: string; - #logger: Logger; - #request: KibanaRequest; - constructor({ actions, request, @@ -38,26 +32,21 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { super({ ...params, credentials: { accessKeyId: '', secretAccessKey: '' }, - usesMessagesApi: true, // only needed to force BedrockChat to use messages api for Claude v2 model: DEFAULT_BEDROCK_MODEL, region: DEFAULT_BEDROCK_REGION, fetchFn: async (url, options) => { // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); + const actionsClient = await actions.getActionsClientWithRequest(request); - const inputBody = JSON.parse(options?.body); + const inputBody = JSON.parse(options?.body as string); if (this.streaming) { - const data = await actionsClient.execute({ - actionId: this.#connectorId, + const data = (await actionsClient.execute({ + actionId: connectorId, params: { subAction: 'invokeStream', subActionParams: { - // bedrockMethod: 'invoke-with-response-stream', - model: this.model, - // endpointHost: this.endpointHost, - // anthropicVersion: inputBody.anthropicVersion, messages: inputBody.messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, @@ -65,21 +54,19 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { maxTokens: inputBody.maxTokens, }, }, - }); - - console.error('data', data); + })) as { data: Readable }; return { body: Readable.toWeb(data.data), - }; + } as unknown as Response; } - const data = await actionsClient.execute({ - actionId: this.#connectorId, + const data = (await actionsClient.execute({ + actionId: connectorId, params: { subAction: 'invokeAI', subActionParams: { - model: this.model, + // model: this.model, messages: inputBody.messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, @@ -87,7 +74,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { maxTokens: inputBody.maxTokens, }, }, - }); + })) as { status: string; data: { message: string } }; return { ok: data.status === 'ok', @@ -95,13 +82,8 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { content: data.data.message, type: 'message', }), - }; + } as unknown as Response; }, }); - - this.#actions = actions; - this.#request = request; - this.#connectorId = connectorId; - this.#logger = logger; } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index e75a063ddf7dc..a4f7bf2c16431 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -33,6 +33,7 @@ export const callAgentExecutor: AgentExecutor = async ({ anonymizationFields, isEnabledKnowledgeBase, assistantTools = [], + bedrockChatEnabled, connectorId, esClient, esStore, @@ -47,7 +48,8 @@ export const callAgentExecutor: AgentExecutor = async ({ size, traceOptions, }) => { - const llmClass = getLlmClass(llmType, isStream); + const isOpenAI = llmType === 'openai'; + const llmClass = getLlmClass(llmType, bedrockChatEnabled); const llm = new llmClass({ actions, @@ -112,24 +114,23 @@ export const callAgentExecutor: AgentExecutor = async ({ handleParsingErrors: 'Try again, paying close attention to the allowed tool input', }; // isOpenAI check is not on agentType alone because typescript doesn't like - const executor = - llmType === 'openai' - ? await initializeAgentExecutorWithOptions(tools, llm, { - agentType: 'openai-functions', - ...executorArgs, - }) - : await initializeAgentExecutorWithOptions(tools, llm, { - agentType: 'structured-chat-zero-shot-react-description', - ...executorArgs, - returnIntermediateSteps: false, - agentArgs: { - // this is important to help LangChain correctly format tool input - humanMessageTemplate: `Remember, when you have enough information, always prefix your final JSON output with "Final Answer:"\n\nQuestion: {input}\n\n{agent_scratchpad}.`, - memoryPrompts: [new MessagesPlaceholder('chat_history')], - suffix: - 'Begin! Reminder to ALWAYS use the above format, and to use tools if appropriate.', - }, - }); + const executor = isOpenAI + ? await initializeAgentExecutorWithOptions(tools, llm, { + agentType: 'openai-functions', + ...executorArgs, + }) + : await initializeAgentExecutorWithOptions(tools, llm, { + agentType: 'structured-chat-zero-shot-react-description', + ...executorArgs, + returnIntermediateSteps: false, + agentArgs: { + // this is important to help LangChain correctly format tool input + humanMessageTemplate: `Remember, when you have enough information, always prefix your final JSON output with "Final Answer:"\n\nQuestion: {input}\n\n{agent_scratchpad}.`, + memoryPrompts: [new MessagesPlaceholder('chat_history')], + suffix: + 'Begin! Reminder to ALWAYS use the above format, and to use tools if appropriate.', + }, + }); // Sets up tracer for tracing executions to APM. See x-pack/plugins/elastic_assistant/server/lib/langchain/tracers/README.mdx // If LangSmith env vars are set, executions will be traced there as well. See https://docs.smith.langchain.com/tracing diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index bd07099e312b3..0a30a985b88bb 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -38,6 +38,7 @@ export interface AgentExecutorParams { alertsIndexPattern?: string; actions: ActionsPluginStart; anonymizationFields?: AnonymizationFieldResponse[]; + bedrockChatEnabled: boolean; isEnabledKnowledgeBase: boolean; assistantTools?: AssistantTool[]; connectorId: string; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 383b3e9f5cee8..44404785e7593 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -85,6 +85,11 @@ export const streamGraph = async ({ version: 'v1', }); + let message = ''; + let finalOutputIndex = -1; + const finalOutputStartToken = '"action":"FinalAnswer","action_input":"'; + const finalOutputStopRegex = /(? { try { const { value, done } = await stream.next(); @@ -93,8 +98,7 @@ export const streamGraph = async ({ const event = value; if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; - // TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks, - // TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event + if (event.name === 'ActionsClientChatOpenAI') { const msg = chunk.message; @@ -109,9 +113,58 @@ export const streamGraph = async ({ } } } + + if (event.name === 'ActionsClientBedrockChatModel') { + const msg = chunk; + + if (msg) { + if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { + /* empty */ + } else if (!didEnd) { + if (msg.response_metadata.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } else { + const finalOutputEndIndex = msg.content.search(finalOutputStopRegex); + const currentOutput = message.replace(/\s/g, ''); + + if (currentOutput.includes(finalOutputStartToken)) { + finalOutputIndex = currentOutput.indexOf(finalOutputStartToken); + } + + if (finalOutputIndex > -1 && finalOutputEndIndex > -1) { + didEnd = true; + handleStreamEnd(finalMessage); + return; + } + + if (finalOutputIndex > -1) { + finalMessage += msg.content; + push({ payload: msg.content, type: 'content' }); + } + + message += msg.content; + } + } + } + } + } else if (event.event === 'on_llm_end') { + if (event.name === 'ActionsClientChatOpenAI') { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } + } + + if (event.name === 'ActionsClientBedrockChatModel') { + const generations = event.data.output?.generations[0]; + + if (generations && generations[0]?.generationInfo.stop_reason === 'end_turn') { + handleStreamEnd(finalMessage); + } + } } - await processEvent(); + processEvent(); } catch (err) { // if I throw an error here, it crashes the server. Not sure how to get around that. // If I put await on this function the error works properly, but when there is not an error diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 2631b0e50ba1e..c62a122be5a67 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -27,6 +27,7 @@ export const callAssistantGraph: AgentExecutor = async ({ anonymizationFields, isEnabledKnowledgeBase, assistantTools = [], + bedrockChatEnabled, connectorId, conversationId, dataClients, @@ -44,7 +45,8 @@ export const callAssistantGraph: AgentExecutor = async ({ traceOptions, }) => { const logger = parentLogger.get('defaultAssistantGraph'); - const llmClass = getLlmClass(llmType, isStream); + const isOpenAI = llmType === 'openai'; + const llmClass = getLlmClass(llmType, bedrockChatEnabled); const llm = new llmClass({ actions, @@ -95,20 +97,19 @@ export const callAssistantGraph: AgentExecutor = async ({ (tool) => tool.getTool(assistantToolParams) ?? [] ); - const agentRunnable = - llmType === 'openai' - ? await createOpenAIFunctionsAgent({ - llm, - tools, - prompt: openAIFunctionAgentPrompt, - streamRunnable: isStream, - }) - : await createStructuredChatAgent({ - llm, - tools, - prompt: structuredChatAgentPrompt, - streamRunnable: isStream, - }); + const agentRunnable = isOpenAI + ? await createOpenAIFunctionsAgent({ + llm, + tools, + prompt: openAIFunctionAgentPrompt, + streamRunnable: isStream, + }) + : await createStructuredChatAgent({ + llm, + tools, + prompt: structuredChatAgentPrompt, + streamRunnable: isStream, + }); const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts index b42455e14f6f1..bfcea2f231995 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts @@ -38,7 +38,16 @@ export const executeTools = async ({ config, logger, state, tools }: ExecuteTool if (!agentAction || 'returnValues' in agentAction) { throw new Error('Agent has not been run yet'); } - const out = await toolExecutor.invoke(agentAction, config); + + let out; + try { + out = await toolExecutor.invoke(agentAction, config); + } catch (err) { + return { + steps: [{ action: agentAction, observation: JSON.stringify(`Error: ${err}`, null, 2) }], + }; + } + return { steps: [{ action: agentAction, observation: JSON.stringify(out, null, 2) }], }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index ef1950b5e90ad..7193106d51444 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -162,6 +162,8 @@ export const postEvaluateRoute = ( // Setup with kbDataClient if `enableKnowledgeBaseByDefault` FF is enabled const enableKnowledgeBaseByDefault = assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + const bedrockChatEnabled = + assistantContext.getRegisteredFeatures(pluginName).assistantBedrockChat; const kbDataClient = enableKnowledgeBaseByDefault ? (await assistantContext.getAIAssistantKnowledgeBaseDataClient(false)) ?? undefined : undefined; @@ -195,6 +197,7 @@ export const postEvaluateRoute = ( actions, isEnabledKnowledgeBase: true, assistantTools, + bedrockChatEnabled, connectorId, esClient, esStore, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 0254ff159d053..0c283717fc56d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -335,6 +335,8 @@ export const postActionsConnectorExecuteRoute = ( // Setup with kbDataClient if `assistantKnowledgeBaseByDefault` FF is enabled const enableKnowledgeBaseByDefault = assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + const bedrockChatEnabled = + assistantContext.getRegisteredFeatures(pluginName).assistantBedrockChat; const kbDataClient = enableKnowledgeBaseByDefault ? (await assistantContext.getAIAssistantKnowledgeBaseDataClient(false)) ?? undefined : undefined; @@ -366,6 +368,7 @@ export const postActionsConnectorExecuteRoute = ( ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) : undefined, actions, + bedrockChatEnabled, isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, assistantTools, connectorId, diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 395da909cb786..72aa5218ac6ce 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -180,9 +180,9 @@ export const getLlmType = (actionTypeId: string): string | undefined => { return llmTypeDictionary[actionTypeId]; }; -export const getLlmClass = (llmType?: string, isStreaming?: boolean) => +export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) => llmType === 'openai' ? ActionsClientChatOpenAI - : llmType === 'bedrock' + : llmType === 'bedrock' && bedrockChatEnabled ? ActionsClientBedrockChatModel : ActionsClientSimpleChatModel; diff --git a/x-pack/plugins/security_solution/common/experimental_features.ts b/x-pack/plugins/security_solution/common/experimental_features.ts index 53c5bdd8a657e..dafbc54e700b8 100644 --- a/x-pack/plugins/security_solution/common/experimental_features.ts +++ b/x-pack/plugins/security_solution/common/experimental_features.ts @@ -127,6 +127,11 @@ export const allowedExperimentalValues = Object.freeze({ */ assistantKnowledgeBaseByDefault: false, + /** + * Enables the Assistant BedrockChat Langchain model, introduced in `8.15.0`. + */ + assistantBedrockChat: false, + /** * Enables the Managed User section inside the new user details flyout. * To see this section you also need expandableFlyoutDisabled flag set to false. diff --git a/x-pack/plugins/security_solution/server/plugin.ts b/x-pack/plugins/security_solution/server/plugin.ts index 5b5b833dd2d4b..c0dedaf565e86 100644 --- a/x-pack/plugins/security_solution/server/plugin.ts +++ b/x-pack/plugins/security_solution/server/plugin.ts @@ -567,6 +567,7 @@ export class Plugin implements ISecuritySolutionPlugin { // Assistant Tool and Feature Registration plugins.elasticAssistant.registerTools(APP_UI_ID, getAssistantTools()); plugins.elasticAssistant.registerFeatures(APP_UI_ID, { + assistantBedrockChat: config.experimentalFeatures.assistantBedrockChat, assistantKnowledgeBaseByDefault: config.experimentalFeatures.assistantKnowledgeBaseByDefault, assistantModelEvaluation: config.experimentalFeatures.assistantModelEvaluation, }); diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index e300c8f21c408..bf35aa6bb8e0d 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -65,8 +65,6 @@ export const RunApiLatestResponseSchema = schema.object( { unknowns: 'allow' } ); -export const RunActionRawResponseSchema = schema.any(); - export const RunActionResponseSchema = schema.object( { completion: schema.string(), diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 616c63b82dcb5..8b05c30a5b0cb 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -318,7 +318,6 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, }); - return { message: res.completion.trim() }; } } From cf3d54723f5ec697b35db685842b04d83830903b Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sat, 29 Jun 2024 23:50:33 +0200 Subject: [PATCH 06/55] cleanup --- .../server/language_models/bedrock_chat.ts | 2 -- .../execute_custom_llm_chain/index.ts | 24 +++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 50560abe6471c..c0beb8dc52cc6 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -28,7 +28,6 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { logger: Logger; request: KibanaRequest; } & BaseChatModelParams) { - // Just to make Langchain BedrockChat happy super({ ...params, credentials: { accessKeyId: '', secretAccessKey: '' }, @@ -66,7 +65,6 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { params: { subAction: 'invokeAI', subActionParams: { - // model: this.model, messages: inputBody.messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index a4f7bf2c16431..35fa66b73db19 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -193,19 +193,23 @@ export const callAgentExecutor: AgentExecutor = async ({ tokenParentRunId = parentRunId; } if (payload.length && !didEnd && tokenParentRunId === parentRunId) { - const finalOutputEndIndex = payload.search(finalOutputStopRegex); - const currentOutput = message.replace(/\s/g, ''); + if (llmType === 'bedrock' && bedrockChatEnabled) { + const finalOutputEndIndex = payload.search(finalOutputStopRegex); + const currentOutput = message.replace(/\s/g, ''); - if (currentOutput.includes(finalOutputStartToken)) { - finalOutputIndex = currentOutput.indexOf(finalOutputStartToken); - } + if (currentOutput.includes(finalOutputStartToken)) { + finalOutputIndex = currentOutput.indexOf(finalOutputStartToken); + } - if (finalOutputIndex > -1) { - push({ payload, type: 'content' }); - } + if (finalOutputIndex > -1) { + push({ payload, type: 'content' }); + } - if (finalOutputIndex > -1 && finalOutputEndIndex > -1) { - didEnd = true; + if (finalOutputIndex > -1 && finalOutputEndIndex > -1) { + didEnd = true; + } + } else { + push({ payload, type: 'content' }); } // store message in case of error message += payload; From 281f98e6f1b6a9e747638ded7adfaca9dcd09402 Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Sat, 29 Jun 2024 22:08:33 +0000 Subject: [PATCH 07/55] [CI] Auto-commit changed files from 'node scripts/lint_ts_projects --fix' --- x-pack/plugins/elastic_assistant/tsconfig.json | 1 + x-pack/plugins/integration_assistant/tsconfig.json | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/tsconfig.json b/x-pack/plugins/elastic_assistant/tsconfig.json index f63a8da530196..8f546d6e5fe01 100644 --- a/x-pack/plugins/elastic_assistant/tsconfig.json +++ b/x-pack/plugins/elastic_assistant/tsconfig.json @@ -45,6 +45,7 @@ "@kbn/core-saved-objects-api-server", "@kbn/langchain", "@kbn/stack-connectors-plugin", + "@kbn/security-plugin", ], "exclude": [ "target/**/*", diff --git a/x-pack/plugins/integration_assistant/tsconfig.json b/x-pack/plugins/integration_assistant/tsconfig.json index bd8541866bc4d..fbb29e03b4f6e 100644 --- a/x-pack/plugins/integration_assistant/tsconfig.json +++ b/x-pack/plugins/integration_assistant/tsconfig.json @@ -35,6 +35,7 @@ "@kbn/logging-mocks", "@kbn/core-http-request-handler-context-server", "@kbn/core-http-router-server-mocks", - "@kbn/core-http-server" + "@kbn/core-http-server", + "@kbn/langchain" ] } From 10f5de99f4f170db1e4dc9089fd5d2cb674e678a Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 00:16:41 +0200 Subject: [PATCH 08/55] fix --- package.json | 2 +- yarn.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/package.json b/package.json index c85a21a3d2e58..feb89518dc19f 100644 --- a/package.json +++ b/package.json @@ -938,7 +938,7 @@ "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", "@langchain/community": "^0.2.15", - "@langchain/core": "^0.2.1", + "@langchain/core": "^0.2.11", "@langchain/langgraph": "^0.0.25", "@langchain/openai": "^0.2.1", "@langtrase/trace-attributes": "^3.0.8", diff --git a/yarn.lock b/yarn.lock index e9ad7f2247913..b295dc7c34edc 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6992,7 +6992,7 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@0.2.11", "@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.8 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.1", "@langchain/core@~0.2.9": +"@langchain/core@0.2.11", "@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.8 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.11", "@langchain/core@~0.2.9": version "0.2.11" resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.11.tgz#5f47467e20e56b250831baef20083657c6facb4c" integrity sha512-d4SNL7WI0c3oHrV4WxCRH1/TNqdePXEzYjYwIb4aEH6lW1aM0utGhLbNthX+aYkOL4Ynx2FoG4h91ECIipiKWQ== From 4cd8f7fe6052b3fa88cb19eb341e5ff97cb0e2fa Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 01:42:11 +0200 Subject: [PATCH 09/55] fix --- .../impl/schemas/capabilities/get_capabilities_route.gen.ts | 1 + .../schemas/capabilities/get_capabilities_route.schema.yaml | 3 +++ x-pack/packages/kbn-langchain/server/language_models/types.ts | 2 +- .../lib/langchain/graphs/default_assistant_graph/helpers.ts | 1 + 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts index 34a10bd517c1a..d8cecc9c788a4 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts @@ -20,4 +20,5 @@ export type GetCapabilitiesResponse = z.infer; export const GetCapabilitiesResponse = z.object({ assistantKnowledgeBaseByDefault: z.boolean(), assistantModelEvaluation: z.boolean(), + assistantBedrockChat: z.boolean(), }); diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.schema.yaml index 7461bdbc93237..8e8325e1501c7 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.schema.yaml @@ -19,11 +19,14 @@ paths: schema: type: object properties: + assistantBedrockChat: + type: boolean assistantKnowledgeBaseByDefault: type: boolean assistantModelEvaluation: type: boolean required: + - assistantBedrockChat - assistantKnowledgeBaseByDefault - assistantModelEvaluation '400': diff --git a/x-pack/packages/kbn-langchain/server/language_models/types.ts b/x-pack/packages/kbn-langchain/server/language_models/types.ts index df866bdf30eb7..43dcad34fda3c 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/types.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/types.ts @@ -16,7 +16,7 @@ export interface InvokeAIActionParamsSchema { function_call?: { arguments: string; name: string; - }; + } | null; tool_calls?: Array<{ id: string; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 44404785e7593..90a66b810b9a4 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -164,6 +164,7 @@ export const streamGraph = async ({ } } + // @ts-expect-error processEvent(); } catch (err) { // if I throw an error here, it crashes the server. Not sure how to get around that. From 1bd50b5b56d86961cd5b89249b6c6de9f3e52aee Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Sun, 30 Jun 2024 00:34:05 +0000 Subject: [PATCH 10/55] [CI] Auto-commit changed files from 'yarn openapi:generate' --- .../impl/schemas/capabilities/get_capabilities_route.gen.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts index d8cecc9c788a4..6341f7296f390 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/capabilities/get_capabilities_route.gen.ts @@ -18,7 +18,7 @@ import { z } from 'zod'; export type GetCapabilitiesResponse = z.infer; export const GetCapabilitiesResponse = z.object({ + assistantBedrockChat: z.boolean(), assistantKnowledgeBaseByDefault: z.boolean(), assistantModelEvaluation: z.boolean(), - assistantBedrockChat: z.boolean(), }); From 7696942f1eaa02fe1ae7f43a8bd9b823213e816b Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 08:58:06 +0200 Subject: [PATCH 11/55] fix --- .../lib/langchain/execute_custom_llm_chain/index.test.ts | 1 + .../langchain/graphs/default_assistant_graph/helpers.ts | 5 +++-- .../server/lib/conversational_chain.test.ts | 9 ++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts index e1e8cdc50eee0..6c90e28a8de19 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts @@ -96,6 +96,7 @@ const esStoreMock = new ElasticsearchStore( ); const defaultProps: AgentExecutorParams = { actions: mockActions, + bedrockChatEnabled: false, isEnabledKnowledgeBase: true, connectorId: mockConnectorId, esClient: esClientMock, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 90a66b810b9a4..434fbae594aef 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -5,6 +5,8 @@ * 2.0. */ +/* eslint-disable complexity */ + import agent, { Span } from 'elastic-apm-node'; import type { Logger } from '@kbn/logging'; import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; @@ -164,8 +166,7 @@ export const streamGraph = async ({ } } - // @ts-expect-error - processEvent(); + void processEvent(); } catch (err) { // if I throw an error here, it crashes the server. Not sure how to get around that. // If I put await on this function the error works properly, but when there is not an error diff --git a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts index fa6ee8e6e1abd..9045f68a83f8d 100644 --- a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts +++ b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts @@ -433,7 +433,10 @@ describe('conversational chain', () => { expectedDocs: [ { documents: [ - { metadata: { _id: '1', _index: 'index' } }, + { + metadata: { _id: '1', _index: 'index' }, + pageContent: expect.any(String), + }, { metadata: { _id: '1', _index: 'website' }, pageContent: expect.any(String), @@ -444,8 +447,8 @@ describe('conversational chain', () => { ], // Even with body_content of 1000, the token count should be below the model limit of 100 expectedTokens: [ - { type: 'context_token_count', count: 70 }, - { type: 'prompt_token_count', count: 97 }, + { type: 'context_token_count', count: 73 }, + { type: 'prompt_token_count', count: 100 }, ], expectedHasClipped: true, expectedSearchRequest: [ From 340097b38ef305a685ff5313ffca25bbff2037d9 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 11:47:07 +0200 Subject: [PATCH 12/55] fix --- .../server/language_models/bedrock_chat.ts | 24 +++++++++++++++++-- .../esql_language_knowledge_base_tool.ts | 3 ++- .../knowledge_base_retrieval_tool.ts | 3 ++- .../knowledge_base_write_tool.ts | 3 ++- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index c0beb8dc52cc6..687ec06eb175a 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -11,6 +11,7 @@ import { BaseChatModelParams } from '@langchain/core/language_models/chat_models import { Logger } from '@kbn/logging'; import { KibanaRequest } from '@kbn/core/server'; import { Readable } from 'stream'; +import { filter, isArray, map } from 'lodash'; export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; export const DEFAULT_BEDROCK_REGION = 'us-east-1'; @@ -39,6 +40,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { const actionsClient = await actions.getActionsClientWithRequest(request); const inputBody = JSON.parse(options?.body as string); + const messages = map(inputBody.messages, sanitizeMessage); if (this.streaming) { const data = (await actionsClient.execute({ @@ -46,7 +48,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { params: { subAction: 'invokeStream', subActionParams: { - messages: inputBody.messages, + messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, system: inputBody.system, @@ -65,7 +67,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { params: { subAction: 'invokeAI', subActionParams: { - messages: inputBody.messages, + messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, system: inputBody.system, @@ -85,3 +87,21 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { }); } } + +const sanitizeMessage = ({ + role, + content, +}: { + role: string; + content: string | Array<{ type: string; text: string }>; +}) => { + if (isArray(content)) { + const textContent = filter(content, ['type', 'text']); + return { role, content: textContent[textContent.length - 1]?.text }; + } + + return { + role, + content, + }; +}; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts index 692753a22dea0..f3a5841adc0a3 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts @@ -47,6 +47,7 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = { return result.text; }, tags: ['esql', 'query-generation', 'knowledge-base'], - }); + // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts + }) as unknown as DynamicStructuredTool; }, }; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts index 47cb35e244d51..29a398b9bbeec 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts @@ -52,6 +52,7 @@ export const KNOWLEDGE_BASE_RETRIEVAL_TOOL: AssistantTool = { return JSON.stringify(docs); }, tags: ['knowledge-base'], - }); + // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts + }) as unknown as DynamicStructuredTool; }, }; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts index addb2a5580dfc..4522ae8c3e75b 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts @@ -64,6 +64,7 @@ export const KNOWLEDGE_BASE_WRITE_TOOL: AssistantTool = { return "I've successfully saved this entry to your knowledge base. You can ask me to recall this information at any time."; }, tags: ['knowledge-base'], - }); + // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts + }) as unknown as DynamicStructuredTool; }, }; From 29788718643d5bab8e8cf9498d7d7ea2d4c15a41 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 19:48:14 +0200 Subject: [PATCH 13/55] cleanup --- x-pack/plugins/integration_assistant/kibana.jsonc | 3 +-- .../graphs/categorization/categorization.test.ts | 7 +++++-- .../server/graphs/categorization/categorization.ts | 7 +++++-- .../server/graphs/categorization/errors.test.ts | 7 +++++-- .../server/graphs/categorization/errors.ts | 7 +++++-- .../server/graphs/categorization/graph.test.ts | 7 +++++-- .../server/graphs/categorization/graph.ts | 7 +++++-- .../server/graphs/categorization/invalid.test.ts | 7 +++++-- .../server/graphs/categorization/invalid.ts | 7 +++++-- .../server/graphs/categorization/review.test.ts | 7 +++++-- .../server/graphs/categorization/review.ts | 7 +++++-- .../server/graphs/ecs/duplicates.test.ts | 7 +++++-- .../server/graphs/ecs/duplicates.ts | 7 +++++-- .../server/graphs/ecs/graph.test.ts | 7 +++++-- .../integration_assistant/server/graphs/ecs/graph.ts | 7 +++++-- .../server/graphs/ecs/invalid.test.ts | 7 +++++-- .../server/graphs/ecs/invalid.ts | 7 +++++-- .../server/graphs/ecs/mapping.test.ts | 7 +++++-- .../server/graphs/ecs/mapping.ts | 7 +++++-- .../server/graphs/ecs/missing.test.ts | 7 +++++-- .../server/graphs/ecs/missing.ts | 7 +++++-- .../server/graphs/related/errors.test.ts | 7 +++++-- .../server/graphs/related/errors.ts | 7 +++++-- .../server/graphs/related/graph.test.ts | 7 +++++-- .../server/graphs/related/graph.ts | 7 +++++-- .../server/graphs/related/related.test.ts | 7 +++++-- .../server/graphs/related/related.ts | 7 +++++-- .../server/graphs/related/review.test.ts | 7 +++++-- .../server/graphs/related/review.ts | 7 +++++-- .../server/routes/categorization_routes.ts | 11 +++++++---- .../integration_assistant/server/routes/ecs_routes.ts | 11 +++++++---- .../server/routes/related_routes.ts | 11 +++++++---- x-pack/plugins/integration_assistant/tsconfig.json | 5 ++--- 33 files changed, 164 insertions(+), 73 deletions(-) diff --git a/x-pack/plugins/integration_assistant/kibana.jsonc b/x-pack/plugins/integration_assistant/kibana.jsonc index bf52e0abcabf4..d7f0a68765b8b 100644 --- a/x-pack/plugins/integration_assistant/kibana.jsonc +++ b/x-pack/plugins/integration_assistant/kibana.jsonc @@ -15,8 +15,7 @@ "kibanaReact", "triggersActionsUi", "actions", - "stackConnectors", - "elasticAssistant" + "stackConnectors" ], } } diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts index b30fa1d66a534..3ad0926297bbc 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts @@ -13,11 +13,14 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts index 03f95d36ff5a4..01a3d51aa7e18 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts @@ -5,14 +5,17 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_MAIN_PROMPT } from './prompts'; -export async function handleCategorization(state: CategorizationState, model: AssistantToolLlm) { +export async function handleCategorization(state: CategorizationState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts index 93ca5030e5104..18d8c1842080a 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts @@ -13,11 +13,14 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts index 8ce8792604af5..74a1a36a99a99 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts @@ -5,14 +5,17 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_ERROR_PROMPT } from './prompts'; -export async function handleErrors(state: CategorizationState, model: AssistantToolLlm) { +export async function handleErrors(state: CategorizationState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT; const outputParser = new JsonOutputParser(); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts index 7fd76e3bd7a60..4122d4540dbc0 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts @@ -25,11 +25,14 @@ import { handleCategorization } from './categorization'; import { handleErrors } from './errors'; import { handleInvalidCategorization } from './invalid'; import { testPipeline, combineProcessors } from '../../util'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: "I'll callback later.", -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; jest.mock('./errors'); jest.mock('./review'); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts index 79795b3c39ecf..6c6630200effb 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts @@ -8,7 +8,10 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import type { CategorizationState } from '../../types'; import { modifySamples, formatSamples } from '../../util/samples'; import { handleCategorization } from './categorization'; @@ -148,7 +151,7 @@ function chainRouter(state: CategorizationState): string { export async function getCategorizationGraph( client: IScopedClusterClient, - model: AssistantToolLlm + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel ) { const workflow = new StateGraph({ channels: graphState, diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts index 7c8e1fe1c4a62..10560137093d8 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts @@ -13,11 +13,14 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts index b5ec203b54cef..9847a76ff5a48 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts @@ -5,7 +5,10 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { CategorizationState } from '../../types'; @@ -15,7 +18,7 @@ import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts'; export async function handleInvalidCategorization( state: CategorizationState, - model: AssistantToolLlm + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel ) { const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts index a053226b65afc..7775b69c5b6a8 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts @@ -13,11 +13,14 @@ import { categorizationMockProcessors, categorizationExpectedHandlerResponse, } from '../../../__jest__/fixtures/categorization'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: CategorizationState = categorizationTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts index df11a6ef8d4eb..03862ed33f13d 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts @@ -5,7 +5,10 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import { CATEGORIZATION_REVIEW_PROMPT } from './prompts'; @@ -14,7 +17,7 @@ import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants'; -export async function handleReview(state: CategorizationState, model: AssistantToolLlm) { +export async function handleReview(state: CategorizationState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts index 2aa950f5a0591..9270b2453e261 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts @@ -9,11 +9,14 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleDuplicates } from './duplicates'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts index 8e576b3775a9c..a82708bd6b33a 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts @@ -5,12 +5,15 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_DUPLICATES_PROMPT } from './prompts'; -export async function handleDuplicates(state: EcsMappingState, model: AssistantToolLlm) { +export async function handleDuplicates(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const ecsDuplicatesPrompt = ECS_DUPLICATES_PROMPT; const outputParser = new JsonOutputParser(); const ecsDuplicatesGraph = ecsDuplicatesPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts index 41729e7e54c06..0ae626924c349 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts @@ -19,11 +19,14 @@ import { handleEcsMapping } from './mapping'; import { handleDuplicates } from './duplicates'; import { handleMissingKeys } from './missing'; import { handleInvalidEcs } from './invalid'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: "I'll callback later.", -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; jest.mock('./mapping'); jest.mock('./duplicates'); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts index b4a9c7b0dfd80..5720c42eb22c9 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts @@ -5,7 +5,10 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import type { StateGraphArgs } from '@langchain/langgraph'; import { END, START, StateGraph } from '@langchain/langgraph'; import type { EcsMappingState } from '../../types'; @@ -137,7 +140,7 @@ function chainRouter(state: EcsMappingState): string { return END; } -export async function getEcsGraph(model: AssistantToolLlm) { +export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const workflow = new StateGraph({ channels: graphState, }) diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts index 15da3809e2d97..ce1f76ce7a721 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts @@ -9,11 +9,14 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleInvalidEcs } from './invalid'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts index e06113135c910..65806f59c2faa 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts @@ -5,12 +5,15 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_INVALID_PROMPT } from './prompts'; -export async function handleInvalidEcs(state: EcsMappingState, model: AssistantToolLlm) { +export async function handleInvalidEcs(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const ecsInvalidEcsPrompt = ECS_INVALID_PROMPT; const outputParser = new JsonOutputParser(); const ecsInvalidEcsGraph = ecsInvalidEcsPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts index 4170505e458cd..dbbfc0608d010 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts @@ -9,11 +9,14 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleEcsMapping } from './mapping'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts index 98c9fe4eca82f..48511a5f4fc4f 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts @@ -5,12 +5,15 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_MAIN_PROMPT } from './prompts'; -export async function handleEcsMapping(state: EcsMappingState, model: AssistantToolLlm) { +export async function handleEcsMapping(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const ecsMainPrompt = ECS_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const ecsMainGraph = ecsMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts index d283a6f3fe1c1..b369d28b1e177 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts @@ -9,11 +9,14 @@ import { FakeLLM } from '@langchain/core/utils/testing'; import { handleMissingKeys } from './missing'; import type { EcsMappingState } from '../../types'; import { ecsTestState } from '../../../__jest__/fixtures/ecs_mapping'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: '{ "message": "ll callback later."}', -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: EcsMappingState = ecsTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts index ca7f7501f4eef..6412bf99d1188 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts @@ -5,12 +5,15 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_MISSING_KEYS_PROMPT } from './prompts'; -export async function handleMissingKeys(state: EcsMappingState, model: AssistantToolLlm) { +export async function handleMissingKeys(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const ecsMissingPrompt = ECS_MISSING_KEYS_PROMPT; const outputParser = new JsonOutputParser(); const ecsMissingGraph = ecsMissingPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts index 9f530d49fc6f3..24dc4365dcbff 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts @@ -13,11 +13,14 @@ import { relatedMockProcessors, relatedExpectedHandlerResponse, } from '../../../__jest__/fixtures/related'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: RelatedState = relatedTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts index 20c5f5e108226..de5691b845638 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts @@ -5,14 +5,17 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_ERROR_PROMPT } from './prompts'; -export async function handleErrors(state: RelatedState, model: AssistantToolLlm) { +export async function handleErrors(state: RelatedState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const relatedErrorPrompt = RELATED_ERROR_PROMPT; const outputParser = new JsonOutputParser(); const relatedErrorGraph = relatedErrorPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts index 19cf49b989ea1..40989e9733800 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts @@ -22,11 +22,14 @@ import { handleReview } from './review'; import { handleRelated } from './related'; import { handleErrors } from './errors'; import { testPipeline, combineProcessors } from '../../util'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: "I'll callback later.", -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; jest.mock('./errors'); jest.mock('./review'); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts index 51a6f9583fe64..3b44d9b65f170 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts @@ -8,7 +8,10 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import type { RelatedState } from '../../types'; import { modifySamples, formatSamples } from '../../util/samples'; import { handleValidatePipeline } from '../../util/graph'; @@ -134,7 +137,7 @@ function chainRouter(state: RelatedState): string { return END; } -export async function getRelatedGraph(client: IScopedClusterClient, model: AssistantToolLlm) { +export async function getRelatedGraph(client: IScopedClusterClient, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const workflow = new StateGraph({ channels: graphState }) .addNode('modelInput', modelInput) .addNode('modelOutput', modelOutput) diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts index b81de2b1025e0..3a741020fb530 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts @@ -13,11 +13,14 @@ import { relatedMockProcessors, relatedExpectedHandlerResponse, } from '../../../__jest__/fixtures/related'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: RelatedState = relatedTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts index 0cd1a7f8251b1..044afe0c91930 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts @@ -5,14 +5,17 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_MAIN_PROMPT } from './prompts'; -export async function handleRelated(state: RelatedState, model: AssistantToolLlm) { +export async function handleRelated(state: RelatedState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const relatedMainPrompt = RELATED_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const relatedMainGraph = relatedMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts index a814d25d0c3a2..475f0d72b988d 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts @@ -13,11 +13,14 @@ import { relatedMockProcessors, relatedExpectedHandlerResponse, } from '../../../__jest__/fixtures/related'; -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; const mockLlm = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), -}) as unknown as AssistantToolLlm; +}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; const testState: RelatedState = relatedTestState; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts index 2e76b822af2ae..ff3c76863ce21 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts @@ -5,14 +5,17 @@ * 2.0. */ -import type { AssistantToolLlm } from '@kbn/elastic-assistant-plugin/server/types'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { ESProcessorItem, Pipeline } from '../../../common'; import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_REVIEW_PROMPT } from './prompts'; -export async function handleReview(state: RelatedState, model: AssistantToolLlm) { +export async function handleReview(state: RelatedState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { const relatedReviewPrompt = RELATED_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const relatedReviewGraph = relatedReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts index b0befabd78384..6654898bd0232 100644 --- a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts @@ -7,7 +7,10 @@ import type { IKibanaResponse, IRouter } from '@kbn/core/server'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { getLlmType, getLlmClass } from '@kbn/elastic-assistant-plugin/server/routes/utils'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { CATEGORIZATION_GRAPH_PATH, CategorizationRequestBody, @@ -56,15 +59,15 @@ export function registerCategorizationRoutes( )[0]; const abortSignal = getRequestAbortedSignal(req.events.aborted$); - const llmType = getLlmType(connector.actionTypeId); - const llmClass = getLlmClass(llmType); + const isOpenAI = connector.actionTypeId === '.gen-ai'; + const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const model = new llmClass({ actions: actionsPlugin, connectorId: connector.id, request: req, logger, - llmType, + llmType: isOpenAI ? 'openai' : 'bedrock', model: connector.config?.defaultModel, temperature: 0.05, maxTokens: 4096, diff --git a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts index c53fcb49442ad..ee461b94feba4 100644 --- a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts @@ -7,7 +7,10 @@ import type { IKibanaResponse, IRouter } from '@kbn/core/server'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { getLlmType, getLlmClass } from '@kbn/elastic-assistant-plugin/server/routes/utils'; +import { + ActionsClientChatOpenAI, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server/language_models'; import { ECS_GRAPH_PATH, EcsMappingRequestBody, EcsMappingResponse } from '../../common'; import { ROUTE_HANDLER_TIMEOUT } from '../constants'; import { getEcsGraph } from '../graphs/ecs'; @@ -47,15 +50,15 @@ export function registerEcsRoutes(router: IRouter connectorItem.actionTypeId === '.bedrock' )[0]; - const llmType = getLlmType(connector.actionTypeId); - const llmClass = getLlmClass(llmType); + const isOpenAI = connector.actionTypeId === '.gen-ai'; + const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const abortSignal = getRequestAbortedSignal(req.events.aborted$); const model = new llmClass({ @@ -57,7 +60,7 @@ export function registerRelatedRoutes(router: IRouter Date: Sun, 30 Jun 2024 19:53:50 +0200 Subject: [PATCH 14/55] cleanup --- x-pack/plugins/integration_assistant/kibana.jsonc | 2 +- .../server/graphs/categorization/categorization.ts | 8 +++++--- .../server/graphs/categorization/errors.ts | 8 +++++--- .../server/graphs/categorization/graph.ts | 2 +- .../server/graphs/categorization/invalid.ts | 3 +-- .../server/graphs/categorization/review.ts | 8 +++++--- .../integration_assistant/server/graphs/ecs/duplicates.ts | 8 +++++--- .../integration_assistant/server/graphs/ecs/graph.ts | 2 +- .../integration_assistant/server/graphs/ecs/invalid.ts | 8 +++++--- .../integration_assistant/server/graphs/ecs/mapping.ts | 8 +++++--- .../integration_assistant/server/graphs/ecs/missing.ts | 8 +++++--- .../integration_assistant/server/graphs/related/errors.ts | 8 +++++--- .../integration_assistant/server/graphs/related/graph.ts | 7 +++++-- .../server/graphs/related/related.ts | 8 +++++--- .../integration_assistant/server/graphs/related/review.ts | 8 +++++--- 15 files changed, 59 insertions(+), 37 deletions(-) diff --git a/x-pack/plugins/integration_assistant/kibana.jsonc b/x-pack/plugins/integration_assistant/kibana.jsonc index d7f0a68765b8b..a70120d9cefba 100644 --- a/x-pack/plugins/integration_assistant/kibana.jsonc +++ b/x-pack/plugins/integration_assistant/kibana.jsonc @@ -15,7 +15,7 @@ "kibanaReact", "triggersActionsUi", "actions", - "stackConnectors" + "stackConnectors", ], } } diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts index 01a3d51aa7e18..ed1a88c3a1cfd 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -15,7 +14,10 @@ import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_MAIN_PROMPT } from './prompts'; -export async function handleCategorization(state: CategorizationState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleCategorization( + state: CategorizationState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts index 74a1a36a99a99..d8cb7beedc9bf 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -15,7 +14,10 @@ import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_ERROR_PROMPT } from './prompts'; -export async function handleErrors(state: CategorizationState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleErrors( + state: CategorizationState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT; const outputParser = new JsonOutputParser(); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts index 6c6630200effb..6834fcf892a9e 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts @@ -8,7 +8,7 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts index 9847a76ff5a48..413694b594518 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts index 03862ed33f13d..12b3880737237 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -17,7 +16,10 @@ import type { CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants'; -export async function handleReview(state: CategorizationState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleReview( + state: CategorizationState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts index a82708bd6b33a..fd11a660e75ab 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -13,7 +12,10 @@ import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_DUPLICATES_PROMPT } from './prompts'; -export async function handleDuplicates(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleDuplicates( + state: EcsMappingState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const ecsDuplicatesPrompt = ECS_DUPLICATES_PROMPT; const outputParser = new JsonOutputParser(); const ecsDuplicatesGraph = ecsDuplicatesPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts index 5720c42eb22c9..2c8e7283d4728 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts index 65806f59c2faa..dcbba0ebe9d13 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -13,7 +12,10 @@ import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_INVALID_PROMPT } from './prompts'; -export async function handleInvalidEcs(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleInvalidEcs( + state: EcsMappingState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const ecsInvalidEcsPrompt = ECS_INVALID_PROMPT; const outputParser = new JsonOutputParser(); const ecsInvalidEcsGraph = ecsInvalidEcsPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts index 48511a5f4fc4f..7ecb108659f45 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -13,7 +12,10 @@ import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_MAIN_PROMPT } from './prompts'; -export async function handleEcsMapping(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleEcsMapping( + state: EcsMappingState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const ecsMainPrompt = ECS_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const ecsMainGraph = ecsMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts index 6412bf99d1188..d7f1f65b2b4ea 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -13,7 +12,10 @@ import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { EcsMappingState } from '../../types'; import { ECS_MISSING_KEYS_PROMPT } from './prompts'; -export async function handleMissingKeys(state: EcsMappingState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleMissingKeys( + state: EcsMappingState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const ecsMissingPrompt = ECS_MISSING_KEYS_PROMPT; const outputParser = new JsonOutputParser(); const ecsMissingGraph = ecsMissingPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts index de5691b845638..025422008c4dc 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -15,7 +14,10 @@ import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_ERROR_PROMPT } from './prompts'; -export async function handleErrors(state: RelatedState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleErrors( + state: RelatedState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const relatedErrorPrompt = RELATED_ERROR_PROMPT; const outputParser = new JsonOutputParser(); const relatedErrorGraph = relatedErrorPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts index 3b44d9b65f170..9b50c05889402 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts @@ -8,7 +8,7 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -137,7 +137,10 @@ function chainRouter(state: RelatedState): string { return END; } -export async function getRelatedGraph(client: IScopedClusterClient, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function getRelatedGraph( + client: IScopedClusterClient, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const workflow = new StateGraph({ channels: graphState }) .addNode('modelInput', modelInput) .addNode('modelOutput', modelOutput) diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts index 044afe0c91930..2c98381510d9b 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -15,7 +14,10 @@ import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_MAIN_PROMPT } from './prompts'; -export async function handleRelated(state: RelatedState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleRelated( + state: RelatedState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const relatedMainPrompt = RELATED_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const relatedMainGraph = relatedMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts index ff3c76863ce21..6c07079e18f48 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts @@ -4,8 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -import { +import type { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; @@ -15,7 +14,10 @@ import type { RelatedState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { RELATED_REVIEW_PROMPT } from './prompts'; -export async function handleReview(state: RelatedState, model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +export async function handleReview( + state: RelatedState, + model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel +) { const relatedReviewPrompt = RELATED_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const relatedReviewGraph = relatedReviewPrompt.pipe(model).pipe(outputParser); From 931745db3441709b765780e62c09ba370c92704b Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 20:23:16 +0200 Subject: [PATCH 15/55] basic tests --- .../execute_custom_llm_chain/index.test.ts | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts index 6c90e28a8de19..299685b489380 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts @@ -17,7 +17,11 @@ import { langChainMessages } from '../../../__mocks__/lang_chain_messages'; import { KNOWLEDGE_BASE_INDEX_PATTERN } from '../../../routes/knowledge_base/constants'; import { callAgentExecutor } from '.'; import { PassThrough, Stream } from 'stream'; -import { ActionsClientChatOpenAI, ActionsClientSimpleChatModel } from '@kbn/langchain/server'; +import { + ActionsClientChatOpenAI, + ActionsClientBedrockChatModel, + ActionsClientSimpleChatModel, +} from '@kbn/langchain/server'; import { AgentExecutorParams } from '../executors/types'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; @@ -26,6 +30,7 @@ jest.mock('@kbn/langchain/server', () => { return { ...original, ActionsClientChatOpenAI: jest.fn(), + ActionsClientBedrockChatModel: jest.fn(), ActionsClientSimpleChatModel: jest.fn(), }; }); @@ -112,6 +117,11 @@ const bedrockProps = { ...defaultProps, llmType: 'bedrock', }; +const bedrockChatProps = { + ...defaultProps, + bedrockChatEnabled: true, + llmType: 'bedrock', +}; const executorMock = initializeAgentExecutorWithOptions as jest.Mock; describe('callAgentExecutor', () => { beforeEach(() => { @@ -274,6 +284,72 @@ describe('callAgentExecutor', () => { }); }); + describe('BedrockChat', () => { + describe('when the agent is not streaming', () => { + it('creates an instance of ActionsClientBedrockChatModel with the expected context from the request', async () => { + await callAgentExecutor(bedrockChatProps); + + expect(ActionsClientBedrockChatModel).toHaveBeenCalledWith({ + actions: mockActions, + connectorId: mockConnectorId, + logger: mockLogger, + maxRetries: 0, + request: mockRequest, + streaming: false, + temperature: 0, + llmType: 'bedrock', + }); + }); + + it('uses the structured-chat-zero-shot-react-description agent type', async () => { + await callAgentExecutor(bedrockChatProps); + expect(mockCall.mock.calls[0][0].agentType).toEqual( + 'structured-chat-zero-shot-react-description' + ); + }); + + it('returns the expected response', async () => { + const result = await callAgentExecutor(bedrockChatProps); + + expect(result).toEqual({ + body: { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + replacements: {}, + trace_data: undefined, + }, + headers: { + 'content-type': 'application/json', + }, + }); + }); + }); + describe('when the agent is streaming', () => { + it('creates an instance of ActionsClientBedrockChatModel with the expected context from the request', async () => { + await callAgentExecutor({ ...bedrockChatProps, isStream: true }); + + expect(ActionsClientBedrockChatModel).toHaveBeenCalledWith({ + actions: mockActions, + connectorId: mockConnectorId, + logger: mockLogger, + maxRetries: 0, + request: mockRequest, + streaming: true, + temperature: 0, + llmType: 'bedrock', + }); + }); + + it('uses the structured-chat-zero-shot-react-description agent type', async () => { + await callAgentExecutor({ ...bedrockChatProps, isStream: true }); + expect(mockInvoke.mock.calls[0][0].agentType).toEqual( + 'structured-chat-zero-shot-react-description' + ); + }); + }); + }); + describe.each([ ['OpenAI', defaultProps], ['Bedrock', bedrockProps], From 467cd96b7d59f34925f0e60819f0f696b3743913 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 20:30:16 +0200 Subject: [PATCH 16/55] cleanup package.json --- package.json | 7 ++----- yarn.lock | 47 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/package.json b/package.json index feb89518dc19f..c05971787034e 100644 --- a/package.json +++ b/package.json @@ -80,9 +80,6 @@ "resolutions": { "**/@bazel/typescript/protobufjs": "6.11.4", "**/@hello-pangea/dnd": "16.6.0", - "**/@langchain/core": "0.2.11", - "**/@langchain/openai": "0.2.1", - "**/@smithy/util-utf8": "3.0.0", "**/@types/node": "20.10.5", "**/@typescript-eslint/utils": "5.62.0", "**/chokidar": "^3.5.3", @@ -937,10 +934,10 @@ "@kbn/watcher-plugin": "link:x-pack/plugins/watcher", "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", - "@langchain/community": "^0.2.15", + "@langchain/community": "^0.2.16", "@langchain/core": "^0.2.11", "@langchain/langgraph": "^0.0.25", - "@langchain/openai": "^0.2.1", + "@langchain/openai": "^0.1.3", "@langtrase/trace-attributes": "^3.0.8", "@launchdarkly/node-server-sdk": "^9.4.6", "@loaders.gl/core": "^3.4.7", diff --git a/yarn.lock b/yarn.lock index b295dc7c34edc..dd615f36d8d41 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6975,12 +6975,12 @@ resolved "https://registry.yarnpkg.com/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz#8ace5259254426ccef57f3175bc64ed7095ed919" integrity sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw== -"@langchain/community@^0.2.15": - version "0.2.15" - resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.15.tgz#91ad5af44bfc72e83f6b6bd58cf5b29e53effed6" - integrity sha512-WOsNQGhriXh5tqRWfX3nthWO6RoVtM5kceX2GbJhqk09KV4R+1QmrOyph3OrpmjRA/YuSf0a+94LHY+c/QGolw== +"@langchain/community@^0.2.16": + version "0.2.16" + resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.16.tgz#5888baf7fc7ea272c5f91aaa0e71bc444167262d" + integrity sha512-dFDcMabKACvuRd0w6EIRLWf1ubPGZEeEwFt9v1jiEr4HCFxH0OF+iM1QUCcVRbB2fK5lqmKeTD1XAeZV8+AyXA== dependencies: - "@langchain/core" "~0.2.9" + "@langchain/core" "~0.2.11" "@langchain/openai" "~0.1.0" binary-extensions "^2.2.0" expr-eval "^2.0.2" @@ -6992,7 +6992,7 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@0.2.11", "@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.8 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.11", "@langchain/core@~0.2.9": +"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.11", "@langchain/core@~0.2.11": version "0.2.11" resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.11.tgz#5f47467e20e56b250831baef20083657c6facb4c" integrity sha512-d4SNL7WI0c3oHrV4WxCRH1/TNqdePXEzYjYwIb4aEH6lW1aM0utGhLbNthX+aYkOL4Ynx2FoG4h91ECIipiKWQ== @@ -7018,12 +7018,12 @@ "@langchain/core" ">0.1.61 <0.3.0" uuid "^9.0.1" -"@langchain/openai@0.2.1", "@langchain/openai@>=0.1.0 <0.3.0", "@langchain/openai@^0.2.1", "@langchain/openai@~0.1.0": - version "0.2.1" - resolved "https://registry.yarnpkg.com/@langchain/openai/-/openai-0.2.1.tgz#2c0c2cb6bd7839d8ce342c97099c6e35f2dde40d" - integrity sha512-Ti3C6ZIUPaueIPAfMljMnLu3GSGNq5KmrlHeWkIbrLShOBlzj4xj7mRfR73oWgAC0qivfxdkfbB0e+WCY+oRJw== +"@langchain/openai@>=0.1.0 <0.3.0", "@langchain/openai@^0.1.3", "@langchain/openai@~0.1.0": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@langchain/openai/-/openai-0.1.3.tgz#6eb0994e970d85ffa9aaeafb94449024ccf6ca63" + integrity sha512-riv/JC9x2A8b7GcHu8sx+mlZJ8KAwSSi231IPTlcciYnKozmrQ5H0vrtiD31fxiDbaRsk7tyCpkSBIOQEo7CyQ== dependencies: - "@langchain/core" ">=0.2.8 <0.3.0" + "@langchain/core" ">=0.2.5 <0.3.0" js-tiktoken "^1.0.12" openai "^4.49.1" zod "^3.22.4" @@ -8329,6 +8329,13 @@ "@smithy/types" "^3.2.0" tslib "^2.6.2" +"@smithy/is-array-buffer@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/is-array-buffer/-/is-array-buffer-2.0.0.tgz#8fa9b8040651e7ba0b2f6106e636a91354ff7d34" + integrity sha512-z3PjFjMyZNI98JFRJi/U0nGoLWMSJlDjAW4QUX2WNZLas5C0CmVV6LJ01JI0k90l7FvpmixjWxPFmENSClQ7ug== + dependencies: + tslib "^2.5.0" + "@smithy/is-array-buffer@^3.0.0": version "3.0.0" resolved "https://registry.yarnpkg.com/@smithy/is-array-buffer/-/is-array-buffer-3.0.0.tgz#9a95c2d46b8768946a9eec7f935feaddcffa5e7a" @@ -8364,6 +8371,14 @@ dependencies: tslib "^2.6.2" +"@smithy/util-buffer-from@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/util-buffer-from/-/util-buffer-from-2.0.0.tgz#7eb75d72288b6b3001bc5f75b48b711513091deb" + integrity sha512-/YNnLoHsR+4W4Vf2wL5lGv0ksg8Bmk3GEGxn2vEQt52AQaPSCuaO5PM5VM7lP1K9qHRKHwrPGktqVoAHKWHxzw== + dependencies: + "@smithy/is-array-buffer" "^2.0.0" + tslib "^2.5.0" + "@smithy/util-buffer-from@^3.0.0": version "3.0.0" resolved "https://registry.yarnpkg.com/@smithy/util-buffer-from/-/util-buffer-from-3.0.0.tgz#559fc1c86138a89b2edaefc1e6677780c24594e3" @@ -8394,7 +8409,15 @@ dependencies: tslib "^2.6.2" -"@smithy/util-utf8@3.0.0", "@smithy/util-utf8@^2.0.0", "@smithy/util-utf8@^3.0.0": +"@smithy/util-utf8@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/util-utf8/-/util-utf8-2.0.0.tgz#b4da87566ea7757435e153799df9da717262ad42" + integrity sha512-rctU1VkziY84n5OXe3bPNpKR001ZCME2JCaBBFgtiM2hfKbHFudc/BkMuPab8hRbLd0j3vbnBTTZ1igBf0wgiQ== + dependencies: + "@smithy/util-buffer-from" "^2.0.0" + tslib "^2.5.0" + +"@smithy/util-utf8@^3.0.0": version "3.0.0" resolved "https://registry.yarnpkg.com/@smithy/util-utf8/-/util-utf8-3.0.0.tgz#1a6a823d47cbec1fd6933e5fc87df975286d9d6a" integrity sha512-rUeT12bxFnplYDe815GXbq/oixEGHfRFFtcTF3YdDi/JaENIM6aSYYLJydG83UNzLXeRI5K8abYd/8Sp/QM0kA== From cb8e2a4639820f4365afde6c1c577cb678edb2fd Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Sun, 30 Jun 2024 22:55:46 +0200 Subject: [PATCH 17/55] test --- x-pack/plugins/elastic_assistant/tsconfig.json | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/tsconfig.json b/x-pack/plugins/elastic_assistant/tsconfig.json index 8f546d6e5fe01..f63a8da530196 100644 --- a/x-pack/plugins/elastic_assistant/tsconfig.json +++ b/x-pack/plugins/elastic_assistant/tsconfig.json @@ -45,7 +45,6 @@ "@kbn/core-saved-objects-api-server", "@kbn/langchain", "@kbn/stack-connectors-plugin", - "@kbn/security-plugin", ], "exclude": [ "target/**/*", From bd64e0b290daa11941b9cd345b8dc6e8097076fc Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Sun, 30 Jun 2024 21:06:33 +0000 Subject: [PATCH 18/55] [CI] Auto-commit changed files from 'node scripts/lint_ts_projects --fix' --- x-pack/plugins/elastic_assistant/tsconfig.json | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugins/elastic_assistant/tsconfig.json b/x-pack/plugins/elastic_assistant/tsconfig.json index f63a8da530196..8f546d6e5fe01 100644 --- a/x-pack/plugins/elastic_assistant/tsconfig.json +++ b/x-pack/plugins/elastic_assistant/tsconfig.json @@ -45,6 +45,7 @@ "@kbn/core-saved-objects-api-server", "@kbn/langchain", "@kbn/stack-connectors-plugin", + "@kbn/security-plugin", ], "exclude": [ "target/**/*", From abbae71574c7b3e61cd9df23cb145999c9eae978 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 8 Jul 2024 01:39:01 +0200 Subject: [PATCH 19/55] bump --- package.json | 7 ++++--- yarn.lock | 33 +++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/package.json b/package.json index 366c06df3bd84..f3cf46eced6dd 100644 --- a/package.json +++ b/package.json @@ -80,6 +80,7 @@ "resolutions": { "**/@bazel/typescript/protobufjs": "6.11.4", "**/@hello-pangea/dnd": "16.6.0", + "**/@langchain/core": "^0.2.14", "**/@types/node": "20.10.5", "**/@typescript-eslint/utils": "5.62.0", "**/chokidar": "^3.5.3", @@ -934,9 +935,9 @@ "@kbn/watcher-plugin": "link:x-pack/plugins/watcher", "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", - "@langchain/community": "^0.2.16", - "@langchain/core": "^0.2.11", - "@langchain/langgraph": "^0.0.25", + "@langchain/community": "^0.2.17", + "@langchain/core": "^0.2.14", + "@langchain/langgraph": "^0.0.26", "@langchain/openai": "^0.1.3", "@langtrase/trace-attributes": "^3.0.8", "@launchdarkly/node-server-sdk": "^9.4.6", diff --git a/yarn.lock b/yarn.lock index 5df5d5792131c..2ebc638f01d8d 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6975,10 +6975,10 @@ resolved "https://registry.yarnpkg.com/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz#8ace5259254426ccef57f3175bc64ed7095ed919" integrity sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw== -"@langchain/community@^0.2.16": - version "0.2.16" - resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.16.tgz#5888baf7fc7ea272c5f91aaa0e71bc444167262d" - integrity sha512-dFDcMabKACvuRd0w6EIRLWf1ubPGZEeEwFt9v1jiEr4HCFxH0OF+iM1QUCcVRbB2fK5lqmKeTD1XAeZV8+AyXA== +"@langchain/community@^0.2.17": + version "0.2.17" + resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.17.tgz#83e776b5d4d022b22bce907ce08668c56422e8de" + integrity sha512-lbmOvOvE0L2EV8lUb/ZcYyrLGF0sveGpYg9A0m6F/nDhuPG1HZqHvU/LiHsCaVO2WJPGowibMPTC02fUG/6dKA== dependencies: "@langchain/core" "~0.2.11" "@langchain/openai" "~0.1.0" @@ -6992,10 +6992,10 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.11", "@langchain/core@~0.2.11": - version "0.2.11" - resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.11.tgz#5f47467e20e56b250831baef20083657c6facb4c" - integrity sha512-d4SNL7WI0c3oHrV4WxCRH1/TNqdePXEzYjYwIb4aEH6lW1aM0utGhLbNthX+aYkOL4Ynx2FoG4h91ECIipiKWQ== +"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.14", "@langchain/core@~0.2.11": + version "0.2.14" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.14.tgz#2f607b3da03717bdb1ddb70ec64cf0f6d03bede3" + integrity sha512-e+dB2Sra+5sY4fPCXrZh4w1fLR68yfX2XSWS97DzHpeJRRnMGbOzjaaks6gy3wPyWhfS1WGLisJCMfCmFvH9fw== dependencies: ansi-styles "^5.0.0" camelcase "6" @@ -7006,17 +7006,17 @@ mustache "^4.2.0" p-queue "^6.6.2" p-retry "4" - uuid "^9.0.0" + uuid "^10.0.0" zod "^3.22.4" zod-to-json-schema "^3.22.3" -"@langchain/langgraph@^0.0.25": - version "0.0.25" - resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.25.tgz#2582d8652e2dda722f0c5043c1d0254a778e2486" - integrity sha512-DiTnB5Psm0y7TSgHdK4r/r+xzLohbN4zMQL+5Wk3EmOGX45ioBp98AqL8hYdyxKgHM6cjoIFHavHF7EhMg+ugQ== +"@langchain/langgraph@^0.0.26": + version "0.0.26" + resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.26.tgz#4971577366f7ec4ff1c634f0ab1aa9f0ec73fb0b" + integrity sha512-edfy9C9e5E1/HvWY9Jk4P3kF3RToA7OIVj9aXjCCb/8dvoMcmMsQSB990Uc29H+3lSpxHWrzSaZVQCAxD6XAAg== dependencies: "@langchain/core" ">0.1.61 <0.3.0" - uuid "^9.0.1" + uuid "^10.0.0" "@langchain/openai@>=0.1.0 <0.3.0", "@langchain/openai@^0.1.3", "@langchain/openai@~0.1.0": version "0.1.3" @@ -31290,6 +31290,11 @@ uuid@9.0.0: resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5" integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg== +uuid@^10.0.0: + version "10.0.0" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-10.0.0.tgz#5a95aa454e6e002725c79055fd42aaba30ca6294" + integrity sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ== + uuid@^3.3.2, uuid@^3.3.3: version "3.4.0" resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" From e6e33d929792ea26bb5286bc0960bf774ec238bb Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Wed, 10 Jul 2024 15:30:15 +0200 Subject: [PATCH 20/55] fix tool calling for non-graph --- package.json | 10 +-- .../server/language_models/bedrock_chat.ts | 39 +++-------- .../execute_custom_llm_chain/index.ts | 45 ++++++------- .../graphs/default_assistant_graph/helpers.ts | 5 +- .../graphs/default_assistant_graph/index.ts | 20 +++++- .../server/routes/helpers.ts | 1 + .../common/bedrock/constants.ts | 1 + .../stack_connectors/common/bedrock/schema.ts | 42 +++++++++++- .../stack_connectors/common/bedrock/types.ts | 4 ++ .../server/connector_types/bedrock/bedrock.ts | 67 ++++++++++++++++--- yarn.lock | 30 ++++----- 11 files changed, 182 insertions(+), 82 deletions(-) diff --git a/package.json b/package.json index b09d81079ad97..51624efd2f337 100644 --- a/package.json +++ b/package.json @@ -80,7 +80,7 @@ "resolutions": { "**/@bazel/typescript/protobufjs": "6.11.4", "**/@hello-pangea/dnd": "16.6.0", - "**/@langchain/core": "^0.2.14", + "**/@langchain/core": "^0.2.15", "**/@types/node": "20.10.5", "**/@typescript-eslint/utils": "5.62.0", "**/chokidar": "^3.5.3", @@ -88,7 +88,7 @@ "**/globule/minimatch": "^3.1.2", "**/hoist-non-react-statics": "^3.3.2", "**/isomorphic-fetch/node-fetch": "^2.6.7", - "**/langchain": "0.2.8", + "**/langchain": "^0.2.9", "**/react-intl/**/@types/react": "^17.0.45", "**/remark-parse/trim": "1.0.1", "**/sharp": "0.32.6", @@ -937,8 +937,8 @@ "@kbn/watcher-plugin": "link:x-pack/plugins/watcher", "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", - "@langchain/community": "^0.2.17", - "@langchain/core": "^0.2.14", + "@langchain/community": "^0.2.18", + "@langchain/core": "^0.2.15", "@langchain/langgraph": "^0.0.26", "@langchain/openai": "^0.1.3", "@langtrase/trace-attributes": "^3.0.8", @@ -1080,7 +1080,7 @@ "jsonwebtoken": "^9.0.2", "jsts": "^1.6.2", "kea": "^2.6.0", - "langchain": "^0.2.8", + "langchain": "^0.2.9", "langsmith": "^0.1.30", "launchdarkly-js-client-sdk": "^3.4.0", "launchdarkly-node-server-sdk": "^7.0.3", diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 7d4f7cb34f842..da92f5fb78d3f 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -10,7 +10,6 @@ import type { ActionsClient } from '@kbn/actions-plugin/server'; import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; import { Logger } from '@kbn/logging'; import { Readable } from 'stream'; -import { filter, isArray, map } from 'lodash'; import { PublicMethodsOf } from '@kbn/utility-types'; export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; @@ -21,11 +20,13 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { actionsClient, connectorId, logger, + graph, ...params }: { actionsClient: PublicMethodsOf; connectorId: string; logger: Logger; + graph?: boolean; } & BaseChatModelParams) { super({ ...params, @@ -35,19 +36,20 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { region: DEFAULT_BEDROCK_REGION, fetchFn: async (url, options) => { const inputBody = JSON.parse(options?.body as string); - const messages = map(inputBody.messages, sanitizeMessage); - if (this.streaming) { + if (this.streaming && graph) { const data = (await actionsClient.execute({ actionId: connectorId, params: { subAction: 'invokeStream', subActionParams: { - messages, + messages: inputBody.messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, system: inputBody.system, maxTokens: inputBody.maxTokens, + tools: inputBody.tools, + anthropicVersion: inputBody.anthropic_version, }, }, })) as { data: Readable }; @@ -60,43 +62,24 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { const data = (await actionsClient.execute({ actionId: connectorId, params: { - subAction: 'invokeAI', + subAction: 'invokeAIRaw', subActionParams: { - messages, + messages: inputBody.messages, temperature: inputBody.temperature, stopSequences: inputBody.stopSequences, system: inputBody.system, maxTokens: inputBody.maxTokens, + tools: inputBody.tools, + anthropicVersion: inputBody.anthropic_version, }, }, })) as { status: string; data: { message: string } }; return { ok: data.status === 'ok', - json: () => ({ - content: data.data.message, - type: 'message', - }), + json: () => data.data, } as unknown as Response; }, }); } } - -const sanitizeMessage = ({ - role, - content, -}: { - role: string; - content: string | Array<{ type: string; text: string }>; -}) => { - if (isArray(content)) { - const textContent = filter(content, ['type', 'text']); - return { role, content: textContent[textContent.length - 1]?.text }; - } - - return { - role, - content, - }; -}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index bbc5f3fc423b0..66374150f2e37 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -5,7 +5,11 @@ * 2.0. */ import agent, { Span } from 'elastic-apm-node'; -import { initializeAgentExecutorWithOptions } from 'langchain/agents'; +import { + initializeAgentExecutorWithOptions, + createToolCallingAgent, + AgentExecutor as lcAgentExecutor, +} from 'langchain/agents'; import { BufferMemory, ChatMessageHistory } from 'langchain/memory'; import { ToolInterface } from '@langchain/core/tools'; @@ -13,7 +17,7 @@ import { streamFactory } from '@kbn/ml-response-stream/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { RetrievalQAChain } from 'langchain/chains'; import { getDefaultArguments } from '@kbn/langchain/server'; -import { MessagesPlaceholder } from '@langchain/core/prompts'; +import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; import { withAssistantSpan } from '../tracers/apm/with_assistant_span'; import { getLlmClass } from '../../../routes/utils'; @@ -132,6 +136,21 @@ export const callAgentExecutor: AgentExecutor = async ({ agentType: 'openai-functions', ...executorArgs, }) + : llmType === 'bedrock' && bedrockChatEnabled + ? new lcAgentExecutor({ + agent: await createToolCallingAgent({ + llm, + tools, + prompt: ChatPromptTemplate.fromMessages([ + ['system', 'You are a helpful assistant'], + ['placeholder', '{chat_history}'], + ['human', '{input}'], + ['placeholder', '{agent_scratchpad}'], + ]), + streamRunnable: isStream, + }), + tools, + }) : await initializeAgentExecutorWithOptions(tools, llm, { agentType: 'structured-chat-zero-shot-react-description', ...executorArgs, @@ -184,9 +203,6 @@ export const callAgentExecutor: AgentExecutor = async ({ let message = ''; let tokenParentRunId = ''; - let finalOutputIndex = -1; - const finalOutputStartToken = '"action":"FinalAnswer","action_input":"'; - const finalOutputStopRegex = /(? = async ({ tokenParentRunId = parentRunId; } if (payload.length && !didEnd && tokenParentRunId === parentRunId) { - if (llmType === 'bedrock' && bedrockChatEnabled) { - const finalOutputEndIndex = payload.search(finalOutputStopRegex); - const currentOutput = message.replace(/\s/g, ''); - - if (currentOutput.includes(finalOutputStartToken)) { - finalOutputIndex = currentOutput.indexOf(finalOutputStartToken); - } - - if (finalOutputIndex > -1) { - push({ payload, type: 'content' }); - } - - if (finalOutputIndex > -1 && finalOutputEndIndex > -1) { - didEnd = true; - } - } else { - push({ payload, type: 'content' }); - } + push({ payload, type: 'content' }); // store message in case of error message += payload; } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index da1cdf74de761..e66ca4791aa5b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -164,7 +164,10 @@ export const streamGraph = async ({ if (event.name === 'ActionsClientBedrockChatModel') { const generations = event.data.output?.generations[0]; - if (generations && generations[0]?.generationInfo.stop_reason === 'end_turn') { + if ( + (generations && generations[0]?.generationInfo?.stop_reason === 'end_turn') || + generations?.[0]?.text + ) { handleStreamEnd(finalMessage); } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 885c357c76b20..711f1d6ba2b87 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -8,8 +8,13 @@ import { StructuredTool } from '@langchain/core/tools'; import { RetrievalQAChain } from 'langchain/chains'; import { getDefaultArguments } from '@kbn/langchain/server'; -import { createOpenAIFunctionsAgent, createStructuredChatAgent } from 'langchain/agents'; +import { + createOpenAIFunctionsAgent, + createStructuredChatAgent, + createToolCallingAgent, +} from 'langchain/agents'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; +import { ChatPromptTemplate } from '@langchain/core/prompts'; import { getLlmClass } from '../../../../routes/utils'; import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; @@ -66,6 +71,7 @@ export const callAssistantGraph: AgentExecutor = async ({ // prevents the agent from retrying on failure // failure could be due to bad connector, we should deliver that result to the client asap maxRetries: 0, + graph: true, }); const anonymizationFieldsRes = @@ -113,6 +119,18 @@ export const callAssistantGraph: AgentExecutor = async ({ prompt: openAIFunctionAgentPrompt, streamRunnable: isStream, }) + : llmType === 'bedrock' && bedrockChatEnabled + ? await createToolCallingAgent({ + llm, + tools, + prompt: ChatPromptTemplate.fromMessages([ + ['system', 'You are a helpful assistant'], + ['placeholder', '{chat_history}'], + ['human', '{input}'], + ['placeholder', '{agent_scratchpad}'], + ]), + streamRunnable: isStream, + }) : await createStructuredChatAgent({ llm, tools, diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 89a9a25793d5e..87cc105797343 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -460,6 +460,7 @@ export const langChainExecute = async ({ // New code path for LangGraph implementation, behind `assistantKnowledgeBaseByDefault` FF let result: StreamResponseWithHeaders | StaticReturnType; + if (enableKnowledgeBaseByDefault && request.body.isEnabledKnowledgeBase) { result = await callAssistantGraph(executorParams); } else { diff --git a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts index e2414f46dd985..f3b133dd783f6 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts @@ -17,6 +17,7 @@ export const BEDROCK_CONNECTOR_ID = '.bedrock'; export enum SUB_ACTION { RUN = 'run', INVOKE_AI = 'invokeAI', + INVOKE_AI_RAW = 'invokeAIRaw', INVOKE_STREAM = 'invokeStream', DASHBOARD = 'getDashboard', TEST = 'test', diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index bf35aa6bb8e0d..bb89265219a78 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -25,13 +25,14 @@ export const RunActionParamsSchema = schema.object({ // abort signal from client signal: schema.maybe(schema.any()), timeout: schema.maybe(schema.number()), + raw: schema.maybe(schema.boolean()), }); export const InvokeAIActionParamsSchema = schema.object({ messages: schema.arrayOf( schema.object({ role: schema.string(), - content: schema.string(), + content: schema.any(), }) ), model: schema.maybe(schema.string()), @@ -42,12 +43,51 @@ export const InvokeAIActionParamsSchema = schema.object({ // abort signal from client signal: schema.maybe(schema.any()), timeout: schema.maybe(schema.number()), + anthropicVersion: schema.maybe(schema.string()), + tools: schema.maybe( + schema.arrayOf( + schema.object({ + name: schema.string(), + description: schema.string(), + input_schema: schema.object({}, { unknowns: 'allow' }), + }) + ) + ), }); export const InvokeAIActionResponseSchema = schema.object({ message: schema.string(), }); +export const InvokeAIRawActionParamsSchema = schema.object({ + messages: schema.arrayOf( + schema.object({ + role: schema.string(), + content: schema.any(), + }) + ), + model: schema.maybe(schema.string()), + temperature: schema.maybe(schema.number()), + stopSequences: schema.maybe(schema.arrayOf(schema.string())), + system: schema.maybe(schema.string()), + maxTokens: schema.maybe(schema.number()), + // abort signal from client + signal: schema.maybe(schema.any()), + anthropicVersion: schema.maybe(schema.string()), + timeout: schema.maybe(schema.number()), + tools: schema.maybe( + schema.arrayOf( + schema.object({ + name: schema.string(), + description: schema.string(), + input_schema: schema.object({}, { unknowns: 'allow' }), + }) + ) + ), +}); + +export const InvokeAIRawActionResponseSchema = schema.object({}, { unknowns: 'allow' }); + export const RunApiLatestResponseSchema = schema.object( { stop_reason: schema.maybe(schema.string()), diff --git a/x-pack/plugins/stack_connectors/common/bedrock/types.ts b/x-pack/plugins/stack_connectors/common/bedrock/types.ts index 1256831fc7fa0..b144f78b91edd 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/types.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/types.ts @@ -15,6 +15,8 @@ import { RunActionResponseSchema, InvokeAIActionParamsSchema, InvokeAIActionResponseSchema, + InvokeAIRawActionParamsSchema, + InvokeAIRawActionResponseSchema, StreamingResponseSchema, RunApiLatestResponseSchema, } from './schema'; @@ -24,6 +26,8 @@ export type Secrets = TypeOf; export type RunActionParams = TypeOf; export type InvokeAIActionParams = TypeOf; export type InvokeAIActionResponse = TypeOf; +export type InvokeAIRawActionParams = TypeOf; +export type InvokeAIRawActionResponse = TypeOf; export type RunApiLatestResponse = TypeOf; export type RunActionResponse = TypeOf; export type StreamingResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 8b05c30a5b0cb..2b91a1ed948ee 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -15,6 +15,8 @@ import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard'; import { RunActionParamsSchema, InvokeAIActionParamsSchema, + InvokeAIRawActionParamsSchema, + InvokeAIRawActionResponseSchema, StreamingResponseSchema, RunActionResponseSchema, RunApiLatestResponseSchema, @@ -26,6 +28,8 @@ import { RunActionResponse, InvokeAIActionParams, InvokeAIActionResponse, + InvokeAIRawActionParams, + InvokeAIRawActionResponse, RunApiLatestResponse, } from '../../../common/bedrock/types'; import { @@ -90,6 +94,12 @@ export class BedrockConnector extends SubActionConnector { method: 'invokeStream', schema: InvokeAIActionParamsSchema, }); + + this.registerSubAction({ + name: SUB_ACTION.INVOKE_AI_RAW, + method: 'invokeAIRaw', + schema: InvokeAIRawActionParamsSchema, + }); } protected getResponseErrorMessage(error: AxiosError<{ message?: string }>): string { @@ -183,9 +193,9 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B return { available: response.success }; } - private async runApiDeprecated( - params: SubActionRequestParams // : SubActionRequestParams - ): Promise { + private async runApiRaw( + params: SubActionRequestParams // : SubActionRequestParams + ): Promise { const response = await this.request(params); return response.data; } @@ -213,7 +223,8 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B model: reqModel, signal, timeout, - }: RunActionParams): Promise { + raw, + }: RunActionParams): Promise { // set model on per request basis const currentModel = reqModel ?? this.model; const path = `/model/${currentModel}/invoke`; @@ -227,9 +238,13 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B // give up to 2 minutes for response timeout: timeout ?? DEFAULT_TIMEOUT_MS, }; + + if (raw) { + return this.runApiRaw({ ...requestArgs, responseSchema: InvokeAIRawActionResponseSchema }); + } // possible api received deprecated arguments, which will still work with the deprecated Claude 2 models if (usesDeprecatedArguments(body)) { - return this.runApiDeprecated({ ...requestArgs, responseSchema: RunActionResponseSchema }); + return this.runApiRaw({ ...requestArgs, responseSchema: RunActionResponseSchema }); } return this.runApiLatest({ ...requestArgs, responseSchema: RunApiLatestResponseSchema }); } @@ -282,9 +297,12 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B temperature, signal, timeout, + tools, }: InvokeAIActionParams): Promise { const res = (await this.streamApi({ - body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })), + body: JSON.stringify( + formatBedrockBody({ messages, stopSequences, system, temperature, tools }) + ), model, signal, timeout, @@ -310,16 +328,46 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, }: InvokeAIActionParams): Promise { - const res = await this.runApi({ + const res = (await this.runApi({ body: JSON.stringify( formatBedrockBody({ messages, stopSequences, system, temperature, maxTokens }) ), model, signal, timeout, - }); + })) as RunActionResponse; return { message: res.completion.trim() }; } + + public async invokeAIRaw({ + messages, + model, + stopSequences, + system, + temperature, + maxTokens = DEFAULT_TOKEN_LIMIT, + signal, + timeout, + tools, + anthropicVersion, + }: InvokeAIRawActionParams): Promise { + const res = await this.runApi({ + body: JSON.stringify({ + messages, + stop_sequences: stopSequences, + system, + temperature, + max_tokens: maxTokens, + tools, + anthropic_version: anthropicVersion, + }), + model, + signal, + timeout, + raw: true, + }); + return res; + } } const formatBedrockBody = ({ @@ -328,6 +376,7 @@ const formatBedrockBody = ({ temperature = 0, system, maxTokens = DEFAULT_TOKEN_LIMIT, + tools, }: { messages: Array<{ role: string; content: string }>; stopSequences?: string[]; @@ -335,12 +384,14 @@ const formatBedrockBody = ({ maxTokens?: number; // optional system message to be sent to the API system?: string; + tools?: Array<{ name: string; description: string }>; }) => ({ anthropic_version: 'bedrock-2023-05-31', ...ensureMessageFormat(messages, system), max_tokens: maxTokens, stop_sequences: stopSequences, temperature, + tools, }); /** diff --git a/yarn.lock b/yarn.lock index 0338fe1ea41e8..660a21a0964f1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6983,10 +6983,10 @@ resolved "https://registry.yarnpkg.com/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz#8ace5259254426ccef57f3175bc64ed7095ed919" integrity sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw== -"@langchain/community@^0.2.17": - version "0.2.17" - resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.17.tgz#83e776b5d4d022b22bce907ce08668c56422e8de" - integrity sha512-lbmOvOvE0L2EV8lUb/ZcYyrLGF0sveGpYg9A0m6F/nDhuPG1HZqHvU/LiHsCaVO2WJPGowibMPTC02fUG/6dKA== +"@langchain/community@^0.2.18": + version "0.2.18" + resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.18.tgz#127a7ac53a30dd6dedede887811fdd992061e2d2" + integrity sha512-UsCB97dMG87giQLniKx4bjv7OnMw2vQeavSt9gqOnGCnfb5IQBAgdjX4SjwFPbVGMz1HQoQKVlNqQ64ozCdgNg== dependencies: "@langchain/core" "~0.2.11" "@langchain/openai" "~0.1.0" @@ -6996,14 +6996,14 @@ js-yaml "^4.1.0" langchain "0.2.3" langsmith "~0.1.30" - uuid "^9.0.0" + uuid "^10.0.0" zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@>=0.2.9 <0.3.0", "@langchain/core@^0.2.14", "@langchain/core@~0.2.11": - version "0.2.14" - resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.14.tgz#2f607b3da03717bdb1ddb70ec64cf0f6d03bede3" - integrity sha512-e+dB2Sra+5sY4fPCXrZh4w1fLR68yfX2XSWS97DzHpeJRRnMGbOzjaaks6gy3wPyWhfS1WGLisJCMfCmFvH9fw== +"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.15", "@langchain/core@~0.2.11": + version "0.2.15" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.15.tgz#1bb99ac4fffe935c7ba37edcaa91abfba3c82219" + integrity sha512-L096itIBQ5XNsy5BCCPqIQEk/x4rzI+U4BhYT+fDBYtljESshIi/WzXdmiGfY/6MpVjB76jNuaRgMDmo1m9NeQ== dependencies: ansi-styles "^5.0.0" camelcase "6" @@ -21851,12 +21851,12 @@ kuler@^2.0.0: resolved "https://registry.yarnpkg.com/kuler/-/kuler-2.0.0.tgz#e2c570a3800388fb44407e851531c1d670b061b3" integrity sha512-Xq9nH7KlWZmXAtodXDDRE7vs6DU1gTU8zYDHDiWLSip45Egwq3plLHzPn27NgvzL2r1LMPC1vdqh98sQxtqj4A== -langchain@0.2.3, langchain@0.2.8, langchain@^0.2.8: - version "0.2.8" - resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.2.8.tgz#9bd77f5c12071d0ccb637c04fc33415e5369e5aa" - integrity sha512-kb2IOMA71xH8e6EXFg0l4S+QSMC/c796pj1+7mPBkR91HHwoyHZhFRrBaZv4tV+Td+Ba91J2uEDBmySklZLpNQ== +langchain@0.2.3, langchain@^0.2.9: + version "0.2.9" + resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.2.9.tgz#1341bdd7166f4f6da0b9337f363e409a79523dbb" + integrity sha512-iZ0l7BDVfoifqZlDl1gy3JP5mIdhYjWiToPlDnlmfHD748cw3okvF0gZo0ruT4nbftnQcaM7JzPUiNC43UPfgg== dependencies: - "@langchain/core" ">=0.2.9 <0.3.0" + "@langchain/core" ">=0.2.11 <0.3.0" "@langchain/openai" ">=0.1.0 <0.3.0" "@langchain/textsplitters" "~0.0.0" binary-extensions "^2.2.0" @@ -21868,7 +21868,7 @@ langchain@0.2.3, langchain@0.2.8, langchain@^0.2.8: ml-distance "^4.0.0" openapi-types "^12.1.3" p-retry "4" - uuid "^9.0.0" + uuid "^10.0.0" yaml "^2.2.1" zod "^3.22.4" zod-to-json-schema "^3.22.3" From 1ff60bf4d3df8d8dfdd82d1cacc49a3bce594f09 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Wed, 10 Jul 2024 16:05:19 +0200 Subject: [PATCH 21/55] graph-workflow --- .../server/language_models/bedrock_chat.ts | 43 ++++++++++--------- .../graphs/default_assistant_graph/helpers.ts | 43 +------------------ .../nodes/execute_tools.ts | 3 +- .../server/routes/helpers.ts | 1 - 4 files changed, 25 insertions(+), 65 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index da92f5fb78d3f..d130d7413d564 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -9,7 +9,7 @@ import { BedrockChat as _BedrockChat } from '@langchain/community/chat_models/be import type { ActionsClient } from '@kbn/actions-plugin/server'; import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; import { Logger } from '@kbn/logging'; -import { Readable } from 'stream'; +// import { Readable } from 'stream'; import { PublicMethodsOf } from '@kbn/utility-types'; export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; @@ -37,27 +37,28 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { fetchFn: async (url, options) => { const inputBody = JSON.parse(options?.body as string); - if (this.streaming && graph) { - const data = (await actionsClient.execute({ - actionId: connectorId, - params: { - subAction: 'invokeStream', - subActionParams: { - messages: inputBody.messages, - temperature: inputBody.temperature, - stopSequences: inputBody.stopSequences, - system: inputBody.system, - maxTokens: inputBody.maxTokens, - tools: inputBody.tools, - anthropicVersion: inputBody.anthropic_version, - }, - }, - })) as { data: Readable }; + // if (this.streaming && graph) { + // const data = (await actionsClient.execute({ + // actionId: connectorId, + // params: { + // subAction: 'invokeStream', + // subActionParams: { + // messages: inputBody.messages, + // temperature: inputBody.temperature, + // stopSequences: inputBody.stopSequences, + // system: inputBody.system, + // maxTokens: inputBody.maxTokens, + // tools: inputBody.tools, + // anthropicVersion: inputBody.anthropic_version, + // }, + // }, + // })) as { data: Readable }; - return { - body: Readable.toWeb(data.data), - } as unknown as Response; - } + // return { + // body: Readable.toWeb(data.data), + // json: () => Readable.toWeb(data.data), + // } as unknown as Response; + // } const data = (await actionsClient.execute({ actionId: connectorId, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index e66ca4791aa5b..5d1b0f32357b3 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -5,8 +5,6 @@ * 2.0. */ -/* eslint-disable complexity */ - import agent, { Span } from 'elastic-apm-node'; import type { Logger } from '@kbn/logging'; import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; @@ -88,11 +86,6 @@ export const streamGraph = async ({ version: 'v1', }); - let message = ''; - let finalOutputIndex = -1; - const finalOutputStartToken = '"action":"FinalAnswer","action_input":"'; - const finalOutputStopRegex = /(? { try { const { value, done } = await stream.next(); @@ -119,40 +112,6 @@ export const streamGraph = async ({ } } } - - if (event.name === 'ActionsClientBedrockChatModel') { - const msg = chunk; - - if (msg) { - if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { - /* empty */ - } else if (!didEnd) { - if (msg.response_metadata.finish_reason === 'stop') { - handleStreamEnd(finalMessage); - } else { - const finalOutputEndIndex = msg.content.search(finalOutputStopRegex); - const currentOutput = message.replace(/\s/g, ''); - - if (currentOutput.includes(finalOutputStartToken)) { - finalOutputIndex = currentOutput.indexOf(finalOutputStartToken); - } - - if (finalOutputIndex > -1 && finalOutputEndIndex > -1) { - didEnd = true; - handleStreamEnd(finalMessage); - return; - } - - if (finalOutputIndex > -1) { - finalMessage += msg.content; - push({ payload: msg.content, type: 'content' }); - } - - message += msg.content; - } - } - } - } } else if (event.event === 'on_llm_end') { if (event.name === 'ActionsClientChatOpenAI') { const generations = event.data.output?.generations[0]; @@ -168,7 +127,7 @@ export const streamGraph = async ({ (generations && generations[0]?.generationInfo?.stop_reason === 'end_turn') || generations?.[0]?.text ) { - handleStreamEnd(finalMessage); + handleStreamEnd(generations?.[0]?.text); } } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts index 8f3f9a0ae1d46..985707f391a39 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/execute_tools.ts @@ -8,6 +8,7 @@ import { RunnableConfig } from '@langchain/core/runnables'; import { StructuredTool } from '@langchain/core/tools'; import { ToolExecutor } from '@langchain/langgraph/prebuilt'; +import { isArray } from 'lodash'; import { AgentState, NodeParamsBase } from '../types'; export interface ExecuteToolsParams extends NodeParamsBase { @@ -33,7 +34,7 @@ export const executeTools = async ({ config, logger, state, tools }: ExecuteTool logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`); const toolExecutor = new ToolExecutor({ tools }); - const agentAction = state.agentOutcome; + const agentAction = isArray(state.agentOutcome) ? state.agentOutcome[0] : state.agentOutcome; if (!agentAction || 'returnValues' in agentAction) { throw new Error('Agent has not been run yet'); diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 87cc105797343..89a9a25793d5e 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -460,7 +460,6 @@ export const langChainExecute = async ({ // New code path for LangGraph implementation, behind `assistantKnowledgeBaseByDefault` FF let result: StreamResponseWithHeaders | StaticReturnType; - if (enableKnowledgeBaseByDefault && request.body.isEnabledKnowledgeBase) { result = await callAssistantGraph(executorParams); } else { From cdb428305821cacfbfffe117d80edbbb62d19d31 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 15 Jul 2024 16:01:53 +0200 Subject: [PATCH 22/55] test bedrock streaming --- package.json | 2 +- .../impl/capabilities/index.ts | 4 +- .../server/language_models/bedrock_chat.ts | 47 +++--- .../graphs/default_assistant_graph/graph.ts | 17 +- .../graphs/default_assistant_graph/helpers.ts | 149 ++++++++++-------- .../graphs/default_assistant_graph/index.ts | 50 +++--- .../default_assistant_graph/nodes/respond.ts | 37 +++++ .../common/experimental_features.ts | 4 +- yarn.lock | 13 +- 9 files changed, 206 insertions(+), 117 deletions(-) create mode 100644 x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts diff --git a/package.json b/package.json index 625d9219b1dbe..e8cfdecefca32 100644 --- a/package.json +++ b/package.json @@ -1081,7 +1081,7 @@ "jsts": "^1.6.2", "kea": "^2.6.0", "langchain": "^0.2.9", - "langsmith": "^0.1.30", + "langsmith": "^0.1.36", "launchdarkly-js-client-sdk": "^3.4.0", "launchdarkly-node-server-sdk": "^7.0.3", "load-json-file": "^6.2.0", diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts index 1e759df2819ed..819432bae6ec6 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts @@ -19,7 +19,7 @@ export type AssistantFeatureKey = keyof AssistantFeatures; * Default features available to the elastic assistant */ export const defaultAssistantFeatures = Object.freeze({ - assistantKnowledgeBaseByDefault: false, + assistantKnowledgeBaseByDefault: true, assistantModelEvaluation: false, - assistantBedrockChat: false, + assistantBedrockChat: true, }); diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index d130d7413d564..0a329ba41e03f 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -9,7 +9,7 @@ import { BedrockChat as _BedrockChat } from '@langchain/community/chat_models/be import type { ActionsClient } from '@kbn/actions-plugin/server'; import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; import { Logger } from '@kbn/logging'; -// import { Readable } from 'stream'; +import { Readable } from 'stream'; import { PublicMethodsOf } from '@kbn/utility-types'; export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; @@ -37,28 +37,27 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { fetchFn: async (url, options) => { const inputBody = JSON.parse(options?.body as string); - // if (this.streaming && graph) { - // const data = (await actionsClient.execute({ - // actionId: connectorId, - // params: { - // subAction: 'invokeStream', - // subActionParams: { - // messages: inputBody.messages, - // temperature: inputBody.temperature, - // stopSequences: inputBody.stopSequences, - // system: inputBody.system, - // maxTokens: inputBody.maxTokens, - // tools: inputBody.tools, - // anthropicVersion: inputBody.anthropic_version, - // }, - // }, - // })) as { data: Readable }; + if (this.streaming && !inputBody.tools?.length) { + const data = (await actionsClient.execute({ + actionId: connectorId, + params: { + subAction: 'invokeStream', + subActionParams: { + messages: inputBody.messages, + temperature: inputBody.temperature, + stopSequences: inputBody.stop_sequences, + system: inputBody.system, + maxTokens: inputBody.max_tokens, + tools: inputBody.tools, + anthropicVersion: inputBody.anthropic_version, + }, + }, + })) as { data: Readable }; - // return { - // body: Readable.toWeb(data.data), - // json: () => Readable.toWeb(data.data), - // } as unknown as Response; - // } + return { + body: Readable.toWeb(data.data), + } as unknown as Response; + } const data = (await actionsClient.execute({ actionId: connectorId, @@ -67,9 +66,9 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { subActionParams: { messages: inputBody.messages, temperature: inputBody.temperature, - stopSequences: inputBody.stopSequences, + stopSequences: inputBody.stop_sequences, system: inputBody.system, - maxTokens: inputBody.maxTokens, + maxTokens: inputBody.max_tokens, tools: inputBody.tools, anthropicVersion: inputBody.anthropic_version, }, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index b16f7d9693e5f..50ad6ba09cc8c 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -33,6 +33,7 @@ import { PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChanges, } from './nodes/persist_conversation_changes'; +import { RESPOND_NODE, respond } from './nodes/respond'; export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph'; @@ -40,6 +41,7 @@ interface GetDefaultAssistantGraphParams { agentRunnable: AgentRunnableSequence; dataClients?: AssistantDataClients; conversationId?: string; + getLlmInstance: () => BaseChatModel; llm: BaseChatModel; logger: Logger; tools: StructuredTool[]; @@ -56,6 +58,7 @@ export const getDefaultAssistantGraph = ({ agentRunnable, conversationId, dataClients, + getLlmInstance, llm, logger, responseLanguage, @@ -142,6 +145,12 @@ export const getDefaultAssistantGraph = ({ conversationId, replacements, }); + const respondNode = (state: AgentState) => + respond({ + ...nodeParams, + llm: getLlmInstance(), + state, + }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); const shouldContinueGenerateTitleEdge = (state: AgentState) => shouldContinueGenerateTitle({ ...nodeParams, state }); @@ -158,6 +167,7 @@ export const getDefaultAssistantGraph = ({ | 'generateChatTitle' | 'getPersistedConversation' | 'persistConversationChanges' + | 'respond' >({ channels: graphState, }); @@ -167,6 +177,7 @@ export const getDefaultAssistantGraph = ({ graph.addNode(PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChangesNode); graph.addNode(AGENT_NODE, runAgentNode); graph.addNode(TOOLS_NODE, executeToolsNode); + graph.addNode(RESPOND_NODE, respondNode); // Add edges, alternating between agent and action until finished graph.addConditionalEdges(START, shouldContinueGetConversationEdge, { @@ -180,7 +191,11 @@ export const getDefaultAssistantGraph = ({ graph.addEdge(GENERATE_CHAT_TITLE_NODE, PERSIST_CONVERSATION_CHANGES_NODE); graph.addEdge(PERSIST_CONVERSATION_CHANGES_NODE, AGENT_NODE); // Add conditional edge for basic routing - graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, { continue: TOOLS_NODE, end: END }); + graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, { + continue: TOOLS_NODE, + end: RESPOND_NODE, + }); + graph.addEdge(RESPOND_NODE, END); graph.addEdge(TOOLS_NODE, AGENT_NODE); // Compile the graph return graph.compile(); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 5d1b0f32357b3..390a5aa20aa85 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -12,6 +12,7 @@ import { transformError } from '@kbn/securitysolution-es-utils'; import type { KibanaRequest } from '@kbn/core-http-server'; import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; +import { AIMessageChunk } from '@langchain/core/messages'; import { withAssistantSpan } from '../../tracers/apm/with_assistant_span'; import { AGENT_NODE_TAG } from './nodes/run_agent'; import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph'; @@ -20,7 +21,9 @@ import type { OnLlmResponse, TraceOptions } from '../../executors/types'; interface StreamGraphParams { apmTracer: APMTracer; assistantGraph: DefaultAssistantGraph; + bedrockChatEnabled: boolean; inputs: { input: string }; + llmType: string | undefined; logger: Logger; onLlmResponse?: OnLlmResponse; request: KibanaRequest; @@ -40,6 +43,8 @@ interface StreamGraphParams { */ export const streamGraph = async ({ apmTracer, + llmType, + bedrockChatEnabled, assistantGraph, inputs, logger, @@ -77,81 +82,91 @@ export const streamGraph = async ({ streamingSpan?.end(); }; - let finalMessage = ''; - const stream = assistantGraph.streamEvents(inputs, { - callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])], - runName: DEFAULT_ASSISTANT_GRAPH_ID, - streamMode: 'values', - tags: traceOptions?.tags ?? [], - version: 'v1', - }); - - const processEvent = async () => { - try { - const { value, done } = await stream.next(); - if (done) return; - - const event = value; - - // only process events that are part of the agent run - if ((event.tags || []).includes(AGENT_NODE_TAG)) { - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; + if (llmType === 'bedrock' && bedrockChatEnabled) { + const stream = await assistantGraph.streamEvents( + inputs, + { + // callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])], + // runName: DEFAULT_ASSISTANT_GRAPH_ID, + // // streamMode: 'updates', + // tags: traceOptions?.tags ?? [], + version: 'v2', + }, + { includeNames: ['Summarizer'] } + ); + + for await (const { event, data } of stream) { + if (event === 'on_chat_model_stream') { + const msg = data.chunk as AIMessageChunk; + + if (!msg.tool_call_chunks?.length) { + push({ payload: msg.content, type: 'content' }); + } + } - if (event.name === 'ActionsClientChatOpenAI') { - const msg = chunk.message; + if (event === 'on_chat_model_end') { + handleStreamEnd(data.output.content); + } + } + } else { + let finalMessage = ''; + const stream = assistantGraph.streamEvents(inputs, { + callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])], + runName: DEFAULT_ASSISTANT_GRAPH_ID, + streamMode: 'values', + tags: traceOptions?.tags ?? [], + version: 'v1', + }); - if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { - /* empty */ - } else if (!didEnd) { - if (msg.response_metadata.finish_reason === 'stop') { - handleStreamEnd(finalMessage); - } else { - push({ payload: msg.content, type: 'content' }); - finalMessage += msg.content; + const processEvent = async () => { + try { + const { value, done } = await stream.next(); + if (done) return; + const event = value; + // only process events that are part of the agent run + if ((event.tags || []).includes(AGENT_NODE_TAG)) { + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; + if (event.name === 'ActionsClientChatOpenAI') { + const msg = chunk.message; + if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { + /* empty */ + } else if (!didEnd) { + if (msg.response_metadata.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } else { + push({ payload: msg.content, type: 'content' }); + finalMessage += msg.content; + } } } - } - } else if (event.event === 'on_llm_end') { - if (event.name === 'ActionsClientChatOpenAI') { - const generations = event.data.output?.generations[0]; - if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { - handleStreamEnd(finalMessage); - } - } - - if (event.name === 'ActionsClientBedrockChatModel') { - const generations = event.data.output?.generations[0]; - - if ( - (generations && generations[0]?.generationInfo?.stop_reason === 'end_turn') || - generations?.[0]?.text - ) { - handleStreamEnd(generations?.[0]?.text); + } else if (event.event === 'on_llm_end') { + if (event.name === 'ActionsClientChatOpenAI') { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } } } } + void processEvent(); + } catch (err) { + // if I throw an error here, it crashes the server. Not sure how to get around that. + // If I put await on this function the error works properly, but when there is not an error + // it waits for the entire stream to complete before resolving + const error = transformError(err); + if (error.message === 'AbortError') { + // user aborted the stream, we must end it manually here + return handleStreamEnd(finalMessage); + } + logger.error(`Error streaming from LangChain: ${error.message}`); + push({ payload: error.message, type: 'content' }); + handleStreamEnd(error.message, true); } - - void processEvent(); - } catch (err) { - // if I throw an error here, it crashes the server. Not sure how to get around that. - // If I put await on this function the error works properly, but when there is not an error - // it waits for the entire stream to complete before resolving - const error = transformError(err); - - if (error.message === 'AbortError') { - // user aborted the stream, we must end it manually here - return handleStreamEnd(finalMessage); - } - logger.error(`Error streaming from LangChain: ${error.message}`); - push({ payload: error.message, type: 'content' }); - handleStreamEnd(error.message, true); - } - }; - - // Start processing events, do not await! Return `responseWithHeaders` immediately - await processEvent(); + }; + // Start processing events, do not await! Return `responseWithHeaders` immediately + await processEvent(); + } return responseWithHeaders; }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 711f1d6ba2b87..c9f846b92eaf4 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -54,25 +54,26 @@ export const callAssistantGraph: AgentExecutor = async ({ const logger = parentLogger.get('defaultAssistantGraph'); const isOpenAI = llmType === 'openai'; const llmClass = getLlmClass(llmType, bedrockChatEnabled); + const getLlmInstance = () => + new llmClass({ + actionsClient, + connectorId, + llmType, + logger, + // possible client model override, + // let this be undefined otherwise so the connector handles the model + model: request.body.model, + // ensure this is defined because we default to it in the language_models + // This is where the LangSmith logs (Metadata > Invocation Params) are set + temperature: getDefaultArguments(llmType).temperature, + signal: abortSignal, + streaming: isStream, + // prevents the agent from retrying on failure + // failure could be due to bad connector, we should deliver that result to the client asap + maxRetries: 0, + }); - const llm = new llmClass({ - actionsClient, - connectorId, - llmType, - logger, - // possible client model override, - // let this be undefined otherwise so the connector handles the model - model: request.body.model, - // ensure this is defined because we default to it in the language_models - // This is where the LangSmith logs (Metadata > Invocation Params) are set - temperature: getDefaultArguments(llmType).temperature, - signal: abortSignal, - streaming: isStream, - // prevents the agent from retrying on failure - // failure could be due to bad connector, we should deliver that result to the client asap - maxRetries: 0, - graph: true, - }); + const llm = getLlmInstance(); const anonymizationFieldsRes = await dataClients?.anonymizationFieldsDataClient?.findDocuments({ @@ -145,6 +146,7 @@ export const callAssistantGraph: AgentExecutor = async ({ conversationId, dataClients, llm, + getLlmInstance, logger, tools, responseLanguage, @@ -153,7 +155,17 @@ export const callAssistantGraph: AgentExecutor = async ({ const inputs = { input: latestMessage[0]?.content as string }; if (isStream) { - return streamGraph({ apmTracer, assistantGraph, inputs, logger, onLlmResponse, request }); + return streamGraph({ + apmTracer, + assistantGraph, + llmType, + bedrockChatEnabled, + inputs, + logger, + onLlmResponse, + request, + traceOptions, + }); } const graphResponse = await invokeGraph({ diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts new file mode 100644 index 0000000000000..86b1d13503f33 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { AgentState } from '../types'; + +export const RESPOND_NODE = 'respond'; +export const respond = async ({ llm, state }: { llm: BaseChatModel; state: AgentState }) => { + // Assign the final model call a run name + // console.error('state', state); + // const { messages } = state; + const userMessage = [ + 'user', + `Respond exactly with + ${state.agentOutcome?.returnValues?.output}. + + Do not verify, confirm or anything else. Just reply with the same content as provided above.`, + ]; + // console.error('messages', messages); + // console.error('userMessage', userMessage); + const responseMessage = await llm + // .bindTools([]) + .withConfig({ runName: 'Summarizer' }) + .invoke([userMessage]); + + return { + agentOutcome: { + returnValues: { + output: responseMessage.content, + }, + }, + }; +}; diff --git a/x-pack/plugins/security_solution/common/experimental_features.ts b/x-pack/plugins/security_solution/common/experimental_features.ts index 0182318258c05..5969df250e612 100644 --- a/x-pack/plugins/security_solution/common/experimental_features.ts +++ b/x-pack/plugins/security_solution/common/experimental_features.ts @@ -123,12 +123,12 @@ export const allowedExperimentalValues = Object.freeze({ /** * Enables the Assistant Knowledge Base by default, introduced in `8.15.0`. */ - assistantKnowledgeBaseByDefault: false, + assistantKnowledgeBaseByDefault: true, /** * Enables the Assistant BedrockChat Langchain model, introduced in `8.15.0`. */ - assistantBedrockChat: false, + assistantBedrockChat: true, /** * Enables the Managed User section inside the new user details flyout. diff --git a/yarn.lock b/yarn.lock index f049ed88fcb6f..81d5ba3b84a7a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -21884,7 +21884,18 @@ langchainhub@~0.0.8: resolved "https://registry.yarnpkg.com/langchainhub/-/langchainhub-0.0.8.tgz#fd4b96dc795e22e36c1a20bad31b61b0c33d3110" integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== -langsmith@^0.1.30, langsmith@~0.1.30: +langsmith@^0.1.36: + version "0.1.36" + resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.36.tgz#5f21b9c6bcd4ea9c0e943f83e304a53e5232297d" + integrity sha512-D5hhkFl31uxFdffx0lA6pin0lt8Pv2dpHFZYpSgEzvQ26PQ/Y/tnniQ+aCNokIXuLhMa7uqLtb6tfwjfiZXgdg== + dependencies: + "@types/uuid" "^9.0.1" + commander "^10.0.1" + p-queue "^6.6.2" + p-retry "4" + uuid "^9.0.0" + +langsmith@~0.1.30: version "0.1.32" resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.32.tgz#38938b0e8685522087b697b8200c488c6490c137" integrity sha512-EUWHIH6fiOCGRYdzgwGoXwJxCMyUrL+bmUcxoVmkXoXoAGDOVinz8bqJLKbxotsQWqM64NKKsW85OTIutgNaMQ== From 109374d7f94709eae1a6337944bb34e4ca138ecd Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 15 Jul 2024 16:27:04 +0200 Subject: [PATCH 23/55] fix --- yarn.lock | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/yarn.lock b/yarn.lock index 81d5ba3b84a7a..22eca8c4500ce 100644 --- a/yarn.lock +++ b/yarn.lock @@ -21884,7 +21884,7 @@ langchainhub@~0.0.8: resolved "https://registry.yarnpkg.com/langchainhub/-/langchainhub-0.0.8.tgz#fd4b96dc795e22e36c1a20bad31b61b0c33d3110" integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== -langsmith@^0.1.36: +langsmith@^0.1.36, langsmith@~0.1.30: version "0.1.36" resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.36.tgz#5f21b9c6bcd4ea9c0e943f83e304a53e5232297d" integrity sha512-D5hhkFl31uxFdffx0lA6pin0lt8Pv2dpHFZYpSgEzvQ26PQ/Y/tnniQ+aCNokIXuLhMa7uqLtb6tfwjfiZXgdg== @@ -21895,17 +21895,6 @@ langsmith@^0.1.36: p-retry "4" uuid "^9.0.0" -langsmith@~0.1.30: - version "0.1.32" - resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.32.tgz#38938b0e8685522087b697b8200c488c6490c137" - integrity sha512-EUWHIH6fiOCGRYdzgwGoXwJxCMyUrL+bmUcxoVmkXoXoAGDOVinz8bqJLKbxotsQWqM64NKKsW85OTIutgNaMQ== - dependencies: - "@types/uuid" "^9.0.1" - commander "^10.0.1" - p-queue "^6.6.2" - p-retry "4" - uuid "^9.0.0" - language-subtag-registry@~0.3.2: version "0.3.21" resolved "https://registry.yarnpkg.com/language-subtag-registry/-/language-subtag-registry-0.3.21.tgz#04ac218bea46f04cb039084602c6da9e788dd45a" From 90943b0818ed1eac657bd187213c6fc36b673bcd Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Tue, 16 Jul 2024 15:14:41 +0200 Subject: [PATCH 24/55] test VertexChat --- package.json | 15 ++-- x-pack/packages/kbn-langchain/server/index.ts | 2 + .../server/language_models/vertex_chat.ts | 76 +++++++++++++++++++ .../graphs/default_assistant_graph/helpers.ts | 9 +-- .../graphs/default_assistant_graph/index.ts | 2 +- .../default_assistant_graph/nodes/respond.ts | 2 +- .../elastic_assistant/server/routes/utils.ts | 3 + .../plugins/elastic_assistant/server/types.ts | 4 +- .../common/gemini/constants.ts | 1 + .../stack_connectors/common/gemini/schema.ts | 14 ++++ .../stack_connectors/common/gemini/types.ts | 5 ++ .../server/connector_types/gemini/gemini.ts | 57 ++++++++++++-- yarn.lock | 66 +++++++++++----- 13 files changed, 213 insertions(+), 43 deletions(-) create mode 100644 x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts diff --git a/package.json b/package.json index 0b3c3f37cce1c..2b271a4f0d8d2 100644 --- a/package.json +++ b/package.json @@ -80,7 +80,7 @@ "resolutions": { "**/@bazel/typescript/protobufjs": "6.11.4", "**/@hello-pangea/dnd": "16.6.0", - "**/@langchain/core": "^0.2.15", + "**/@langchain/core": "^0.2.16", "**/@types/node": "20.10.5", "**/@typescript-eslint/utils": "5.62.0", "**/chokidar": "^3.5.3", @@ -88,7 +88,7 @@ "**/globule/minimatch": "^3.1.2", "**/hoist-non-react-statics": "^3.3.2", "**/isomorphic-fetch/node-fetch": "^2.6.7", - "**/langchain": "^0.2.9", + "**/langchain": "^0.2.10", "**/react-intl/**/@types/react": "^17.0.45", "**/remark-parse/trim": "1.0.1", "**/sharp": "0.32.6", @@ -938,9 +938,10 @@ "@kbn/watcher-plugin": "link:x-pack/plugins/watcher", "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", - "@langchain/community": "^0.2.18", - "@langchain/core": "^0.2.15", - "@langchain/langgraph": "^0.0.26", + "@langchain/community": "^0.2.19", + "@langchain/core": "^0.2.16", + "@langchain/google-common": "^0.0.20", + "@langchain/langgraph": "^0.0.27", "@langchain/openai": "^0.1.3", "@langtrase/trace-attributes": "^3.0.8", "@launchdarkly/node-server-sdk": "^9.4.7", @@ -1081,8 +1082,8 @@ "jsonwebtoken": "^9.0.2", "jsts": "^1.6.2", "kea": "^2.6.0", - "langchain": "^0.2.9", - "langsmith": "^0.1.36", + "langchain": "^0.2.10", + "langsmith": "^0.1.37", "launchdarkly-js-client-sdk": "^3.4.0", "launchdarkly-node-server-sdk": "^7.0.3", "load-json-file": "^6.2.0", diff --git a/x-pack/packages/kbn-langchain/server/index.ts b/x-pack/packages/kbn-langchain/server/index.ts index 126a9f6bdbfc6..a2baf6b45a2b3 100644 --- a/x-pack/packages/kbn-langchain/server/index.ts +++ b/x-pack/packages/kbn-langchain/server/index.ts @@ -9,6 +9,7 @@ import { ActionsClientBedrockChatModel } from './language_models/bedrock_chat'; import { ActionsClientChatOpenAI } from './language_models/chat_openai'; import { ActionsClientLlm } from './language_models/llm'; import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model'; +import { ActionsClientVertexChatModel } from './language_models/vertex_chat'; import { parseBedrockStream } from './utils/bedrock'; import { parseGeminiResponse } from './utils/gemini'; import { getDefaultArguments } from './language_models/constants'; @@ -21,4 +22,5 @@ export { ActionsClientChatOpenAI, ActionsClientLlm, ActionsClientSimpleChatModel, + ActionsClientVertexChatModel, }; diff --git a/x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts new file mode 100644 index 0000000000000..967790b420b98 --- /dev/null +++ b/x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ActionsClient } from '@kbn/actions-plugin/server'; +import { PublicMethodsOf } from '@kbn/utility-types'; +import { + ChatGoogleBase, + ChatGoogleBaseInput, + GoogleBaseLLMInput, + ReadableJsonStream, +} from '@langchain/google-common'; +import { Readable } from 'stream'; + +export type ChatGoogleInput = ChatGoogleBaseInput<{}>; + +export class ActionsClientVertexChatModel extends ChatGoogleBase<{}> implements ChatGoogleInput { + #actionsClient: PublicMethodsOf; + #connectorId: string; + streaming: boolean; + model: string = ''; + temperature: number = 0; + #maxTokens?: number; + + static lc_name() { + return 'ChatVertexAI'; + } + + constructor({ actionsClient, connectorId, streaming, temperature, model }) { + super({ + // ...fields, + platformType: 'gcp', + }); + + this.#actionsClient = actionsClient; + this.#connectorId = connectorId; + this.model = model; + this.temperature = temperature ?? 0; + this.streaming = streaming; + } + + override buildAbstractedClient(fields: GoogleBaseLLMInput<{}> | undefined) { + return { + request: async (props) => { + // create a new connector request body with the assistant message: + const requestBody = { + actionId: 'my-gemini-ai' || this.#connectorId, + params: { + subAction: this.streaming ? 'invokeStream' : 'invokeAIRaw', + subActionParams: { + model: 'gemini-1.5-pro-preview-0409' || this.model, + messages: props.data, + }, + }, + }; + + const actionResult = await this.#actionsClient.execute(requestBody); + + if (this.streaming) { + return { + data: new ReadableJsonStream( + actionResult.data ? Readable.toWeb(actionResult.data) : null + ), + }; + } + + return actionResult; + }, + getProjectId: () => Promise.resolve(''), + clientType: '', + }; + } +} diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 390a5aa20aa85..eb676030fc3f6 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -82,14 +82,13 @@ export const streamGraph = async ({ streamingSpan?.end(); }; - if (llmType === 'bedrock' && bedrockChatEnabled) { + if ((llmType === 'bedrock' || llmType === 'gemini') && bedrockChatEnabled) { const stream = await assistantGraph.streamEvents( inputs, { - // callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])], - // runName: DEFAULT_ASSISTANT_GRAPH_ID, - // // streamMode: 'updates', - // tags: traceOptions?.tags ?? [], + callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])], + runName: DEFAULT_ASSISTANT_GRAPH_ID, + tags: traceOptions?.tags ?? [], version: 'v2', }, { includeNames: ['Summarizer'] } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index c9f846b92eaf4..425834ea516f2 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -120,7 +120,7 @@ export const callAssistantGraph: AgentExecutor = async ({ prompt: openAIFunctionAgentPrompt, streamRunnable: isStream, }) - : llmType === 'bedrock' && bedrockChatEnabled + : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled ? await createToolCallingAgent({ llm, tools, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts index 86b1d13503f33..68c82ab9482b0 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -16,7 +16,7 @@ export const respond = async ({ llm, state }: { llm: BaseChatModel; state: Agent const userMessage = [ 'user', `Respond exactly with - ${state.agentOutcome?.returnValues?.output}. + ${state.agentOutcome?.returnValues?.output} Do not verify, confirm or anything else. Just reply with the same content as provided above.`, ]; diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 72aa5218ac6ce..2948453b4fba0 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -17,6 +17,7 @@ import { ActionsClientChatOpenAI, ActionsClientBedrockChatModel, ActionsClientSimpleChatModel, + ActionsClientVertexChatModel, } from '@kbn/langchain/server'; import { CustomHttpRequestError } from './custom_http_request_error'; @@ -185,4 +186,6 @@ export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) => ? ActionsClientChatOpenAI : llmType === 'bedrock' && bedrockChatEnabled ? ActionsClientBedrockChatModel + : llmType === 'gemini' && bedrockChatEnabled + ? ActionsClientVertexChatModel : ActionsClientSimpleChatModel; diff --git a/x-pack/plugins/elastic_assistant/server/types.ts b/x-pack/plugins/elastic_assistant/server/types.ts index 010309693e1ae..dc19b23ce45a2 100755 --- a/x-pack/plugins/elastic_assistant/server/types.ts +++ b/x-pack/plugins/elastic_assistant/server/types.ts @@ -39,6 +39,7 @@ import { ActionsClientChatOpenAI, ActionsClientLlm, ActionsClientSimpleChatModel, + ActionsClientVertexChatModel, } from '@kbn/langchain/server'; import { AttackDiscoveryDataClient } from './ai_assistant_data_clients/attack_discovery'; @@ -217,7 +218,8 @@ export interface AssistantTool { export type AssistantToolLlm = | ActionsClientBedrockChatModel | ActionsClientChatOpenAI - | ActionsClientSimpleChatModel; + | ActionsClientSimpleChatModel + | ActionsClientVertexChatModel; export interface AssistantToolParams { alertsIndexPattern?: string; diff --git a/x-pack/plugins/stack_connectors/common/gemini/constants.ts b/x-pack/plugins/stack_connectors/common/gemini/constants.ts index bbad177547033..16bdfb2d91766 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/constants.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/constants.ts @@ -19,6 +19,7 @@ export enum SUB_ACTION { DASHBOARD = 'getDashboard', TEST = 'test', INVOKE_AI = 'invokeAI', + INVOKE_AI_RAW = 'invokeAIRaw', INVOKE_STREAM = 'invokeStream', } diff --git a/x-pack/plugins/stack_connectors/common/gemini/schema.ts b/x-pack/plugins/stack_connectors/common/gemini/schema.ts index 91b523ef4853b..5a040dd07e17c 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/schema.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/schema.ts @@ -26,6 +26,7 @@ export const RunActionParamsSchema = schema.object({ timeout: schema.maybe(schema.number()), temperature: schema.maybe(schema.number()), stopSequences: schema.maybe(schema.arrayOf(schema.string())), + raw: schema.maybe(schema.boolean()), }); export const RunApiResponseSchema = schema.object({ @@ -52,6 +53,8 @@ export const RunActionResponseSchema = schema.object( { unknowns: 'ignore' } ); +export const RunActionRawResponse = schema.any(); + export const InvokeAIActionParamsSchema = schema.object({ messages: schema.any(), model: schema.maybe(schema.string()), @@ -61,6 +64,15 @@ export const InvokeAIActionParamsSchema = schema.object({ timeout: schema.maybe(schema.number()), }); +export const InvokeAIRawActionParamsSchema = schema.object({ + messages: schema.any(), + model: schema.maybe(schema.string()), + temperature: schema.maybe(schema.number()), + stopSequences: schema.maybe(schema.arrayOf(schema.string())), + signal: schema.maybe(schema.any()), + timeout: schema.maybe(schema.number()), +}); + export const InvokeAIActionResponseSchema = schema.object({ message: schema.string(), usageMetadata: schema.maybe( @@ -72,6 +84,8 @@ export const InvokeAIActionResponseSchema = schema.object({ ), }); +export const InvokeAIRawActionResponseSchema = schema.any(); + export const StreamingResponseSchema = schema.any(); export const DashboardActionParamsSchema = schema.object({ diff --git a/x-pack/plugins/stack_connectors/common/gemini/types.ts b/x-pack/plugins/stack_connectors/common/gemini/types.ts index bc9a96ebdfdd7..206b2423f46ce 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/types.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/types.ts @@ -16,6 +16,8 @@ import { RunApiResponseSchema, InvokeAIActionParamsSchema, InvokeAIActionResponseSchema, + InvokeAIRawActionParamsSchema, + InvokeAIRawActionResponseSchema, StreamingResponseSchema, } from './schema'; @@ -24,8 +26,11 @@ export type Secrets = TypeOf; export type RunActionParams = TypeOf; export type RunApiResponse = TypeOf; export type RunActionResponse = TypeOf; +export type RunActionRawResponse = TypeOf; export type DashboardActionParams = TypeOf; export type DashboardActionResponse = TypeOf; export type InvokeAIActionParams = TypeOf; export type InvokeAIActionResponse = TypeOf; +export type InvokeAIRawActionParams = TypeOf; +export type InvokeAIRawActionResponse = TypeOf; export type StreamingResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index 31ec65bca0859..6468eccb481a9 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -17,6 +17,7 @@ import { RunActionParamsSchema, RunApiResponseSchema, InvokeAIActionParamsSchema, + InvokeAIRawActionParamsSchema, StreamingResponseSchema, } from '../../../common/gemini/schema'; import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard'; @@ -31,6 +32,8 @@ import { StreamingResponse, InvokeAIActionParams, InvokeAIActionResponse, + InvokeAIRawActionParams, + InvokeAIRawActionResponse, } from '../../../common/gemini/types'; import { SUB_ACTION, @@ -103,6 +106,12 @@ export class GeminiConnector extends SubActionConnector { schema: InvokeAIActionParamsSchema, }); + this.registerSubAction({ + name: SUB_ACTION.INVOKE_AI_RAW, + method: 'invokeAIRaw', + schema: InvokeAIRawActionParamsSchema, + }); + this.registerSubAction({ name: SUB_ACTION.INVOKE_STREAM, method: 'invokeStream', @@ -193,7 +202,8 @@ export class GeminiConnector extends SubActionConnector { model: reqModel, signal, timeout, - }: RunActionParams): Promise { + raw, + }: RunActionParams): Promise { // set model on per request basis const currentModel = reqModel ?? this.model; const path = `/v1/projects/${this.gcpProjectID}/locations/${this.gcpRegion}/publishers/google/models/${currentModel}:generateContent`; @@ -217,6 +227,10 @@ export class GeminiConnector extends SubActionConnector { const usageMetadata = response.data.usageMetadata; const completionText = candidate.content.parts[0].text; + if (raw) { + return response.data; + } + return { completion: completionText, usageMetadata }; } @@ -264,6 +278,24 @@ export class GeminiConnector extends SubActionConnector { return { message: res.completion, usageMetadata: res.usageMetadata }; } + public async invokeAIRaw({ + messages, + model, + temperature = 0, + signal, + timeout, + }: InvokeAIRawActionParams): Promise { + const res = await this.runApi({ + body: JSON.stringify(messages), + model, + signal, + timeout, + raw: true, + }); + + return res; + } + /** * takes in an array of messages and a model as inputs. It calls the streamApi method to make a * request to the Gemini API with the formatted messages and model. It then returns a Transform stream @@ -280,13 +312,22 @@ export class GeminiConnector extends SubActionConnector { signal, timeout, }: InvokeAIActionParams): Promise { - const res = (await this.streamAPI({ - body: JSON.stringify(formatGeminiPayload(messages, temperature)), - model, - stopSequences, - signal, - timeout, - })) as unknown as IncomingMessage; + console.error('invokeStream', JSON.stringify(messages, null, 2)); + let res; + + try { + res = (await this.streamAPI({ + // body: JSON.stringify(formatGeminiPayload(messages, temperature)), + body: JSON.stringify(messages), + model, + stopSequences, + signal, + timeout, + })) as unknown as IncomingMessage; + } catch (e) { + console.error('eee', e); + } + return res; } } diff --git a/yarn.lock b/yarn.lock index f6001a7038ba2..273c47573c7d3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6986,12 +6986,12 @@ resolved "https://registry.yarnpkg.com/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz#8ace5259254426ccef57f3175bc64ed7095ed919" integrity sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw== -"@langchain/community@^0.2.18": - version "0.2.18" - resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.18.tgz#127a7ac53a30dd6dedede887811fdd992061e2d2" - integrity sha512-UsCB97dMG87giQLniKx4bjv7OnMw2vQeavSt9gqOnGCnfb5IQBAgdjX4SjwFPbVGMz1HQoQKVlNqQ64ozCdgNg== +"@langchain/community@^0.2.19": + version "0.2.19" + resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.19.tgz#ca1fb64e57f94216d6d05aa9362950d4c5473bc3" + integrity sha512-NKUOFW7ykY+WcnxEV6MZJj1hKncogdloBGDsk5zfW/FkZtQQpSHTgA8bgAT7X4Bnr5+Cv1fLkiDtVs/yKI4/Ow== dependencies: - "@langchain/core" "~0.2.11" + "@langchain/core" ">=0.2.16 <0.3.0" "@langchain/openai" "~0.1.0" binary-extensions "^2.2.0" expr-eval "^2.0.2" @@ -7003,10 +7003,10 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>0.1.61 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.15", "@langchain/core@~0.2.11": - version "0.2.15" - resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.15.tgz#1bb99ac4fffe935c7ba37edcaa91abfba3c82219" - integrity sha512-L096itIBQ5XNsy5BCCPqIQEk/x4rzI+U4BhYT+fDBYtljESshIi/WzXdmiGfY/6MpVjB76jNuaRgMDmo1m9NeQ== +"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.16": + version "0.2.16" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.16.tgz#0700a7def44c613ef775d351a2e6428e09cfbfda" + integrity sha512-mPmQi0ecJ81QwhvUQX4cwGVAqsM30ly3ygIlWoeUwDOXv9UW/IB2LAq8KKoVYIHTyEsIWJiyMP9Sv3e0xwjV8g== dependencies: ansi-styles "^5.0.0" camelcase "6" @@ -7021,13 +7021,23 @@ zod "^3.22.4" zod-to-json-schema "^3.22.3" -"@langchain/langgraph@^0.0.26": - version "0.0.26" - resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.26.tgz#4971577366f7ec4ff1c634f0ab1aa9f0ec73fb0b" - integrity sha512-edfy9C9e5E1/HvWY9Jk4P3kF3RToA7OIVj9aXjCCb/8dvoMcmMsQSB990Uc29H+3lSpxHWrzSaZVQCAxD6XAAg== +"@langchain/google-common@^0.0.20": + version "0.0.20" + resolved "https://registry.yarnpkg.com/@langchain/google-common/-/google-common-0.0.20.tgz#00d8e9b8f346c986366e199d4aaf19be1f97f1fd" + integrity sha512-kH1Bwh1tKxzIU+IFhOLLxuY7GjYjO+iebd3Gaih3smtQNldMidrYO2CRYtesnvD9AKJxvforU7neeux39fysoA== dependencies: - "@langchain/core" ">0.1.61 <0.3.0" + "@langchain/core" ">=0.2.16 <0.3.0" uuid "^10.0.0" + zod-to-json-schema "^3.22.4" + +"@langchain/langgraph@^0.0.27": + version "0.0.27" + resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.27.tgz#ae7df3f838e6cecd5be5c23071eafce102aa8942" + integrity sha512-+E5VOfIDUS9Rbv2Ut1osQ9Dy5IXPiIT8KasLyEToFbBN7KlzajIC2sm29he5aiR3I3KrKc6AburjPUUw+hw35A== + dependencies: + "@langchain/core" ">=0.2.16 <0.3.0" + uuid "^10.0.0" + zod "^3.23.8" "@langchain/openai@>=0.1.0 <0.3.0", "@langchain/openai@^0.1.3", "@langchain/openai@~0.1.0": version "0.1.3" @@ -21861,10 +21871,10 @@ kuler@^2.0.0: resolved "https://registry.yarnpkg.com/kuler/-/kuler-2.0.0.tgz#e2c570a3800388fb44407e851531c1d670b061b3" integrity sha512-Xq9nH7KlWZmXAtodXDDRE7vs6DU1gTU8zYDHDiWLSip45Egwq3plLHzPn27NgvzL2r1LMPC1vdqh98sQxtqj4A== -langchain@0.2.3, langchain@^0.2.9: - version "0.2.9" - resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.2.9.tgz#1341bdd7166f4f6da0b9337f363e409a79523dbb" - integrity sha512-iZ0l7BDVfoifqZlDl1gy3JP5mIdhYjWiToPlDnlmfHD748cw3okvF0gZo0ruT4nbftnQcaM7JzPUiNC43UPfgg== +langchain@0.2.3, langchain@^0.2.10: + version "0.2.10" + resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.2.10.tgz#35b74038e54650efbd9fe7d9d59765fe2790bb47" + integrity sha512-i0fC+RlX/6w6HKPWL3N5zrhrkijvpe2Xu4t/qbWzq4uFf8WBfPwmNFom3RtO2RatuPnHLm8mViU6nw8YBDiVwA== dependencies: "@langchain/core" ">=0.2.11 <0.3.0" "@langchain/openai" ">=0.1.0 <0.3.0" @@ -21888,7 +21898,18 @@ langchainhub@~0.0.8: resolved "https://registry.yarnpkg.com/langchainhub/-/langchainhub-0.0.8.tgz#fd4b96dc795e22e36c1a20bad31b61b0c33d3110" integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== -langsmith@^0.1.36, langsmith@~0.1.30: +langsmith@^0.1.37: + version "0.1.37" + resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.37.tgz#7e7bf6fce3eab2a9e95221d5879820ec29d0aa60" + integrity sha512-8JgWykdJywdKWs+WeefOEf4Gkz3YdNkvG5u5JPbgXuodTUwuHPwjmblsldt1OGKkPp7iCWfdtCdnc9z9MYC/Dw== + dependencies: + "@types/uuid" "^9.0.1" + commander "^10.0.1" + p-queue "^6.6.2" + p-retry "4" + uuid "^9.0.0" + +langsmith@~0.1.30: version "0.1.36" resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.36.tgz#5f21b9c6bcd4ea9c0e943f83e304a53e5232297d" integrity sha512-D5hhkFl31uxFdffx0lA6pin0lt8Pv2dpHFZYpSgEzvQ26PQ/Y/tnniQ+aCNokIXuLhMa7uqLtb6tfwjfiZXgdg== @@ -32828,7 +32849,7 @@ zip-stream@^4.1.0: compress-commons "^4.1.0" readable-stream "^3.6.0" -zod-to-json-schema@^3.22.3, zod-to-json-schema@^3.22.5: +zod-to-json-schema@^3.22.3, zod-to-json-schema@^3.22.4, zod-to-json-schema@^3.22.5: version "3.22.5" resolved "https://registry.yarnpkg.com/zod-to-json-schema/-/zod-to-json-schema-3.22.5.tgz#3646e81cfc318dbad2a22519e5ce661615418673" integrity sha512-+akaPo6a0zpVCCseDed504KBJUQpEW5QZw7RMneNmKw+fGaML1Z9tUNLnHHAC8x6dzVRO1eB2oEMyZRnuBZg7Q== @@ -32838,6 +32859,11 @@ zod@3.22.4, zod@^3.22.3, zod@^3.22.4: resolved "https://registry.yarnpkg.com/zod/-/zod-3.22.4.tgz#f31c3a9386f61b1f228af56faa9255e845cf3fff" integrity sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg== +zod@^3.23.8: + version "3.23.8" + resolved "https://registry.yarnpkg.com/zod/-/zod-3.23.8.tgz#e37b957b5d52079769fb8097099b592f0ef4067d" + integrity sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g== + zwitch@^1.0.0: version "1.0.5" resolved "https://registry.yarnpkg.com/zwitch/-/zwitch-1.0.5.tgz#d11d7381ffed16b742f6af7b3f223d5cd9fe9920" From ab87b848cc807910dd9032c0e0b30e2cd72729f5 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Tue, 16 Jul 2024 19:52:25 +0200 Subject: [PATCH 25/55] fix bedrock streaming --- .../kbn-langchain/server/language_models/bedrock_chat.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 0a329ba41e03f..c78db1231c73e 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -37,7 +37,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { fetchFn: async (url, options) => { const inputBody = JSON.parse(options?.body as string); - if (this.streaming && !inputBody.tools?.length) { + if (this.streaming) { const data = (await actionsClient.execute({ actionId: connectorId, params: { From 19bbc9c5a98320366c7e27faf7a12076ead8bdb1 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Tue, 16 Jul 2024 20:29:36 +0200 Subject: [PATCH 26/55] fix bedrock graph streaming --- package.json | 2 +- .../server/language_models/bedrock_chat.ts | 2 +- yarn.lock | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/package.json b/package.json index 2aa5837d706b2..062752cae3cfd 100644 --- a/package.json +++ b/package.json @@ -938,7 +938,7 @@ "@kbn/watcher-plugin": "link:x-pack/plugins/watcher", "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", - "@langchain/community": "^0.2.19", + "@langchain/community": "0.2.18", "@langchain/core": "^0.2.16", "@langchain/google-common": "^0.0.20", "@langchain/langgraph": "^0.0.27", diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index c78db1231c73e..0a329ba41e03f 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -37,7 +37,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { fetchFn: async (url, options) => { const inputBody = JSON.parse(options?.body as string); - if (this.streaming) { + if (this.streaming && !inputBody.tools?.length) { const data = (await actionsClient.execute({ actionId: connectorId, params: { diff --git a/yarn.lock b/yarn.lock index 697e62422c0e2..92602b03ee1e8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6986,12 +6986,12 @@ resolved "https://registry.yarnpkg.com/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz#8ace5259254426ccef57f3175bc64ed7095ed919" integrity sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw== -"@langchain/community@^0.2.19": - version "0.2.19" - resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.19.tgz#ca1fb64e57f94216d6d05aa9362950d4c5473bc3" - integrity sha512-NKUOFW7ykY+WcnxEV6MZJj1hKncogdloBGDsk5zfW/FkZtQQpSHTgA8bgAT7X4Bnr5+Cv1fLkiDtVs/yKI4/Ow== +"@langchain/community@0.2.18": + version "0.2.18" + resolved "https://registry.yarnpkg.com/@langchain/community/-/community-0.2.18.tgz#127a7ac53a30dd6dedede887811fdd992061e2d2" + integrity sha512-UsCB97dMG87giQLniKx4bjv7OnMw2vQeavSt9gqOnGCnfb5IQBAgdjX4SjwFPbVGMz1HQoQKVlNqQ64ozCdgNg== dependencies: - "@langchain/core" ">=0.2.16 <0.3.0" + "@langchain/core" "~0.2.11" "@langchain/openai" "~0.1.0" binary-extensions "^2.2.0" expr-eval "^2.0.2" @@ -7003,7 +7003,7 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.16": +"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.16", "@langchain/core@~0.2.11": version "0.2.16" resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.16.tgz#0700a7def44c613ef775d351a2e6428e09cfbfda" integrity sha512-mPmQi0ecJ81QwhvUQX4cwGVAqsM30ly3ygIlWoeUwDOXv9UW/IB2LAq8KKoVYIHTyEsIWJiyMP9Sv3e0xwjV8g== From f18e641f7bdc8bd9df2044b0bc6b8668f981cde2 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Tue, 16 Jul 2024 20:29:52 +0200 Subject: [PATCH 27/55] fix --- .../langchain/graphs/default_assistant_graph/graph.ts | 9 ++++++++- .../langchain/graphs/default_assistant_graph/index.ts | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 50ad6ba09cc8c..ad0dbb039dd08 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -47,6 +47,8 @@ interface GetDefaultAssistantGraphParams { tools: StructuredTool[]; responseLanguage: string; replacements: Replacements; + llmType: string | undefined; + bedrockChatEnabled?: boolean; } export type DefaultAssistantGraph = ReturnType; @@ -64,6 +66,8 @@ export const getDefaultAssistantGraph = ({ responseLanguage, tools, replacements, + llmType, + bedrockChatEnabled, }: GetDefaultAssistantGraphParams) => { try { // Default graph state @@ -193,7 +197,10 @@ export const getDefaultAssistantGraph = ({ // Add conditional edge for basic routing graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, { continue: TOOLS_NODE, - end: RESPOND_NODE, + end: + llmType && bedrockChatEnabled && ['bedrock', 'gemini'].includes(llmType) + ? RESPOND_NODE + : END, }); graph.addEdge(RESPOND_NODE, END); graph.addEdge(TOOLS_NODE, AGENT_NODE); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 425834ea516f2..87d7bae47eb43 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -151,6 +151,8 @@ export const callAssistantGraph: AgentExecutor = async ({ tools, responseLanguage, replacements, + llmType, + bedrockChatEnabled, }); const inputs = { input: latestMessage[0]?.content as string }; From b6ecc5a3874e99926b4746a7ef90fe1bb07c1ed4 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Wed, 17 Jul 2024 17:28:24 +0200 Subject: [PATCH 28/55] switch to ActionsClientGeminiChatModel --- package.json | 3 +- x-pack/packages/kbn-langchain/server/index.ts | 4 +- .../server/language_models/gemini_chat.ts | 384 ++++++++++++++++++ .../server/language_models/index.ts | 1 + .../server/language_models/vertex_chat.ts | 76 ---- .../server/lib/gen_ai_token_tracking.ts | 36 +- .../graphs/default_assistant_graph/graph.ts | 17 +- .../graphs/default_assistant_graph/helpers.ts | 7 +- .../graphs/default_assistant_graph/index.ts | 2 + .../elastic_assistant/server/routes/utils.ts | 4 +- .../stack_connectors/common/gemini/schema.ts | 4 +- .../stack_connectors/common/gemini/types.ts | 1 + .../server/connector_types/gemini/gemini.ts | 36 +- yarn.lock | 22 +- 14 files changed, 461 insertions(+), 136 deletions(-) create mode 100644 x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts delete mode 100644 x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts diff --git a/package.json b/package.json index 062752cae3cfd..9ae002a2dcfb9 100644 --- a/package.json +++ b/package.json @@ -941,7 +941,8 @@ "@langchain/community": "0.2.18", "@langchain/core": "^0.2.16", "@langchain/google-common": "^0.0.20", - "@langchain/langgraph": "^0.0.27", + "@langchain/google-genai": "^0.0.22", + "@langchain/langgraph": "^0.0.28", "@langchain/openai": "^0.1.3", "@langtrase/trace-attributes": "^3.0.8", "@launchdarkly/node-server-sdk": "^9.4.7", diff --git a/x-pack/packages/kbn-langchain/server/index.ts b/x-pack/packages/kbn-langchain/server/index.ts index a2baf6b45a2b3..7f5691c8f9907 100644 --- a/x-pack/packages/kbn-langchain/server/index.ts +++ b/x-pack/packages/kbn-langchain/server/index.ts @@ -9,7 +9,7 @@ import { ActionsClientBedrockChatModel } from './language_models/bedrock_chat'; import { ActionsClientChatOpenAI } from './language_models/chat_openai'; import { ActionsClientLlm } from './language_models/llm'; import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model'; -import { ActionsClientVertexChatModel } from './language_models/vertex_chat'; +import { ActionsClientGeminiChatModel } from './language_models/gemini_chat'; import { parseBedrockStream } from './utils/bedrock'; import { parseGeminiResponse } from './utils/gemini'; import { getDefaultArguments } from './language_models/constants'; @@ -20,7 +20,7 @@ export { getDefaultArguments, ActionsClientBedrockChatModel, ActionsClientChatOpenAI, + ActionsClientGeminiChatModel, ActionsClientLlm, ActionsClientSimpleChatModel, - ActionsClientVertexChatModel, }; diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts new file mode 100644 index 0000000000000..0b40a34779ac4 --- /dev/null +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -0,0 +1,384 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + Content, + Part, + FunctionCallPart, + FunctionResponsePart, + POSSIBLE_ROLES, + EnhancedGenerateContentResponse, +} from '@google/generative-ai'; +import { ActionsClient } from '@kbn/actions-plugin/server'; +import { PublicMethodsOf } from '@kbn/utility-types'; +import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; +import { ToolCallChunk } from '@langchain/core/dist/messages/tool'; +import { + AIMessageChunk, + BaseMessage, + ChatMessage, + isBaseMessage, + UsageMetadata, +} from '@langchain/core/messages'; +import { ChatGenerationChunk } from '@langchain/core/outputs'; +import { ChatGoogleGenerativeAI } from '@langchain/google-genai'; + +export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { + #actionsClient: PublicMethodsOf; + #connectorId: string; + + constructor({ actionsClient, connectorId, ...props }) { + super({ + ...props, + apiKey: 'asda', + temperature: 0, + }); + + this.#actionsClient = actionsClient; + this.#connectorId = connectorId; + this.apiKey = 'sadd'; + } + + async completionWithRetry( + request: string | GenerateContentRequest | Array, + options?: this['ParsedCallOptions'] + ) { + return this.caller.callWithOptions({ signal: options?.signal }, async () => { + try { + // console.error('requses', request); + const requestBody = { + actionId: 'my-gemini-ai' || this.#connectorId, + params: { + subAction: 'invokeAIRaw', + subActionParams: { + model: 'gemini-1.5-pro-preview-0409' || this.model, + messages: request, + }, + }, + }; + + const actionResult = await this.#actionsClient.execute(requestBody); + + return { + response: { + ...actionResult.data, + functionCalls: () => + actionResult.data?.candidates?.[0]?.content.parts[0].functionCall + ? [actionResult.data?.candidates?.[0]?.content.parts[0].functionCall] + : null, + }, + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + // TODO: Improve error handling + if (e.message?.includes('400 Bad Request')) { + e.status = 400; + } + throw e; + } + }); + } + + async *_streamResponseChunks( + messages: BaseMessage[], + options: this['ParsedCallOptions'], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel); + const parameters = this.invocationParams(options); + const request = { + ...parameters, + contents: prompt, + }; + const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => { + const requestBody = { + actionId: 'my-gemini-ai' || this.#connectorId, + params: { + subAction: 'invokeStream', + subActionParams: { + model: 'gemini-1.5-pro-preview-0409' || this.model, + messages: request, + }, + }, + }; + + const actionResult = await this.#actionsClient.execute(requestBody); + + return actionResult.data; + }); + + let usageMetadata: UsageMetadata | undefined; + let index = 0; + for await (const rawStreamChunk of stream) { + const parsedStreamChunk = rawStreamChunk + .toString() + .split('\n') + .filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]')) + .map((line) => JSON.parse(line.replace('data: ', '')))[0]; + + const response = { + ...parsedStreamChunk, + functionCalls: () => + parsedStreamChunk?.candidates?.[0]?.content.parts[0].functionCall + ? [parsedStreamChunk.candidates?.[0]?.content.parts[0].functionCall] + : null, + }; + + if ( + 'usageMetadata' in response && + this.streamUsage !== false && + options.streamUsage !== false + ) { + const genAIUsageMetadata = response.usageMetadata as { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + }; + if (!usageMetadata) { + usageMetadata = { + input_tokens: genAIUsageMetadata.promptTokenCount, + output_tokens: genAIUsageMetadata.candidatesTokenCount, + total_tokens: genAIUsageMetadata.totalTokenCount, + }; + } else { + // Under the hood, LangChain combines the prompt tokens. Google returns the updated + // total each time, so we need to find the difference between the tokens. + const outputTokenDiff = + genAIUsageMetadata.candidatesTokenCount - usageMetadata.output_tokens; + usageMetadata = { + input_tokens: 0, + output_tokens: outputTokenDiff, + total_tokens: outputTokenDiff, + }; + } + } + + const chunk = convertResponseContentToChatGenerationChunk(response, { + usageMetadata, + index, + }); + index += 1; + + if (!chunk) { + continue; + } + + yield chunk; + await runManager?.handleLLMNewToken(chunk.text ?? ''); + } + } +} + +export function convertResponseContentToChatGenerationChunk( + response: EnhancedGenerateContentResponse, + extra: { + usageMetadata?: UsageMetadata | undefined; + index: number; + } +): ChatGenerationChunk | null { + if (!response.candidates || response.candidates.length === 0) { + return null; + } + const functionCalls = response.functionCalls(); + const [candidate] = response.candidates; + const { content, ...generationInfo } = candidate; + const text = content?.parts[0]?.text ?? ''; + + const toolCallChunks: ToolCallChunk[] = []; + if (functionCalls) { + toolCallChunks.push( + ...functionCalls.map((fc) => ({ + ...fc, + args: JSON.stringify(fc.args), + index: extra.index, + type: 'tool_call_chunk' as const, + })) + ); + } + return new ChatGenerationChunk({ + text, + message: new AIMessageChunk({ + content: text, + name: !content ? undefined : content.role, + tool_call_chunks: toolCallChunks, + // Each chunk can have unique "generationInfo", and merging strategy is unclear, + // so leave blank for now. + additional_kwargs: {}, + usage_metadata: extra.usageMetadata, + }), + generationInfo, + }); +} + +export function convertAuthorToRole(author: string): typeof POSSIBLE_ROLES[number] { + switch (author) { + /** + * Note: Gemini currently is not supporting system messages + * we will convert them to human messages and merge with following + * */ + case 'ai': + case 'model': // getMessageAuthor returns message.name. code ex.: return message.name ?? type; + return 'model'; + case 'system': + case 'human': + return 'user'; + case 'tool': + case 'function': + return 'function'; + default: + throw new Error(`Unknown / unsupported author: ${author}`); + } +} +export function convertBaseMessagesToContent(messages: BaseMessage[], isMultimodalModel: boolean) { + return messages.reduce<{ + content: Content[]; + mergeWithPreviousContent: boolean; + }>( + (acc, message, index) => { + if (!isBaseMessage(message)) { + throw new Error('Unsupported message input'); + } + const author = getMessageAuthor(message); + if (author === 'system' && index !== 0) { + throw new Error('System message should be the first one'); + } + const role = convertAuthorToRole(author); + + const prevContent = acc.content[acc.content.length]; + if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) { + throw new Error('Google Generative AI requires alternate messages between authors'); + } + + const parts = convertMessageContentToParts(message, isMultimodalModel); + + if (acc.mergeWithPreviousContent) { + const prevContent = acc.content[acc.content.length - 1]; + if (!prevContent) { + throw new Error( + 'There was a problem parsing your system message. Please try a prompt without one.' + ); + } + prevContent.parts.push(...parts); + + return { + mergeWithPreviousContent: false, + content: acc.content, + }; + } + let actualRole = role; + if (actualRole === 'function') { + // GenerativeAI API will throw an error if the role is not "user" or "model." + actualRole = 'user'; + } + const content: Content = { + role: actualRole, + parts, + }; + return { + mergeWithPreviousContent: author === 'system', + content: [...acc.content, content], + }; + }, + { content: [], mergeWithPreviousContent: false } + ).content; +} + +export function convertMessageContentToParts( + message: BaseMessage, + isMultimodalModel: boolean +): Part[] { + if (typeof message.content === 'string' && message.content !== '') { + return [{ text: message.content }]; + } + + let functionCalls: FunctionCallPart[] = []; + let functionResponses: FunctionResponsePart[] = []; + let messageParts: Part[] = []; + + if ( + 'tool_calls' in message && + Array.isArray(message.tool_calls) && + message.tool_calls.length > 0 + ) { + functionCalls = message.tool_calls.map((tc) => ({ + functionCall: { + name: tc.name, + args: tc.args, + }, + })); + } else if (message._getType() === 'tool' && message.name && message.content) { + functionResponses = [ + { + functionResponse: { + name: message.name, + response: message.content, + }, + }, + ]; + } else if (Array.isArray(message.content)) { + messageParts = message.content.map((c) => { + if (c.type === 'text') { + return { + text: c.text, + }; + } + + if (c.type === 'image_url') { + if (!isMultimodalModel) { + throw new Error(`This model does not support images`); + } + let source; + if (typeof c.image_url === 'string') { + source = c.image_url; + } else if (typeof c.image_url === 'object' && 'url' in c.image_url) { + source = c.image_url.url; + } else { + throw new Error('Please provide image as base64 encoded data URL'); + } + const [dm, data] = source.split(','); + if (!dm.startsWith('data:')) { + throw new Error('Please provide image as base64 encoded data URL'); + } + + const [mimeType, encoding] = dm.replace(/^data:/, '').split(';'); + if (encoding !== 'base64') { + throw new Error('Please provide image as base64 encoded data URL'); + } + + return { + inlineData: { + data, + mimeType, + }, + }; + } else if (c.type === 'media') { + return c; + } else if (c.type === 'tool_use') { + return { + functionCall: { + name: c.name, + args: c.input, + }, + }; + } + throw new Error(`Unknown content type ${(c as { type: string }).type}`); + }); + } + + return [...messageParts, ...functionCalls, ...functionResponses]; +} + +export function getMessageAuthor(message: BaseMessage) { + const type = message._getType(); + if (ChatMessage.isInstance(message)) { + return message.role; + } + if (type === 'tool') { + return type; + } + return message.name ?? type; +} diff --git a/x-pack/packages/kbn-langchain/server/language_models/index.ts b/x-pack/packages/kbn-langchain/server/language_models/index.ts index d2039f098c74e..f5415079cbc11 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/index.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/index.ts @@ -7,5 +7,6 @@ export { ActionsClientBedrockChatModel } from './bedrock_chat'; export { ActionsClientChatOpenAI } from './chat_openai'; +export { ActionsClientGeminiChatModel } from './gemini_chat'; export { ActionsClientLlm } from './llm'; export { ActionsClientSimpleChatModel } from './simple_chat_model'; diff --git a/x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts deleted file mode 100644 index 967790b420b98..0000000000000 --- a/x-pack/packages/kbn-langchain/server/language_models/vertex_chat.ts +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { ActionsClient } from '@kbn/actions-plugin/server'; -import { PublicMethodsOf } from '@kbn/utility-types'; -import { - ChatGoogleBase, - ChatGoogleBaseInput, - GoogleBaseLLMInput, - ReadableJsonStream, -} from '@langchain/google-common'; -import { Readable } from 'stream'; - -export type ChatGoogleInput = ChatGoogleBaseInput<{}>; - -export class ActionsClientVertexChatModel extends ChatGoogleBase<{}> implements ChatGoogleInput { - #actionsClient: PublicMethodsOf; - #connectorId: string; - streaming: boolean; - model: string = ''; - temperature: number = 0; - #maxTokens?: number; - - static lc_name() { - return 'ChatVertexAI'; - } - - constructor({ actionsClient, connectorId, streaming, temperature, model }) { - super({ - // ...fields, - platformType: 'gcp', - }); - - this.#actionsClient = actionsClient; - this.#connectorId = connectorId; - this.model = model; - this.temperature = temperature ?? 0; - this.streaming = streaming; - } - - override buildAbstractedClient(fields: GoogleBaseLLMInput<{}> | undefined) { - return { - request: async (props) => { - // create a new connector request body with the assistant message: - const requestBody = { - actionId: 'my-gemini-ai' || this.#connectorId, - params: { - subAction: this.streaming ? 'invokeStream' : 'invokeAIRaw', - subActionParams: { - model: 'gemini-1.5-pro-preview-0409' || this.model, - messages: props.data, - }, - }, - }; - - const actionResult = await this.#actionsClient.execute(requestBody); - - if (this.streaming) { - return { - data: new ReadableJsonStream( - actionResult.data ? Readable.toWeb(actionResult.data) : null - ), - }; - } - - return actionResult; - }, - getProjectId: () => Promise.resolve(''), - clientType: '', - }; - } -} diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index 41bfa28605f40..5568904c83dc2 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -207,24 +207,24 @@ export const getGenAiTokenTracking = async ({ } // Process non-streamed Gemini response from `usageMetadata` object - if (actionTypeId === '.gemini') { - const data = result.data as unknown as { - usageMetadata: { - promptTokenCount?: number; - candidatesTokenCount?: number; - totalTokenCount?: number; - }; - }; - if (data.usageMetadata == null) { - logger.error('Response did not contain usage metadata object'); - return null; - } - return { - total_tokens: data.usageMetadata?.totalTokenCount ?? 0, - prompt_tokens: data.usageMetadata?.promptTokenCount ?? 0, - completion_tokens: data.usageMetadata?.candidatesTokenCount ?? 0, - }; - } + // if (actionTypeId === '.gemini') { + // const data = result.data as unknown as { + // usageMetadata: { + // promptTokenCount?: number; + // candidatesTokenCount?: number; + // totalTokenCount?: number; + // }; + // }; + // if (data.usageMetadata == null) { + // logger.error('Response did not contain usage metadata object'); + // return null; + // } + // return { + // total_tokens: data.usageMetadata?.totalTokenCount ?? 0, + // prompt_tokens: data.usageMetadata?.promptTokenCount ?? 0, + // completion_tokens: data.usageMetadata?.candidatesTokenCount ?? 0, + // }; + // } // this is a non-streamed Bedrock response used by security solution if (actionTypeId === '.bedrock' && validatedParams.subAction === 'invokeAI') { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index ad0dbb039dd08..d8320624143c8 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -49,6 +49,7 @@ interface GetDefaultAssistantGraphParams { replacements: Replacements; llmType: string | undefined; bedrockChatEnabled?: boolean; + isStreaming: boolean; } export type DefaultAssistantGraph = ReturnType; @@ -68,6 +69,7 @@ export const getDefaultAssistantGraph = ({ replacements, llmType, bedrockChatEnabled, + isStreaming, }: GetDefaultAssistantGraphParams) => { try { // Default graph state @@ -181,7 +183,14 @@ export const getDefaultAssistantGraph = ({ graph.addNode(PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChangesNode); graph.addNode(AGENT_NODE, runAgentNode); graph.addNode(TOOLS_NODE, executeToolsNode); - graph.addNode(RESPOND_NODE, respondNode); + + const hasRespondStep = + isStreaming && llmType && bedrockChatEnabled && ['bedrock'].includes(llmType); + + if (hasRespondStep) { + graph.addNode(RESPOND_NODE, respondNode); + graph.addEdge(RESPOND_NODE, END); + } // Add edges, alternating between agent and action until finished graph.addConditionalEdges(START, shouldContinueGetConversationEdge, { @@ -197,12 +206,8 @@ export const getDefaultAssistantGraph = ({ // Add conditional edge for basic routing graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, { continue: TOOLS_NODE, - end: - llmType && bedrockChatEnabled && ['bedrock', 'gemini'].includes(llmType) - ? RESPOND_NODE - : END, + end: hasRespondStep ? RESPOND_NODE : END, }); - graph.addEdge(RESPOND_NODE, END); graph.addEdge(TOOLS_NODE, AGENT_NODE); // Compile the graph return graph.compile(); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 9a3ec63119b71..b7b8557fe2441 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -82,7 +82,6 @@ export const streamGraph = async ({ streamingSpan?.end(); }; - if ((llmType === 'bedrock' || llmType === 'gemini') && bedrockChatEnabled) { const stream = await assistantGraph.streamEvents( inputs, @@ -92,19 +91,19 @@ export const streamGraph = async ({ tags: traceOptions?.tags ?? [], version: 'v2', }, - { includeNames: ['Summarizer'] } + llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined ); for await (const { event, data } of stream) { if (event === 'on_chat_model_stream') { const msg = data.chunk as AIMessageChunk; - if (!msg.tool_call_chunks?.length) { + if (!didEnd && !msg.tool_call_chunks?.length && msg.content.length) { push({ payload: msg.content, type: 'content' }); } } - if (event === 'on_chat_model_end') { + if (event === 'on_chat_model_end' && !data.output.lc_kwargs?.tool_calls?.length) { handleStreamEnd(data.output.content); } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 87d7bae47eb43..7407854eba060 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -60,6 +60,7 @@ export const callAssistantGraph: AgentExecutor = async ({ connectorId, llmType, logger, + apiKey: '', // possible client model override, // let this be undefined otherwise so the connector handles the model model: request.body.model, @@ -153,6 +154,7 @@ export const callAssistantGraph: AgentExecutor = async ({ replacements, llmType, bedrockChatEnabled, + isStreaming: isStream, }); const inputs = { input: latestMessage[0]?.content as string }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 2948453b4fba0..e163526d996ae 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -17,7 +17,7 @@ import { ActionsClientChatOpenAI, ActionsClientBedrockChatModel, ActionsClientSimpleChatModel, - ActionsClientVertexChatModel, + ActionsClientGeminiChatModel, } from '@kbn/langchain/server'; import { CustomHttpRequestError } from './custom_http_request_error'; @@ -187,5 +187,5 @@ export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) => : llmType === 'bedrock' && bedrockChatEnabled ? ActionsClientBedrockChatModel : llmType === 'gemini' && bedrockChatEnabled - ? ActionsClientVertexChatModel + ? ActionsClientGeminiChatModel : ActionsClientSimpleChatModel; diff --git a/x-pack/plugins/stack_connectors/common/gemini/schema.ts b/x-pack/plugins/stack_connectors/common/gemini/schema.ts index 5a040dd07e17c..fa7d44d1c86f8 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/schema.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/schema.ts @@ -20,7 +20,7 @@ export const SecretsSchema = schema.object({ }); export const RunActionParamsSchema = schema.object({ - body: schema.string(), + body: schema.any(), model: schema.maybe(schema.string()), signal: schema.maybe(schema.any()), timeout: schema.maybe(schema.number()), @@ -53,7 +53,7 @@ export const RunActionResponseSchema = schema.object( { unknowns: 'ignore' } ); -export const RunActionRawResponse = schema.any(); +export const RunActionRawResponseSchema = schema.any(); export const InvokeAIActionParamsSchema = schema.object({ messages: schema.any(), diff --git a/x-pack/plugins/stack_connectors/common/gemini/types.ts b/x-pack/plugins/stack_connectors/common/gemini/types.ts index 206b2423f46ce..52a8c090a002e 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/types.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/types.ts @@ -13,6 +13,7 @@ import { SecretsSchema, RunActionParamsSchema, RunActionResponseSchema, + RunActionRawResponseSchema, RunApiResponseSchema, InvokeAIActionParamsSchema, InvokeAIActionResponseSchema, diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index 6468eccb481a9..817b79b679c1e 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -16,6 +16,7 @@ import { ConnectorTokenClientContract } from '@kbn/actions-plugin/server/types'; import { RunActionParamsSchema, RunApiResponseSchema, + RunActionRawResponseSchema, InvokeAIActionParamsSchema, InvokeAIRawActionParamsSchema, StreamingResponseSchema, @@ -26,6 +27,7 @@ import { Secrets, RunActionParams, RunActionResponse, + RunActionRawResponse, RunApiResponse, DashboardActionParams, DashboardActionResponse, @@ -219,18 +221,19 @@ export class GeminiConnector extends SubActionConnector { }, signal, timeout: timeout ?? DEFAULT_TIMEOUT_MS, - responseSchema: RunApiResponseSchema, + responseSchema: raw ? RunActionRawResponseSchema : RunApiResponseSchema, } as SubActionRequestParams; const response = await this.request(requestArgs); - const candidate = response.data.candidates[0]; - const usageMetadata = response.data.usageMetadata; - const completionText = candidate.content.parts[0].text; if (raw) { return response.data; } + const candidate = response.data.candidates[0]; + const usageMetadata = response.data.usageMetadata; + const completionText = candidate.content.parts[0].text; + return { completion: completionText, usageMetadata }; } @@ -312,23 +315,14 @@ export class GeminiConnector extends SubActionConnector { signal, timeout, }: InvokeAIActionParams): Promise { - console.error('invokeStream', JSON.stringify(messages, null, 2)); - let res; - - try { - res = (await this.streamAPI({ - // body: JSON.stringify(formatGeminiPayload(messages, temperature)), - body: JSON.stringify(messages), - model, - stopSequences, - signal, - timeout, - })) as unknown as IncomingMessage; - } catch (e) { - console.error('eee', e); - } - - return res; + return (await this.streamAPI({ + // body: JSON.stringify(formatGeminiPayload(messages, temperature)), + body: JSON.stringify(messages), + model, + stopSequences, + signal, + timeout, + })) as unknown as IncomingMessage; } } diff --git a/yarn.lock b/yarn.lock index 92602b03ee1e8..e579be93ee0a6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2514,6 +2514,11 @@ resolved "https://registry.yarnpkg.com/@gar/promisify/-/promisify-1.1.3.tgz#555193ab2e3bb3b6adc3d551c9c030d9e860daf6" integrity sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw== +"@google/generative-ai@^0.7.0": + version "0.7.1" + resolved "https://registry.yarnpkg.com/@google/generative-ai/-/generative-ai-0.7.1.tgz#eb187c75080c0706245699dbc06816c830d8c6a7" + integrity sha512-WTjMLLYL/xfA5BW6xAycRPiAX7FNHKAxrid/ayqC1QMam0KAK0NbMeS9Lubw80gVg5xFMLE+H7pw4wdNzTOlxw== + "@grpc/grpc-js@^1.7.1", "@grpc/grpc-js@^1.8.22": version "1.8.22" resolved "https://registry.yarnpkg.com/@grpc/grpc-js/-/grpc-js-1.8.22.tgz#847930c9af46e14df05b57fc12325db140ceff1d" @@ -7030,10 +7035,19 @@ uuid "^10.0.0" zod-to-json-schema "^3.22.4" -"@langchain/langgraph@^0.0.27": - version "0.0.27" - resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.27.tgz#ae7df3f838e6cecd5be5c23071eafce102aa8942" - integrity sha512-+E5VOfIDUS9Rbv2Ut1osQ9Dy5IXPiIT8KasLyEToFbBN7KlzajIC2sm29he5aiR3I3KrKc6AburjPUUw+hw35A== +"@langchain/google-genai@^0.0.22": + version "0.0.22" + resolved "https://registry.yarnpkg.com/@langchain/google-genai/-/google-genai-0.0.22.tgz#520796606f7bdb4b60f9d34ea54a0d3fcaa82930" + integrity sha512-egxPpu+GdYigUYOFGDujKt6ziYZ/ELrTvEZ17TJYUgLyvw/gfzr5lz0hU4CgKKUwI/tIk6ZU5I2uaC4oFDkRSQ== + dependencies: + "@google/generative-ai" "^0.7.0" + "@langchain/core" ">=0.2.16 <0.3.0" + zod-to-json-schema "^3.22.4" + +"@langchain/langgraph@^0.0.28": + version "0.0.28" + resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.28.tgz#d1aea4e55f8b563a4a5938f8469cb18df4c46465" + integrity sha512-tKblc95lKfifjYLcWl0t7Z0kqN2zQoSQgOYNxbuaZZiMfuxZ5NJGAH7PxTpLjWdwRsiMpIBNzS/JZ1vjhQrNvA== dependencies: "@langchain/core" ">=0.2.16 <0.3.0" uuid "^10.0.0" From 37297e3c1301d3e18765dee8a9968be343f13c28 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Wed, 17 Jul 2024 22:29:35 +0200 Subject: [PATCH 29/55] cleanup --- package.json | 1 - .../langchain/graphs/default_assistant_graph/graph.ts | 3 +-- yarn.lock | 9 --------- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/package.json b/package.json index 1b9b2ab5f3f1a..dcb72c8595967 100644 --- a/package.json +++ b/package.json @@ -941,7 +941,6 @@ "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", "@langchain/community": "0.2.18", "@langchain/core": "^0.2.16", - "@langchain/google-common": "^0.0.20", "@langchain/google-genai": "^0.0.22", "@langchain/langgraph": "^0.0.28", "@langchain/openai": "^0.1.3", diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index d8320624143c8..2a20dbc9be866 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -184,8 +184,7 @@ export const getDefaultAssistantGraph = ({ graph.addNode(AGENT_NODE, runAgentNode); graph.addNode(TOOLS_NODE, executeToolsNode); - const hasRespondStep = - isStreaming && llmType && bedrockChatEnabled && ['bedrock'].includes(llmType); + const hasRespondStep = isStreaming && bedrockChatEnabled && llmType === 'bedrock'; if (hasRespondStep) { graph.addNode(RESPOND_NODE, respondNode); diff --git a/yarn.lock b/yarn.lock index e19f398e411a0..7d98be783af5c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7064,15 +7064,6 @@ zod "^3.22.4" zod-to-json-schema "^3.22.3" -"@langchain/google-common@^0.0.20": - version "0.0.20" - resolved "https://registry.yarnpkg.com/@langchain/google-common/-/google-common-0.0.20.tgz#00d8e9b8f346c986366e199d4aaf19be1f97f1fd" - integrity sha512-kH1Bwh1tKxzIU+IFhOLLxuY7GjYjO+iebd3Gaih3smtQNldMidrYO2CRYtesnvD9AKJxvforU7neeux39fysoA== - dependencies: - "@langchain/core" ">=0.2.16 <0.3.0" - uuid "^10.0.0" - zod-to-json-schema "^3.22.4" - "@langchain/google-genai@^0.0.22": version "0.0.22" resolved "https://registry.yarnpkg.com/@langchain/google-genai/-/google-genai-0.0.22.tgz#520796606f7bdb4b60f9d34ea54a0d3fcaa82930" From 5fbced682a8f6352c693c462264592a519d37092 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Thu, 18 Jul 2024 01:22:36 +0200 Subject: [PATCH 30/55] fix? gemini streaming --- .../server/language_models/gemini_chat.ts | 42 +++++++++++++++---- .../nodes/run_agent.ts | 1 + .../stack_connectors/common/gemini/schema.ts | 1 + .../server/connector_types/gemini/gemini.ts | 37 +++++++++------- 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 0b40a34779ac4..e30ff96bb5e4d 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -35,7 +35,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { super({ ...props, apiKey: 'asda', - temperature: 0, + maxOutputTokens: 2048, }); this.#actionsClient = actionsClient; @@ -94,6 +94,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { ...parameters, contents: prompt, }; + const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => { const requestBody = { actionId: 'my-gemini-ai' || this.#connectorId, @@ -101,24 +102,51 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { subAction: 'invokeStream', subActionParams: { model: 'gemini-1.5-pro-preview-0409' || this.model, - messages: request, + messages: request.contents.reduce((acc, item) => { + if (!acc?.length) { + acc.push(item); + return acc; + } + + if (acc[acc.length - 1].role === item.role) { + acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(item.parts); + return acc; + } + + acc.push(item); + return acc; + }, []), + tools: request.tools, }, }, }; const actionResult = await this.#actionsClient.execute(requestBody); + if (actionResult.status === 'error') { + throw new Error(actionResult.serviceMessage); + } + return actionResult.data; }); let usageMetadata: UsageMetadata | undefined; let index = 0; + let partialStreamChunk = ''; for await (const rawStreamChunk of stream) { - const parsedStreamChunk = rawStreamChunk - .toString() - .split('\n') - .filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]')) - .map((line) => JSON.parse(line.replace('data: ', '')))[0]; + const streamChunk = rawStreamChunk.toString(); + + const nextChunk = `${partialStreamChunk + streamChunk}`; + + let parsedStreamChunk; + try { + parsedStreamChunk = JSON.parse(nextChunk.replaceAll('data: ', '').replaceAll('\r\n', '')); + partialStreamChunk = ''; + } catch (_) { + partialStreamChunk += nextChunk; + } + + if (!parsedStreamChunk || parsedStreamChunk.candidates[0].finishReason) continue; const response = { ...parsedStreamChunk, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts index aeca4dca21ea6..5cadc29139711 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts @@ -49,6 +49,7 @@ export const runAgent = async ({ const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke( { ...state, + messages: state.messages.splice(-1), chat_history: state.messages, // TODO: Message de-dupe with ...state spread knowledge_history: JSON.stringify(knowledgeHistory?.length ? knowledgeHistory : NO_HISTORY), }, diff --git a/x-pack/plugins/stack_connectors/common/gemini/schema.ts b/x-pack/plugins/stack_connectors/common/gemini/schema.ts index fa7d44d1c86f8..6cccea399e957 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/schema.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/schema.ts @@ -62,6 +62,7 @@ export const InvokeAIActionParamsSchema = schema.object({ stopSequences: schema.maybe(schema.arrayOf(schema.string())), signal: schema.maybe(schema.any()), timeout: schema.maybe(schema.number()), + tools: schema.maybe(schema.arrayOf(schema.any())), }); export const InvokeAIRawActionParamsSchema = schema.object({ diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index 817b79b679c1e..19a2dc0ceadca 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -314,10 +314,10 @@ export class GeminiConnector extends SubActionConnector { temperature = 0, signal, timeout, + tools, }: InvokeAIActionParams): Promise { return (await this.streamAPI({ - // body: JSON.stringify(formatGeminiPayload(messages, temperature)), - body: JSON.stringify(messages), + body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }), model, stopSequences, signal, @@ -342,21 +342,26 @@ const formatGeminiPayload = ( for (const row of data) { const correctRole = row.role === 'assistant' ? 'model' : 'user'; - if (correctRole === 'user' && previousRole === 'user') { - /** Append to the previous 'user' content - * This is to ensure that multiturn requests alternate between user and model - */ - payload.contents[payload.contents.length - 1].parts[0].text += ` ${row.content}`; + // if data is already preformatted by ActionsClientGeminiChatModel + if (row.parts) { + payload.contents.push(row); } else { - // Add a new entry - payload.contents.push({ - role: correctRole, - parts: [ - { - text: row.content, - }, - ], - }); + if (correctRole === 'user' && previousRole === 'user') { + /** Append to the previous 'user' content + * This is to ensure that multiturn requests alternate between user and model + */ + payload.contents[payload.contents.length - 1].parts[0].text += ` ${row.content}`; + } else { + // Add a new entry + payload.contents.push({ + role: correctRole, + parts: [ + { + text: row.content, + }, + ], + }); + } } previousRole = correctRole; } From 47fa2c74122604160c5bec52d5751d8a7fe275db Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 18 Jul 2024 09:57:38 -0500 Subject: [PATCH 31/55] only one final answer --- .../lib/langchain/graphs/default_assistant_graph/helpers.ts | 2 +- .../lib/langchain/graphs/default_assistant_graph/index.ts | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index b7b8557fe2441..48567b54c07f1 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -103,7 +103,7 @@ export const streamGraph = async ({ } } - if (event === 'on_chat_model_end' && !data.output.lc_kwargs?.tool_calls?.length) { + if (event === 'on_chat_model_end' && !data.output.lc_kwargs?.tool_calls?.length && !didEnd) { handleStreamEnd(data.output.content); } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 7407854eba060..46e5216385f82 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -126,7 +126,10 @@ export const callAssistantGraph: AgentExecutor = async ({ llm, tools, prompt: ChatPromptTemplate.fromMessages([ - ['system', 'You are a helpful assistant'], + [ + 'system', + "You are a helpful assistant. Use the available tools to answer the user's question", + ], ['placeholder', '{chat_history}'], ['human', '{input}'], ['placeholder', '{agent_scratchpad}'], From 72658d8de7f82cdee1bbc38e9f3f3f64bf7a3539 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Thu, 18 Jul 2024 20:04:52 +0200 Subject: [PATCH 32/55] WIP --- package.json | 8 ++++---- yarn.lock | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/package.json b/package.json index dcb72c8595967..5ce1395849677 100644 --- a/package.json +++ b/package.json @@ -80,7 +80,7 @@ "resolutions": { "**/@bazel/typescript/protobufjs": "6.11.4", "**/@hello-pangea/dnd": "16.6.0", - "**/@langchain/core": "^0.2.16", + "**/@langchain/core": "^0.2.17", "**/@types/node": "20.10.5", "**/@typescript-eslint/utils": "5.62.0", "**/chokidar": "^3.5.3", @@ -940,9 +940,9 @@ "@kbn/xstate-utils": "link:packages/kbn-xstate-utils", "@kbn/zod-helpers": "link:packages/kbn-zod-helpers", "@langchain/community": "0.2.18", - "@langchain/core": "^0.2.16", - "@langchain/google-genai": "^0.0.22", - "@langchain/langgraph": "^0.0.28", + "@langchain/core": "^0.2.17", + "@langchain/google-genai": "^0.0.23", + "@langchain/langgraph": "^0.0.29", "@langchain/openai": "^0.1.3", "@langtrase/trace-attributes": "^3.0.8", "@launchdarkly/node-server-sdk": "^9.4.7", diff --git a/yarn.lock b/yarn.lock index 7d98be783af5c..b82a8d2e5e1e9 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7046,10 +7046,10 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.16", "@langchain/core@~0.2.11": - version "0.2.16" - resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.16.tgz#0700a7def44c613ef775d351a2e6428e09cfbfda" - integrity sha512-mPmQi0ecJ81QwhvUQX4cwGVAqsM30ly3ygIlWoeUwDOXv9UW/IB2LAq8KKoVYIHTyEsIWJiyMP9Sv3e0xwjV8g== +"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.17", "@langchain/core@~0.2.11": + version "0.2.17" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.17.tgz#dfd44a2ccf79cef88ba765741a1c277bc22e483f" + integrity sha512-WnFiZ7R/ZUVeHO2IgcSL7Tu+CjApa26Iy99THJP5fax/NF8UQCc/ZRcw2Sb/RUuRPVm6ALDass0fSQE1L9YNJg== dependencies: ansi-styles "^5.0.0" camelcase "6" @@ -7064,19 +7064,19 @@ zod "^3.22.4" zod-to-json-schema "^3.22.3" -"@langchain/google-genai@^0.0.22": - version "0.0.22" - resolved "https://registry.yarnpkg.com/@langchain/google-genai/-/google-genai-0.0.22.tgz#520796606f7bdb4b60f9d34ea54a0d3fcaa82930" - integrity sha512-egxPpu+GdYigUYOFGDujKt6ziYZ/ELrTvEZ17TJYUgLyvw/gfzr5lz0hU4CgKKUwI/tIk6ZU5I2uaC4oFDkRSQ== +"@langchain/google-genai@^0.0.23": + version "0.0.23" + resolved "https://registry.yarnpkg.com/@langchain/google-genai/-/google-genai-0.0.23.tgz#e73af501bc1df4c7642b531759b82dc3eb7ae459" + integrity sha512-MTSCJEoKsfU1inz0PWvAjITdNFM4s41uvBCwLpcgx3jWJIEisczFD82x86ahYqJlb2fD6tohYSaCH/4tKAdkXA== dependencies: "@google/generative-ai" "^0.7.0" "@langchain/core" ">=0.2.16 <0.3.0" zod-to-json-schema "^3.22.4" -"@langchain/langgraph@^0.0.28": - version "0.0.28" - resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.28.tgz#d1aea4e55f8b563a4a5938f8469cb18df4c46465" - integrity sha512-tKblc95lKfifjYLcWl0t7Z0kqN2zQoSQgOYNxbuaZZiMfuxZ5NJGAH7PxTpLjWdwRsiMpIBNzS/JZ1vjhQrNvA== +"@langchain/langgraph@^0.0.29": + version "0.0.29" + resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.29.tgz#eda31d101e7a75981e0929661c41ab2461ff8640" + integrity sha512-BSFFJarkXqrMdH9yH6AIiBCw4ww0VsXXpBwqaw+9/7iulW0pBFRSkWXHjEYnmsdCRgyIxoP8vYQAQ8Jtu3qzZA== dependencies: "@langchain/core" ">=0.2.16 <0.3.0" uuid "^10.0.0" From f38f83c1b2e2bcbca59cec14509dc2e8040cccd9 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 18 Jul 2024 14:45:54 -0500 Subject: [PATCH 33/55] fix double response from Gemini --- .../graphs/default_assistant_graph/helpers.ts | 23 ++++++++++++------- .../graphs/default_assistant_graph/index.ts | 3 ++- .../default_assistant_graph/nodes/respond.ts | 4 +++- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index bf86dd0c85dbb..c7e98100c63f3 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -94,17 +94,23 @@ export const streamGraph = async ({ llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined ); - for await (const { event, data } of stream) { - if (event === 'on_chat_model_stream') { - const msg = data.chunk as AIMessageChunk; + for await (const { event, data, tags } of stream) { + if ((tags || []).includes(AGENT_NODE_TAG)) { + if (event === 'on_chat_model_stream') { + const msg = data.chunk as AIMessageChunk; - if (!didEnd && !msg.tool_call_chunks?.length && msg.content.length) { - push({ payload: msg.content, type: 'content' }); + if (!didEnd && !msg.tool_call_chunks?.length && msg.content.length) { + push({ payload: msg.content, type: 'content' }); + } } - } - if (event === 'on_chat_model_end' && !data.output.lc_kwargs?.tool_calls?.length && !didEnd) { - handleStreamEnd(data.output.content); + if ( + event === 'on_chat_model_end' && + !data.output.lc_kwargs?.tool_calls?.length && + !didEnd + ) { + handleStreamEnd(data.output.content); + } } } return responseWithHeaders; @@ -133,6 +139,7 @@ export const streamGraph = async ({ const event = value; // only process events that are part of the agent run if ((event.tags || []).includes(AGENT_NODE_TAG)) { + console.log('stephhh openai event', JSON.stringify(event, null, 2)); if (event.name === 'ActionsClientChatOpenAI') { if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index e0e0fbfb0d280..22653625e29d7 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -130,7 +130,8 @@ export const callAssistantGraph: AgentExecutor = async ({ prompt: ChatPromptTemplate.fromMessages([ [ 'system', - "You are a helpful assistant. Use the available tools to answer the user's question", + 'You are a helpful assistant. ALWAYS use the provided tools.\n\n' + + `The final response will be the only output the user sees and should be a complete answer to the user's question.`, ], ['placeholder', '{chat_history}'], ['human', '{input}'], diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts index 68c82ab9482b0..1e4e0016734f6 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -6,6 +6,7 @@ */ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { AGENT_NODE_TAG } from './run_agent'; import { AgentState } from '../types'; export const RESPOND_NODE = 'respond'; @@ -24,7 +25,8 @@ export const respond = async ({ llm, state }: { llm: BaseChatModel; state: Agent // console.error('userMessage', userMessage); const responseMessage = await llm // .bindTools([]) - .withConfig({ runName: 'Summarizer' }) + // use AGENT_NODE_TAG to identify as agent node for stream parsing + .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] }) .invoke([userMessage]); return { From 5b4f3777ab37190bac13aaad66832dc1e2900218 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 18 Jul 2024 15:51:40 -0500 Subject: [PATCH 34/55] revert log --- .../lib/langchain/graphs/default_assistant_graph/helpers.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index c7e98100c63f3..50b386ce0bec8 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -139,7 +139,6 @@ export const streamGraph = async ({ const event = value; // only process events that are part of the agent run if ((event.tags || []).includes(AGENT_NODE_TAG)) { - console.log('stephhh openai event', JSON.stringify(event, null, 2)); if (event.name === 'ActionsClientChatOpenAI') { if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; From 11e4cb686d1693027b5855fb18c7c31d6f6cee70 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 22 Jul 2024 15:17:58 +0200 Subject: [PATCH 35/55] fix deduplicates --- yarn.lock | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/yarn.lock b/yarn.lock index 757579da03f1f..4ccbd4e474343 100644 --- a/yarn.lock +++ b/yarn.lock @@ -21883,7 +21883,7 @@ langchainhub@~0.0.8: resolved "https://registry.yarnpkg.com/langchainhub/-/langchainhub-0.0.8.tgz#fd4b96dc795e22e36c1a20bad31b61b0c33d3110" integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== -langsmith@^0.1.37: +langsmith@^0.1.37, langsmith@~0.1.30: version "0.1.37" resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.37.tgz#7e7bf6fce3eab2a9e95221d5879820ec29d0aa60" integrity sha512-8JgWykdJywdKWs+WeefOEf4Gkz3YdNkvG5u5JPbgXuodTUwuHPwjmblsldt1OGKkPp7iCWfdtCdnc9z9MYC/Dw== @@ -21894,17 +21894,6 @@ langsmith@^0.1.37: p-retry "4" uuid "^9.0.0" -langsmith@~0.1.30: - version "0.1.36" - resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.36.tgz#5f21b9c6bcd4ea9c0e943f83e304a53e5232297d" - integrity sha512-D5hhkFl31uxFdffx0lA6pin0lt8Pv2dpHFZYpSgEzvQ26PQ/Y/tnniQ+aCNokIXuLhMa7uqLtb6tfwjfiZXgdg== - dependencies: - "@types/uuid" "^9.0.1" - commander "^10.0.1" - p-queue "^6.6.2" - p-retry "4" - uuid "^9.0.0" - language-subtag-registry@~0.3.2: version "0.3.21" resolved "https://registry.yarnpkg.com/language-subtag-registry/-/language-subtag-registry-0.3.21.tgz#04ac218bea46f04cb039084602c6da9e788dd45a" @@ -32738,22 +32727,17 @@ zip-stream@^4.1.0: compress-commons "^4.1.0" readable-stream "^3.6.0" -zod-to-json-schema@^3.22.3, zod-to-json-schema@^3.22.5, zod-to-json-schema@^3.23.0: +zod-to-json-schema@^3.22.3, zod-to-json-schema@^3.22.4, zod-to-json-schema@^3.22.5, zod-to-json-schema@^3.23.0: version "3.23.0" resolved "https://registry.yarnpkg.com/zod-to-json-schema/-/zod-to-json-schema-3.23.0.tgz#4fc60e88d3c709eedbfaae3f92f8a7bf786469f2" integrity sha512-az0uJ243PxsRIa2x1WmNE/pnuA05gUq/JB8Lwe1EDCCL/Fz9MgjYQ0fPlyc2Tcv6aF2ZA7WM5TWaRZVEFaAIag== -zod-to-json-schema@^3.22.4: - version "3.22.5" - resolved "https://registry.yarnpkg.com/zod-to-json-schema/-/zod-to-json-schema-3.22.5.tgz#3646e81cfc318dbad2a22519e5ce661615418673" - integrity sha512-+akaPo6a0zpVCCseDed504KBJUQpEW5QZw7RMneNmKw+fGaML1Z9tUNLnHHAC8x6dzVRO1eB2oEMyZRnuBZg7Q== - -zod@3.22.4, zod@^3.22.3, zod@^3.22.4: +zod@3.22.4: version "3.22.4" resolved "https://registry.yarnpkg.com/zod/-/zod-3.22.4.tgz#f31c3a9386f61b1f228af56faa9255e845cf3fff" integrity sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg== -zod@^3.23.8: +zod@^3.22.3, zod@^3.22.4, zod@^3.23.8: version "3.23.8" resolved "https://registry.yarnpkg.com/zod/-/zod-3.23.8.tgz#e37b957b5d52079769fb8097099b592f0ef4067d" integrity sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g== From c0bae8b6caad15c8bdfd7fea3bd60d5f5e88a54d Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 11:05:19 -0500 Subject: [PATCH 36/55] fixing lint and types --- package.json | 1 + .../server/language_models/gemini_chat.ts | 179 +++++++++++------- .../graphs/default_assistant_graph/index.ts | 5 +- .../default_assistant_graph/nodes/respond.ts | 41 ++-- 4 files changed, 132 insertions(+), 94 deletions(-) diff --git a/package.json b/package.json index 1f4b49fe0fafc..5b31d05a962b7 100644 --- a/package.json +++ b/package.json @@ -135,6 +135,7 @@ "@formatjs/intl-relativetimeformat": "^11.2.12", "@formatjs/intl-utils": "^3.8.4", "@formatjs/ts-transformer": "^3.13.14", + "@google/generative-ai": "^0.7.0", "@grpc/grpc-js": "^1.8.22", "@hapi/accept": "^5.0.2", "@hapi/boom": "^9.1.4", diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index e30ff96bb5e4d..d0e12fd1e61c7 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -12,6 +12,10 @@ import { FunctionResponsePart, POSSIBLE_ROLES, EnhancedGenerateContentResponse, + GenerateContentRequest, + TextPart, + InlineDataPart, + GenerateContentResult, } from '@google/generative-ai'; import { ActionsClient } from '@kbn/actions-plugin/server'; import { PublicMethodsOf } from '@kbn/utility-types'; @@ -26,42 +30,65 @@ import { } from '@langchain/core/messages'; import { ChatGenerationChunk } from '@langchain/core/outputs'; import { ChatGoogleGenerativeAI } from '@langchain/google-genai'; +import { Logger } from '@kbn/logging'; +import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; +import { get } from 'lodash/fp'; +import { Readable } from 'stream'; +const DEFAULT_GEMINI_TEMPERATURE = 0; + +export interface CustomChatModelInput extends BaseChatModelParams { + actionsClient: PublicMethodsOf; + connectorId: string; + logger: Logger; + temperature?: number; + signal?: AbortSignal; + model?: string; + maxTokens?: number; +} export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { #actionsClient: PublicMethodsOf; #connectorId: string; + #temperature: number; + #model?: string; - constructor({ actionsClient, connectorId, ...props }) { + constructor({ actionsClient, connectorId, ...props }: CustomChatModelInput) { super({ ...props, apiKey: 'asda', - maxOutputTokens: 2048, + maxOutputTokens: props.maxTokens ?? 2048, }); - + // LangChain needs model to be defined for logging purposes + this.model = props.model ?? this.model; + // If model is not specified by consumer, the connector will defin eit so do not pass + // a LangChain default to the actionsClient + this.#model = props.model; + this.#temperature = props.temperature ?? DEFAULT_GEMINI_TEMPERATURE; this.#actionsClient = actionsClient; this.#connectorId = connectorId; - this.apiKey = 'sadd'; } async completionWithRetry( - request: string | GenerateContentRequest | Array, + request: string | GenerateContentRequest | Array, options?: this['ParsedCallOptions'] - ) { + ): Promise { return this.caller.callWithOptions({ signal: options?.signal }, async () => { try { - // console.error('requses', request); const requestBody = { - actionId: 'my-gemini-ai' || this.#connectorId, + actionId: this.#connectorId, params: { subAction: 'invokeAIRaw', subActionParams: { - model: 'gemini-1.5-pro-preview-0409' || this.model, + model: this.#model, messages: request, }, }, }; - const actionResult = await this.#actionsClient.execute(requestBody); + const actionResult = (await this.#actionsClient.execute(requestBody)) as { + status: string; + data: EnhancedGenerateContentResponse; + }; return { response: { @@ -69,7 +96,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { functionCalls: () => actionResult.data?.candidates?.[0]?.content.parts[0].functionCall ? [actionResult.data?.candidates?.[0]?.content.parts[0].functionCall] - : null, + : [], }, }; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -97,12 +124,12 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => { const requestBody = { - actionId: 'my-gemini-ai' || this.#connectorId, + actionId: this.#connectorId, params: { subAction: 'invokeStream', subActionParams: { - model: 'gemini-1.5-pro-preview-0409' || this.model, - messages: request.contents.reduce((acc, item) => { + model: this.#model, + messages: request.contents.reduce((acc: Content[], item) => { if (!acc?.length) { acc.push(item); return acc; @@ -118,6 +145,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { }, []), tools: request.tools, }, + temperature: this.#temperature, }, }; @@ -126,8 +154,12 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { if (actionResult.status === 'error') { throw new Error(actionResult.serviceMessage); } + const readable = get('data', actionResult) as Readable; - return actionResult.data; + if (typeof readable?.read !== 'function') { + throw new Error('Action result status is error: result is not streamable'); + } + return readable; }); let usageMetadata: UsageMetadata | undefined; @@ -138,7 +170,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { const nextChunk = `${partialStreamChunk + streamChunk}`; - let parsedStreamChunk; + let parsedStreamChunk: EnhancedGenerateContentResponse | null = null; try { parsedStreamChunk = JSON.parse(nextChunk.replaceAll('data: ', '').replaceAll('\r\n', '')); partialStreamChunk = ''; @@ -146,57 +178,55 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { partialStreamChunk += nextChunk; } - if (!parsedStreamChunk || parsedStreamChunk.candidates[0].finishReason) continue; - - const response = { - ...parsedStreamChunk, - functionCalls: () => - parsedStreamChunk?.candidates?.[0]?.content.parts[0].functionCall - ? [parsedStreamChunk.candidates?.[0]?.content.parts[0].functionCall] - : null, - }; - - if ( - 'usageMetadata' in response && - this.streamUsage !== false && - options.streamUsage !== false - ) { - const genAIUsageMetadata = response.usageMetadata as { - promptTokenCount: number; - candidatesTokenCount: number; - totalTokenCount: number; + if (parsedStreamChunk !== null && !parsedStreamChunk.candidates?.[0]?.finishReason) { + const response = { + ...parsedStreamChunk, + functionCalls: () => + parsedStreamChunk?.candidates?.[0]?.content.parts[0].functionCall + ? [parsedStreamChunk.candidates?.[0]?.content.parts[0].functionCall] + : [], }; - if (!usageMetadata) { - usageMetadata = { - input_tokens: genAIUsageMetadata.promptTokenCount, - output_tokens: genAIUsageMetadata.candidatesTokenCount, - total_tokens: genAIUsageMetadata.totalTokenCount, - }; - } else { - // Under the hood, LangChain combines the prompt tokens. Google returns the updated - // total each time, so we need to find the difference between the tokens. - const outputTokenDiff = - genAIUsageMetadata.candidatesTokenCount - usageMetadata.output_tokens; - usageMetadata = { - input_tokens: 0, - output_tokens: outputTokenDiff, - total_tokens: outputTokenDiff, + + if ( + 'usageMetadata' in response && + this.streamUsage !== false && + options.streamUsage !== false + ) { + const genAIUsageMetadata = response.usageMetadata as { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; }; + if (!usageMetadata) { + usageMetadata = { + input_tokens: genAIUsageMetadata.promptTokenCount, + output_tokens: genAIUsageMetadata.candidatesTokenCount, + total_tokens: genAIUsageMetadata.totalTokenCount, + }; + } else { + // Under the hood, LangChain combines the prompt tokens. Google returns the updated + // total each time, so we need to find the difference between the tokens. + const outputTokenDiff = + genAIUsageMetadata.candidatesTokenCount - usageMetadata.output_tokens; + usageMetadata = { + input_tokens: 0, + output_tokens: outputTokenDiff, + total_tokens: outputTokenDiff, + }; + } } - } - const chunk = convertResponseContentToChatGenerationChunk(response, { - usageMetadata, - index, - }); - index += 1; + const chunk = convertResponseContentToChatGenerationChunk(response, { + usageMetadata, + index, + }); + index += 1; - if (!chunk) { - continue; + if (chunk) { + yield chunk; + await runManager?.handleLLMNewToken(chunk.text ?? ''); + } } - - yield chunk; - await runManager?.handleLLMNewToken(chunk.text ?? ''); } } } @@ -275,12 +305,6 @@ export function convertBaseMessagesToContent(messages: BaseMessage[], isMultimod throw new Error('System message should be the first one'); } const role = convertAuthorToRole(author); - - const prevContent = acc.content[acc.content.length]; - if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) { - throw new Error('Google Generative AI requires alternate messages between authors'); - } - const parts = convertMessageContentToParts(message, isMultimodalModel); if (acc.mergeWithPreviousContent) { @@ -352,7 +376,7 @@ export function convertMessageContentToParts( if (c.type === 'text') { return { text: c.text, - }; + } as TextPart; } if (c.type === 'image_url') { @@ -382,16 +406,16 @@ export function convertMessageContentToParts( data, mimeType, }, - }; + } as InlineDataPart; } else if (c.type === 'media') { - return c; + return messageContentMedia(c); } else if (c.type === 'tool_use') { return { functionCall: { name: c.name, args: c.input, }, - }; + } as FunctionCallPart; } throw new Error(`Unknown content type ${(c as { type: string }).type}`); }); @@ -410,3 +434,16 @@ export function getMessageAuthor(message: BaseMessage) { } return message.name ?? type; } + +// will be removed once FileDataPart is supported in @langchain/google-genai +function messageContentMedia(content: Record): InlineDataPart { + if ('mimeType' in content && 'data' in content) { + return { + inlineData: { + mimeType: content.mimeType, + data: content.data, + }, + } as InlineDataPart; + } + throw new Error('Invalid media content'); +} diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 22653625e29d7..914f02110b945 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -59,7 +59,6 @@ export const callAssistantGraph: AgentExecutor = async ({ connectorId, llmType, logger, - apiKey: '', // possible client model override, // let this be undefined otherwise so the connector handles the model model: request.body.model, @@ -124,14 +123,14 @@ export const callAssistantGraph: AgentExecutor = async ({ streamRunnable: isStream, }) : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled - ? await createToolCallingAgent({ + ? createToolCallingAgent({ llm, tools, prompt: ChatPromptTemplate.fromMessages([ [ 'system', 'You are a helpful assistant. ALWAYS use the provided tools.\n\n' + - `The final response will be the only output the user sees and should be a complete answer to the user's question.`, + `The final response will be the only output the user sees and should be a complete answer to the user's question. The final response should never be empty.`, ], ['placeholder', '{chat_history}'], ['human', '{input}'], diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts index 1e4e0016734f6..4820a494560ab 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -6,34 +6,35 @@ */ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { StringWithAutocomplete } from '@langchain/core/dist/utils/types'; import { AGENT_NODE_TAG } from './run_agent'; import { AgentState } from '../types'; export const RESPOND_NODE = 'respond'; export const respond = async ({ llm, state }: { llm: BaseChatModel; state: AgentState }) => { - // Assign the final model call a run name - // console.error('state', state); - // const { messages } = state; - const userMessage = [ - 'user', - `Respond exactly with + if (state?.agentOutcome && 'returnValues' in state.agentOutcome) { + const userMessage = [ + 'user', + `Respond exactly with ${state.agentOutcome?.returnValues?.output} Do not verify, confirm or anything else. Just reply with the same content as provided above.`, - ]; - // console.error('messages', messages); - // console.error('userMessage', userMessage); - const responseMessage = await llm - // .bindTools([]) - // use AGENT_NODE_TAG to identify as agent node for stream parsing - .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] }) - .invoke([userMessage]); + ] as [StringWithAutocomplete<'user'>, string]; - return { - agentOutcome: { - returnValues: { - output: responseMessage.content, + const responseMessage = await llm + // .bindTools([]) + // use AGENT_NODE_TAG to identify as agent node for stream parsing + .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] }) + .invoke([userMessage]); + + return { + agentOutcome: { + ...state.agentOutcome, + returnValues: { + output: responseMessage.content, + }, }, - }, - }; + }; + } + return state; }; From 6d3203fc2e1d986374a8f6462ebf52420291357b Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 22 Jul 2024 18:17:29 +0200 Subject: [PATCH 37/55] update jest --- .../server/language_models/bedrock_chat.ts | 2 - .../connector_types.test.ts.snap | 765 +++++++++++++++++- .../execute_custom_llm_chain/index.test.ts | 29 +- .../execute_custom_llm_chain/index.ts | 4 + .../server/lib/conversational_chain.test.ts | 4 +- 5 files changed, 740 insertions(+), 64 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index 0a329ba41e03f..e52a245b35bc9 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -20,13 +20,11 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { actionsClient, connectorId, logger, - graph, ...params }: { actionsClient: PublicMethodsOf; connectorId: string; logger: Logger; - graph?: boolean; } & BaseChatModelParams) { super({ ...params, diff --git a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap index b757c08a9b238..315840ba14e18 100644 --- a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap +++ b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap @@ -45,6 +45,19 @@ Object { ], "type": "string", }, + "raw": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "boolean", + }, "signal": Object { "flags": Object { "default": [Function], @@ -153,6 +166,19 @@ Object { ], "type": "string", }, + "raw": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "boolean", + }, "signal": Object { "flags": Object { "default": [Function], @@ -197,6 +223,27 @@ Object { "presence": "optional", }, "keys": Object { + "anthropicVersion": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, "maxTokens": Object { "flags": Object { "default": [Function], @@ -228,15 +275,12 @@ Object { "flags": Object { "error": [Function], }, - "rules": Array [ + "metas": Array [ Object { - "args": Object { - "method": [Function], - }, - "name": "custom", + "x-oas-any-type": true, }, ], - "type": "string", + "type": "any", }, "role": Object { "flags": Object { @@ -372,6 +416,78 @@ Object { ], "type": "number", }, + "tools": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "description": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "input_schema": Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + "name": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, }, "type": "object", } @@ -387,6 +503,27 @@ Object { "presence": "optional", }, "keys": Object { + "anthropicVersion": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, "maxTokens": Object { "flags": Object { "default": [Function], @@ -415,6 +552,17 @@ Object { }, "keys": Object { "content": Object { + "flags": Object { + "error": [Function], + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + "role": Object { "flags": Object { "error": [Function], }, @@ -428,6 +576,272 @@ Object { ], "type": "string", }, + }, + "type": "object", + }, + ], + "type": "array", + }, + "model": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "signal": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + Object { + "x-oas-optional": true, + }, + ], + "type": "any", + }, + "stopSequences": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, + "system": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "temperature": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + "timeout": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + "tools": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "description": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "input_schema": Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + "name": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .bedrock 6`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "anthropicVersion": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "maxTokens": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + "messages": Object { + "flags": Object { + "error": [Function], + }, + "items": Array [ + Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "content": Object { + "flags": Object { + "error": [Function], + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, "role": Object { "flags": Object { "error": [Function], @@ -531,43 +945,115 @@ Object { "args": Object { "method": [Function], }, - "name": "custom", - }, - ], - "type": "string", - }, - "temperature": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, + "name": "custom", + }, + ], + "type": "string", + }, + "temperature": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + "timeout": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + "tools": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "description": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "input_schema": Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + "unknown": true, + }, + "keys": Object {}, + "preferences": Object { + "stripUnknown": Object { + "objects": false, + }, + }, + "type": "object", + }, + "name": Object { + "flags": Object { + "error": [Function], + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + }, + "type": "object", }, ], - "type": "number", - }, - "timeout": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, "metas": Array [ Object { "x-oas-optional": true, }, ], - "type": "number", + "type": "array", }, }, "type": "object", } `; -exports[`Connector type config checks detect connector type changes for: .bedrock 6`] = ` +exports[`Connector type config checks detect connector type changes for: .bedrock 7`] = ` Object { "flags": Object { "default": Object { @@ -612,7 +1098,7 @@ Object { } `; -exports[`Connector type config checks detect connector type changes for: .bedrock 7`] = ` +exports[`Connector type config checks detect connector type changes for: .bedrock 8`] = ` Object { "flags": Object { "default": Object { @@ -655,7 +1141,7 @@ Object { } `; -exports[`Connector type config checks detect connector type changes for: .bedrock 8`] = ` +exports[`Connector type config checks detect connector type changes for: .bedrock 9`] = ` Object { "flags": Object { "default": Object { @@ -3459,15 +3945,12 @@ Object { "flags": Object { "error": [Function], }, - "rules": Array [ + "metas": Array [ Object { - "args": Object { - "method": [Function], - }, - "name": "custom", + "x-oas-any-type": true, }, ], - "type": "string", + "type": "any", }, "model": Object { "flags": Object { @@ -3490,6 +3973,19 @@ Object { ], "type": "string", }, + "raw": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "boolean", + }, "signal": Object { "flags": Object { "default": [Function], @@ -3610,6 +4106,24 @@ Object { "flags": Object { "error": [Function], }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + "model": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], "rules": Array [ Object { "args": Object { @@ -3620,6 +4134,117 @@ Object { ], "type": "string", }, + "raw": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "boolean", + }, + "signal": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + Object { + "x-oas-optional": true, + }, + ], + "type": "any", + }, + "stopSequences": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, + "temperature": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + "timeout": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "number", + }, + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .gemini 4`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "messages": Object { + "flags": Object { + "error": [Function], + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, "model": Object { "flags": Object { "default": [Function], @@ -3713,12 +4338,39 @@ Object { ], "type": "number", }, + "tools": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, }, "type": "object", } `; -exports[`Connector type config checks detect connector type changes for: .gemini 4`] = ` +exports[`Connector type config checks detect connector type changes for: .gemini 5`] = ` Object { "flags": Object { "default": Object { @@ -3837,7 +4489,7 @@ Object { } `; -exports[`Connector type config checks detect connector type changes for: .gemini 5`] = ` +exports[`Connector type config checks detect connector type changes for: .gemini 6`] = ` Object { "flags": Object { "default": Object { @@ -3951,12 +4603,39 @@ Object { ], "type": "number", }, + "tools": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, }, "type": "object", } `; -exports[`Connector type config checks detect connector type changes for: .gemini 6`] = ` +exports[`Connector type config checks detect connector type changes for: .gemini 7`] = ` Object { "flags": Object { "default": Object { @@ -4029,7 +4708,7 @@ Object { } `; -exports[`Connector type config checks detect connector type changes for: .gemini 7`] = ` +exports[`Connector type config checks detect connector type changes for: .gemini 8`] = ` Object { "flags": Object { "default": Object { @@ -4058,7 +4737,7 @@ Object { } `; -exports[`Connector type config checks detect connector type changes for: .gemini 8`] = ` +exports[`Connector type config checks detect connector type changes for: .gemini 9`] = ` Object { "flags": Object { "default": Object { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts index 6f4b57675d127..7772c0d267273 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts @@ -11,7 +11,7 @@ import { KibanaRequest } from '@kbn/core/server'; import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { loggerMock } from '@kbn/logging-mocks'; -import { initializeAgentExecutorWithOptions } from 'langchain/agents'; +import { initializeAgentExecutorWithOptions, AgentExecutor } from 'langchain/agents'; import { mockActionResponse } from '../../../__mocks__/action_result_data'; import { langChainMessages } from '../../../__mocks__/lang_chain_messages'; @@ -52,6 +52,7 @@ const mockCall = jest.fn().mockImplementation(() => }) ); const mockInvoke = jest.fn().mockImplementation(() => Promise.resolve()); + jest.mock('langchain/agents'); jest.mock('../elasticsearch_store/elasticsearch_store', () => ({ @@ -123,6 +124,8 @@ const bedrockChatProps = { llmType: 'bedrock', }; const executorMock = initializeAgentExecutorWithOptions as jest.Mock; +const agentExecutorMock = AgentExecutor as unknown as jest.Mock; + describe('callAgentExecutor', () => { beforeEach(() => { jest.clearAllMocks(); @@ -132,6 +135,10 @@ describe('callAgentExecutor', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any invoke: (props: any, more: any) => mockInvoke({ ...props, agentType }, more), })); + agentExecutorMock.mockReturnValue({ + call: mockCall, + invoke: mockInvoke, + }); }); describe('callAgentExecutor', () => { @@ -290,20 +297,14 @@ describe('callAgentExecutor', () => { connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, + signal: undefined, + model: undefined, streaming: false, temperature: 0, llmType: 'bedrock', }); }); - it('uses the structured-chat-zero-shot-react-description agent type', async () => { - await callAgentExecutor(bedrockChatProps); - expect(mockCall.mock.calls[0][0].agentType).toEqual( - 'structured-chat-zero-shot-react-description' - ); - }); - it('returns the expected response', async () => { const result = await callAgentExecutor(bedrockChatProps); @@ -330,19 +331,13 @@ describe('callAgentExecutor', () => { connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, + signal: undefined, + model: undefined, streaming: true, temperature: 0, llmType: 'bedrock', }); }); - - it('uses the structured-chat-zero-shot-react-description agent type', async () => { - await callAgentExecutor({ ...bedrockChatProps, isStream: true }); - expect(mockInvoke.mock.calls[0][0].agentType).toEqual( - 'structured-chat-zero-shot-react-description' - ); - }); }); }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 38759d4d68ea3..dc445d1801de7 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -164,6 +164,8 @@ export const callAgentExecutor: AgentExecutor = async ({ }, }); + console.error('exectuor', executor); + // Sets up tracer for tracing executions to APM. See x-pack/plugins/elastic_assistant/server/lib/langchain/tracers/README.mdx // If LangSmith env vars are set, executions will be traced there as well. See https://docs.smith.langchain.com/tracing const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); @@ -280,6 +282,8 @@ export const callAgentExecutor: AgentExecutor = async ({ ); }); + console.error('langChainResponse', langChainResponse); + const langChainOutput = langChainResponse.output; if (onLlmResponse) { await onLlmResponse(langChainOutput, traceData); diff --git a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts index 7e2d0a72089fd..88a6052a0bbf7 100644 --- a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts +++ b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts @@ -449,8 +449,8 @@ describe('conversational chain', () => { ], // Even with body_content of 1000, the token count should be below or equal to model limit of 100 expectedTokens: [ - { type: 'context_token_count', count: 68 }, - { type: 'prompt_token_count', count: 102 }, + { type: 'context_token_count', count: 63 }, + { type: 'prompt_token_count', count: 97 }, ], expectedHasClipped: true, expectedSearchRequest: [ From 5479131e5c65279f1b7f464b92813810899c872f Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 22 Jul 2024 18:31:50 +0200 Subject: [PATCH 38/55] flip feature flagas --- .../kbn-elastic-assistant-common/impl/capabilities/index.ts | 4 ++-- .../plugins/security_solution/common/experimental_features.ts | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts index 819432bae6ec6..1e759df2819ed 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/capabilities/index.ts @@ -19,7 +19,7 @@ export type AssistantFeatureKey = keyof AssistantFeatures; * Default features available to the elastic assistant */ export const defaultAssistantFeatures = Object.freeze({ - assistantKnowledgeBaseByDefault: true, + assistantKnowledgeBaseByDefault: false, assistantModelEvaluation: false, - assistantBedrockChat: true, + assistantBedrockChat: false, }); diff --git a/x-pack/plugins/security_solution/common/experimental_features.ts b/x-pack/plugins/security_solution/common/experimental_features.ts index da8a1c5a442ed..f18b16d17153f 100644 --- a/x-pack/plugins/security_solution/common/experimental_features.ts +++ b/x-pack/plugins/security_solution/common/experimental_features.ts @@ -126,12 +126,12 @@ export const allowedExperimentalValues = Object.freeze({ /** * Enables new Knowledge Base Entries features, introduced in `8.15.0`. */ - assistantKnowledgeBaseByDefault: true, + assistantKnowledgeBaseByDefault: false, /** * Enables the Assistant BedrockChat Langchain model, introduced in `8.15.0`. */ - assistantBedrockChat: true, + assistantBedrockChat: false, /** * Enables the Managed User section inside the new user details flyout. From 8a865f82711412542b558f980c1da1f6fd673550 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 11:48:55 -0500 Subject: [PATCH 39/55] better err handling --- .../server/language_models/gemini_chat.ts | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index d0e12fd1e61c7..0367f79c5ee7a 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -88,13 +88,24 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { const actionResult = (await this.#actionsClient.execute(requestBody)) as { status: string; data: EnhancedGenerateContentResponse; + message?: string; + serviceMessage?: string; }; + if (actionResult.status === 'error') { + throw new Error( + `ActionsClientGeminiChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` + ); + } + + if (!actionResult.data?.candidates?.[0]?.content) { + console.log('stephhh actionResult', JSON.stringify(actionResult)); + } return { response: { ...actionResult.data, functionCalls: () => - actionResult.data?.candidates?.[0]?.content.parts[0].functionCall + actionResult.data?.candidates?.[0]?.content?.parts[0].functionCall ? [actionResult.data?.candidates?.[0]?.content.parts[0].functionCall] : [], }, @@ -152,8 +163,11 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { - throw new Error(actionResult.serviceMessage); + throw new Error( + `ActionsClientGeminiChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` + ); } + const readable = get('data', actionResult) as Readable; if (typeof readable?.read !== 'function') { From da2d127cdfb26addc4fe2e8e004715527379a236 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 11:50:28 -0500 Subject: [PATCH 40/55] rm logs --- .../kbn-langchain/server/language_models/gemini_chat.ts | 3 --- .../server/lib/langchain/execute_custom_llm_chain/index.ts | 4 ---- 2 files changed, 7 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 0367f79c5ee7a..8971b49fbd876 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -98,9 +98,6 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { ); } - if (!actionResult.data?.candidates?.[0]?.content) { - console.log('stephhh actionResult', JSON.stringify(actionResult)); - } return { response: { ...actionResult.data, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index dc445d1801de7..38759d4d68ea3 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -164,8 +164,6 @@ export const callAgentExecutor: AgentExecutor = async ({ }, }); - console.error('exectuor', executor); - // Sets up tracer for tracing executions to APM. See x-pack/plugins/elastic_assistant/server/lib/langchain/tracers/README.mdx // If LangSmith env vars are set, executions will be traced there as well. See https://docs.smith.langchain.com/tracing const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); @@ -282,8 +280,6 @@ export const callAgentExecutor: AgentExecutor = async ({ ); }); - console.error('langChainResponse', langChainResponse); - const langChainOutput = langChainResponse.output; if (onLlmResponse) { await onLlmResponse(langChainOutput, traceData); From 311d92295cd5429e363d3535a03f992ed0554808 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 12:17:54 -0500 Subject: [PATCH 41/55] fix prompt, readd gemini token code --- .../server/lib/gen_ai_token_tracking.ts | 36 +++++++++---------- .../graphs/default_assistant_graph/index.ts | 4 +-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index 5568904c83dc2..41bfa28605f40 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -207,24 +207,24 @@ export const getGenAiTokenTracking = async ({ } // Process non-streamed Gemini response from `usageMetadata` object - // if (actionTypeId === '.gemini') { - // const data = result.data as unknown as { - // usageMetadata: { - // promptTokenCount?: number; - // candidatesTokenCount?: number; - // totalTokenCount?: number; - // }; - // }; - // if (data.usageMetadata == null) { - // logger.error('Response did not contain usage metadata object'); - // return null; - // } - // return { - // total_tokens: data.usageMetadata?.totalTokenCount ?? 0, - // prompt_tokens: data.usageMetadata?.promptTokenCount ?? 0, - // completion_tokens: data.usageMetadata?.candidatesTokenCount ?? 0, - // }; - // } + if (actionTypeId === '.gemini') { + const data = result.data as unknown as { + usageMetadata: { + promptTokenCount?: number; + candidatesTokenCount?: number; + totalTokenCount?: number; + }; + }; + if (data.usageMetadata == null) { + logger.error('Response did not contain usage metadata object'); + return null; + } + return { + total_tokens: data.usageMetadata?.totalTokenCount ?? 0, + prompt_tokens: data.usageMetadata?.promptTokenCount ?? 0, + completion_tokens: data.usageMetadata?.candidatesTokenCount ?? 0, + }; + } // this is a non-streamed Bedrock response used by security solution if (actionTypeId === '.bedrock' && validatedParams.subAction === 'invokeAI') { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 914f02110b945..de6793e580369 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -129,8 +129,8 @@ export const callAssistantGraph: AgentExecutor = async ({ prompt: ChatPromptTemplate.fromMessages([ [ 'system', - 'You are a helpful assistant. ALWAYS use the provided tools.\n\n' + - `The final response will be the only output the user sees and should be a complete answer to the user's question. The final response should never be empty.`, + 'You are a helpful assistant. ALWAYS use the provided tools. Use tools as often as possible, as they have access to the latest data and syntax.\n\n' + + `The final response will be the only output the user sees and should be a complete answer to the user's question, as if you were responding to the user's initial question, which is "{input}". The final response should never be empty.`, ], ['placeholder', '{chat_history}'], ['human', '{input}'], From 8255c11503d7ed019aa36d3e681b775deb6cd45c Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 12:33:47 -0500 Subject: [PATCH 42/55] add safety settings arg --- .../server/language_models/gemini_chat.ts | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 8971b49fbd876..36c6d4244b4d2 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -7,15 +7,17 @@ import { Content, - Part, + EnhancedGenerateContentResponse, FunctionCallPart, FunctionResponsePart, - POSSIBLE_ROLES, - EnhancedGenerateContentResponse, GenerateContentRequest, - TextPart, - InlineDataPart, GenerateContentResult, + HarmBlockThreshold, + HarmCategory, + InlineDataPart, + POSSIBLE_ROLES, + Part, + TextPart, } from '@google/generative-ai'; import { ActionsClient } from '@kbn/actions-plugin/server'; import { PublicMethodsOf } from '@kbn/utility-types'; @@ -57,6 +59,12 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { ...props, apiKey: 'asda', maxOutputTokens: props.maxTokens ?? 2048, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], }); // LangChain needs model to be defined for logging purposes this.model = props.model ?? this.model; From 440a9c4d683dc69f44da60b5d0f73e74e0a2e712 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Mon, 22 Jul 2024 21:35:16 +0200 Subject: [PATCH 43/55] fix types --- .../graphs/default_assistant_graph/helpers.test.ts | 4 ++++ .../graphs/default_assistant_graph/helpers.ts | 2 +- x-pack/plugins/elastic_assistant/server/types.ts | 6 +++--- .../plugins/stack_connectors/common/bedrock/schema.ts | 2 +- .../server/connector_types/bedrock/bedrock.ts | 10 +++++----- .../server/connector_types/gemini/gemini.ts | 2 +- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts index 42eb8c81c4205..2ac1c5b0ea373 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts @@ -100,6 +100,8 @@ describe('streamGraph', () => { logger: mockLogger, onLlmResponse: mockOnLlmResponse, request: mockRequest, + bedrockChatEnabled: false, + llmType: 'openai', }); expect(response).toBe(mockResponseWithHeaders); @@ -179,6 +181,8 @@ describe('streamGraph', () => { logger: mockLogger, onLlmResponse: mockOnLlmResponse, request: mockRequest, + bedrockChatEnabled: false, + llmType: 'gemini', }); expect(response).toBe(mockResponseWithHeaders); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 50b386ce0bec8..0d39e7f0711ae 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -100,7 +100,7 @@ export const streamGraph = async ({ const msg = data.chunk as AIMessageChunk; if (!didEnd && !msg.tool_call_chunks?.length && msg.content.length) { - push({ payload: msg.content, type: 'content' }); + push({ payload: msg.content as string, type: 'content' }); } } diff --git a/x-pack/plugins/elastic_assistant/server/types.ts b/x-pack/plugins/elastic_assistant/server/types.ts index eecb9e65499c0..e4a72aef17557 100755 --- a/x-pack/plugins/elastic_assistant/server/types.ts +++ b/x-pack/plugins/elastic_assistant/server/types.ts @@ -37,9 +37,9 @@ import { LicensingApiRequestHandlerContext } from '@kbn/licensing-plugin/server' import { ActionsClientBedrockChatModel, ActionsClientChatOpenAI, + ActionsClientGeminiChatModel, ActionsClientLlm, ActionsClientSimpleChatModel, - ActionsClientVertexChatModel, } from '@kbn/langchain/server'; import { AttackDiscoveryDataClient } from './ai_assistant_data_clients/attack_discovery'; @@ -216,8 +216,8 @@ export interface AssistantTool { export type AssistantToolLlm = | ActionsClientBedrockChatModel | ActionsClientChatOpenAI - | ActionsClientSimpleChatModel - | ActionsClientVertexChatModel; + | ActionsClientGeminiChatModel + | ActionsClientSimpleChatModel; export interface AssistantToolParams { alertsIndexPattern?: string; diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index bb89265219a78..093a5f9b11518 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -32,7 +32,7 @@ export const InvokeAIActionParamsSchema = schema.object({ messages: schema.arrayOf( schema.object({ role: schema.string(), - content: schema.any(), + content: schema.string(), }) ), model: schema.maybe(schema.string()), diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 2b91a1ed948ee..5167cc34ad11f 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -298,7 +298,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, tools, - }: InvokeAIActionParams): Promise { + }: InvokeAIActionParams | InvokeAIRawActionParams): Promise { const res = (await this.streamApi({ body: JSON.stringify( formatBedrockBody({ messages, stopSequences, system, temperature, tools }) @@ -378,7 +378,7 @@ const formatBedrockBody = ({ maxTokens = DEFAULT_TOKEN_LIMIT, tools, }: { - messages: Array<{ role: string; content: string }>; + messages: Array<{ role: string; content?: string }>; stopSequences?: string[]; temperature?: number; maxTokens?: number; @@ -401,12 +401,12 @@ const formatBedrockBody = ({ * @param messages */ const ensureMessageFormat = ( - messages: Array<{ role: string; content: string }>, + messages: Array<{ role: string; content?: string }>, systemPrompt?: string -): { messages: Array<{ role: string; content: string }>; system?: string } => { +): { messages: Array<{ role: string; content?: string }>; system?: string } => { let system = systemPrompt ? systemPrompt : ''; - const newMessages = messages.reduce((acc: Array<{ role: string; content: string }>, m) => { + const newMessages = messages.reduce((acc: Array<{ role: string; content?: string }>, m) => { const lastMessage = acc[acc.length - 1]; if (m.role === 'system') { system = `${system.length ? `${system}\n` : ''}${m.content}`; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index 19a2dc0ceadca..d89a88f122ae6 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -328,7 +328,7 @@ export class GeminiConnector extends SubActionConnector { /** Format the json body to meet Gemini payload requirements */ const formatGeminiPayload = ( - data: Array<{ role: string; content: string }>, + data: Array<{ role: string; content: string; parts: [{ text: string }] }>, temperature: number ): Payload => { const payload: Payload = { From c0725dfe700f44fb8b96a0b6d0f13d37cc0be11d Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 14:35:51 -0500 Subject: [PATCH 44/55] fix stack connector types --- .../server/connector_types/bedrock/bedrock.ts | 8 ++++---- .../server/connector_types/gemini/gemini.ts | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 2b91a1ed948ee..5240c0b2746e4 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -378,7 +378,7 @@ const formatBedrockBody = ({ maxTokens = DEFAULT_TOKEN_LIMIT, tools, }: { - messages: Array<{ role: string; content: string }>; + messages: Array<{ role: string; content?: string }>; stopSequences?: string[]; temperature?: number; maxTokens?: number; @@ -401,12 +401,12 @@ const formatBedrockBody = ({ * @param messages */ const ensureMessageFormat = ( - messages: Array<{ role: string; content: string }>, + messages: Array<{ role: string; content?: string }>, systemPrompt?: string -): { messages: Array<{ role: string; content: string }>; system?: string } => { +): { messages: Array<{ role: string; content?: string }>; system?: string } => { let system = systemPrompt ? systemPrompt : ''; - const newMessages = messages.reduce((acc: Array<{ role: string; content: string }>, m) => { + const newMessages = messages.reduce((acc: Array<{ role: string; content?: string }>, m) => { const lastMessage = acc[acc.length - 1]; if (m.role === 'system') { system = `${system.length ? `${system}\n` : ''}${m.content}`; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index 19a2dc0ceadca..ee2b1e1d5de22 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -328,7 +328,7 @@ export class GeminiConnector extends SubActionConnector { /** Format the json body to meet Gemini payload requirements */ const formatGeminiPayload = ( - data: Array<{ role: string; content: string }>, + data: Array<{ role: string; content: string; parts: MessagePart[] }>, temperature: number ): Payload => { const payload: Payload = { From debf37545fbd332be8f54f47981a6a72f1a412e3 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 14:50:41 -0500 Subject: [PATCH 45/55] cleanup bedrockchat --- .../server/language_models/bedrock_chat.ts | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts index e52a245b35bc9..55117ba491559 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/bedrock_chat.ts @@ -15,22 +15,23 @@ import { PublicMethodsOf } from '@kbn/utility-types'; export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; export const DEFAULT_BEDROCK_REGION = 'us-east-1'; +export interface CustomChatModelInput extends BaseChatModelParams { + actionsClient: PublicMethodsOf; + connectorId: string; + logger: Logger; + temperature?: number; + signal?: AbortSignal; + model?: string; + maxTokens?: number; +} + export class ActionsClientBedrockChatModel extends _BedrockChat { - constructor({ - actionsClient, - connectorId, - logger, - ...params - }: { - actionsClient: PublicMethodsOf; - connectorId: string; - logger: Logger; - } & BaseChatModelParams) { + constructor({ actionsClient, connectorId, logger, ...params }: CustomChatModelInput) { super({ ...params, credentials: { accessKeyId: '', secretAccessKey: '' }, // only needed to force BedrockChat to use messages api for Claude v2 - model: DEFAULT_BEDROCK_MODEL, + model: params.model ?? DEFAULT_BEDROCK_MODEL, region: DEFAULT_BEDROCK_REGION, fetchFn: async (url, options) => { const inputBody = JSON.parse(options?.body as string); @@ -42,10 +43,10 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { subAction: 'invokeStream', subActionParams: { messages: inputBody.messages, - temperature: inputBody.temperature, + temperature: params.temperature ?? inputBody.temperature, stopSequences: inputBody.stop_sequences, system: inputBody.system, - maxTokens: inputBody.max_tokens, + maxTokens: params.maxTokens ?? inputBody.max_tokens, tools: inputBody.tools, anthropicVersion: inputBody.anthropic_version, }, @@ -63,10 +64,10 @@ export class ActionsClientBedrockChatModel extends _BedrockChat { subAction: 'invokeAIRaw', subActionParams: { messages: inputBody.messages, - temperature: inputBody.temperature, + temperature: params.temperature ?? inputBody.temperature, stopSequences: inputBody.stop_sequences, system: inputBody.system, - maxTokens: inputBody.max_tokens, + maxTokens: params.maxTokens ?? inputBody.max_tokens, tools: inputBody.tools, anthropicVersion: inputBody.anthropic_version, }, From 4db15dc1fdcbfd21692258e3761d601f0e309af5 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 14:54:47 -0500 Subject: [PATCH 46/55] clean up comments --- .../langchain/graphs/default_assistant_graph/nodes/respond.ts | 1 - .../server/connector_types/bedrock/bedrock.ts | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts index 4820a494560ab..bb3b3a518e06d 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -22,7 +22,6 @@ export const respond = async ({ llm, state }: { llm: BaseChatModel; state: Agent ] as [StringWithAutocomplete<'user'>, string]; const responseMessage = await llm - // .bindTools([]) // use AGENT_NODE_TAG to identify as agent node for stream parsing .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] }) .invoke([userMessage]); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 5167cc34ad11f..6b981a365b63a 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -194,14 +194,14 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B } private async runApiRaw( - params: SubActionRequestParams // : SubActionRequestParams + params: SubActionRequestParams ): Promise { const response = await this.request(params); return response.data; } private async runApiLatest( - params: SubActionRequestParams // : SubActionRequestParams + params: SubActionRequestParams ): Promise { const response = await this.request(params); // keeping the response the same as claude 2 for our APIs From 50d73f16bd6b3a7c885996b08dc70e0261135f89 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 15:05:08 -0500 Subject: [PATCH 47/55] fix weird type thing --- .../graphs/default_assistant_graph/graph.ts | 4 +- .../graphs/default_assistant_graph/index.ts | 38 +++++++++---------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 2a20dbc9be866..d7cb716e1ee0d 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -41,7 +41,6 @@ interface GetDefaultAssistantGraphParams { agentRunnable: AgentRunnableSequence; dataClients?: AssistantDataClients; conversationId?: string; - getLlmInstance: () => BaseChatModel; llm: BaseChatModel; logger: Logger; tools: StructuredTool[]; @@ -61,7 +60,6 @@ export const getDefaultAssistantGraph = ({ agentRunnable, conversationId, dataClients, - getLlmInstance, llm, logger, responseLanguage, @@ -154,7 +152,7 @@ export const getDefaultAssistantGraph = ({ const respondNode = (state: AgentState) => respond({ ...nodeParams, - llm: getLlmInstance(), + llm, state, }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index de6793e580369..bd73d251060b5 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -53,26 +53,23 @@ export const callAssistantGraph: AgentExecutor = async ({ const logger = parentLogger.get('defaultAssistantGraph'); const isOpenAI = llmType === 'openai'; const llmClass = getLlmClass(llmType, bedrockChatEnabled); - const getLlmInstance = () => - new llmClass({ - actionsClient, - connectorId, - llmType, - logger, - // possible client model override, - // let this be undefined otherwise so the connector handles the model - model: request.body.model, - // ensure this is defined because we default to it in the language_models - // This is where the LangSmith logs (Metadata > Invocation Params) are set - temperature: getDefaultArguments(llmType).temperature, - signal: abortSignal, - streaming: isStream, - // prevents the agent from retrying on failure - // failure could be due to bad connector, we should deliver that result to the client asap - maxRetries: 0, - }); - - const llm = getLlmInstance(); + const llm = new llmClass({ + actionsClient, + connectorId, + llmType, + logger, + // possible client model override, + // let this be undefined otherwise so the connector handles the model + model: request.body.model, + // ensure this is defined because we default to it in the language_models + // This is where the LangSmith logs (Metadata > Invocation Params) are set + temperature: getDefaultArguments(llmType).temperature, + signal: abortSignal, + streaming: isStream, + // prevents the agent from retrying on failure + // failure could be due to bad connector, we should deliver that result to the client asap + maxRetries: 0, + }); const anonymizationFieldsRes = await dataClients?.anonymizationFieldsDataClient?.findDocuments({ @@ -152,7 +149,6 @@ export const callAssistantGraph: AgentExecutor = async ({ conversationId, dataClients, llm, - getLlmInstance, logger, tools, responseLanguage, From 91a3dfa1d6410e399c271842ab87479f92de67c4 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 15:18:05 -0500 Subject: [PATCH 48/55] fix gemini subaction args --- .../kbn-langchain/server/language_models/gemini_chat.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 36c6d4244b4d2..5597a5b4bb0b4 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -89,6 +89,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { subActionParams: { model: this.#model, messages: request, + temperature: this.#temperature, }, }, }; @@ -159,9 +160,9 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { acc.push(item); return acc; }, []), + temperature: this.#temperature, tools: request.tools, }, - temperature: this.#temperature, }, }; From 8fe746c0da42d797be66f6e37cf6a435f012ff82 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 15:47:54 -0500 Subject: [PATCH 49/55] fix safety settings --- .../server/language_models/gemini_chat.ts | 8 -------- .../server/connector_types/gemini/gemini.ts | 11 ++++++++++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 5597a5b4bb0b4..7a82e35ba2877 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -12,8 +12,6 @@ import { FunctionResponsePart, GenerateContentRequest, GenerateContentResult, - HarmBlockThreshold, - HarmCategory, InlineDataPart, POSSIBLE_ROLES, Part, @@ -59,12 +57,6 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { ...props, apiKey: 'asda', maxOutputTokens: props.maxTokens ?? 2048, - safetySettings: [ - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - }, - ], }); // LangChain needs model to be defined for logging purposes this.model = props.model ?? this.model; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index ee2b1e1d5de22..ed75effe13b61 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -13,6 +13,7 @@ import { SubActionRequestParams } from '@kbn/actions-plugin/server/sub_action_fr import { getGoogleOAuthJwtAccessToken } from '@kbn/actions-plugin/server/lib/get_gcp_oauth_access_token'; import { ConnectorTokenClientContract } from '@kbn/actions-plugin/server/types'; +import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai'; import { RunActionParamsSchema, RunApiResponseSchema, @@ -60,6 +61,7 @@ interface Payload { temperature: number; maxOutputTokens: number; }; + safety_settings: Array<{ category: string; threshold: string }>; } export class GeminiConnector extends SubActionConnector { @@ -289,7 +291,7 @@ export class GeminiConnector extends SubActionConnector { timeout, }: InvokeAIRawActionParams): Promise { const res = await this.runApi({ - body: JSON.stringify(messages), + body: JSON.stringify(formatGeminiPayload(messages, temperature)), model, signal, timeout, @@ -337,6 +339,13 @@ const formatGeminiPayload = ( temperature, maxOutputTokens: DEFAULT_TOKEN_LIMIT, }, + safety_settings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + // when block high, the model will block responses about suspicious alerts + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + ], }; let previousRole: string | null = null; From c5cc32e6d211efcdeda48fd37fda68b202d59739 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 15:59:39 -0500 Subject: [PATCH 50/55] Revert "fix weird type thing" This reverts commit 50d73f16bd6b3a7c885996b08dc70e0261135f89. --- .../graphs/default_assistant_graph/graph.ts | 4 +- .../graphs/default_assistant_graph/index.ts | 38 ++++++++++--------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index d7cb716e1ee0d..2a20dbc9be866 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -41,6 +41,7 @@ interface GetDefaultAssistantGraphParams { agentRunnable: AgentRunnableSequence; dataClients?: AssistantDataClients; conversationId?: string; + getLlmInstance: () => BaseChatModel; llm: BaseChatModel; logger: Logger; tools: StructuredTool[]; @@ -60,6 +61,7 @@ export const getDefaultAssistantGraph = ({ agentRunnable, conversationId, dataClients, + getLlmInstance, llm, logger, responseLanguage, @@ -152,7 +154,7 @@ export const getDefaultAssistantGraph = ({ const respondNode = (state: AgentState) => respond({ ...nodeParams, - llm, + llm: getLlmInstance(), state, }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index bd73d251060b5..de6793e580369 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -53,23 +53,26 @@ export const callAssistantGraph: AgentExecutor = async ({ const logger = parentLogger.get('defaultAssistantGraph'); const isOpenAI = llmType === 'openai'; const llmClass = getLlmClass(llmType, bedrockChatEnabled); - const llm = new llmClass({ - actionsClient, - connectorId, - llmType, - logger, - // possible client model override, - // let this be undefined otherwise so the connector handles the model - model: request.body.model, - // ensure this is defined because we default to it in the language_models - // This is where the LangSmith logs (Metadata > Invocation Params) are set - temperature: getDefaultArguments(llmType).temperature, - signal: abortSignal, - streaming: isStream, - // prevents the agent from retrying on failure - // failure could be due to bad connector, we should deliver that result to the client asap - maxRetries: 0, - }); + const getLlmInstance = () => + new llmClass({ + actionsClient, + connectorId, + llmType, + logger, + // possible client model override, + // let this be undefined otherwise so the connector handles the model + model: request.body.model, + // ensure this is defined because we default to it in the language_models + // This is where the LangSmith logs (Metadata > Invocation Params) are set + temperature: getDefaultArguments(llmType).temperature, + signal: abortSignal, + streaming: isStream, + // prevents the agent from retrying on failure + // failure could be due to bad connector, we should deliver that result to the client asap + maxRetries: 0, + }); + + const llm = getLlmInstance(); const anonymizationFieldsRes = await dataClients?.anonymizationFieldsDataClient?.findDocuments({ @@ -149,6 +152,7 @@ export const callAssistantGraph: AgentExecutor = async ({ conversationId, dataClients, llm, + getLlmInstance, logger, tools, responseLanguage, From 392307480194137934c3c828803bc9387f00f268 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 16:01:34 -0500 Subject: [PATCH 51/55] add comment --- .../server/lib/langchain/graphs/default_assistant_graph/index.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index de6793e580369..8e1b58d3ac683 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -152,6 +152,7 @@ export const callAssistantGraph: AgentExecutor = async ({ conversationId, dataClients, llm, + // we need to pass it like this or streaming does not work for bedrock getLlmInstance, logger, tools, From 1133f28521e041a08cd62371ba3f55e4af266bd3 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 18:25:05 -0500 Subject: [PATCH 52/55] fix jest 1 --- .../server/connector_types/gemini/gemini.test.ts | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts index ed825d9aecbf3..10370facac202 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts @@ -161,6 +161,9 @@ describe('GeminiConnector', () => { temperature: 0, maxOutputTokens: 8192, }, + safety_settings: [ + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + ], }), headers: { Authorization: 'Bearer mock_access_token', @@ -190,6 +193,9 @@ describe('GeminiConnector', () => { temperature: 0, maxOutputTokens: 8192, }, + safety_settings: [ + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + ], }), headers: { Authorization: 'Bearer mock_access_token', @@ -237,6 +243,9 @@ describe('GeminiConnector', () => { temperature: 0, maxOutputTokens: 8192, }, + safety_settings: [ + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + ], }), responseType: 'stream', headers: { @@ -267,6 +276,9 @@ describe('GeminiConnector', () => { temperature: 0, maxOutputTokens: 8192, }, + safety_settings: [ + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + ], }), responseType: 'stream', headers: { From 6e30606bac147149c6b476705c322a3587da8cd3 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 22 Jul 2024 18:26:36 -0500 Subject: [PATCH 53/55] fix jest 2 --- .../__snapshots__/connector_types.test.ts.snap | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap index 315840ba14e18..a8477a44e1ada 100644 --- a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap +++ b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap @@ -275,12 +275,15 @@ Object { "flags": Object { "error": [Function], }, - "metas": Array [ + "rules": Array [ Object { - "x-oas-any-type": true, + "args": Object { + "method": [Function], + }, + "name": "custom", }, ], - "type": "any", + "type": "string", }, "role": Object { "flags": Object { @@ -555,12 +558,15 @@ Object { "flags": Object { "error": [Function], }, - "metas": Array [ + "rules": Array [ Object { - "x-oas-any-type": true, + "args": Object { + "method": [Function], + }, + "name": "custom", }, ], - "type": "any", + "type": "string", }, "role": Object { "flags": Object { From 6f50245d53dbd19cf2aa45649378063676d2827f Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 23 Jul 2024 09:29:21 -0500 Subject: [PATCH 54/55] fix for error --- .../server/connector_types/gemini/gemini.ts | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index ed75effe13b61..5bf28c830b679 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -123,10 +123,20 @@ export class GeminiConnector extends SubActionConnector { }); } - protected getResponseErrorMessage(error: AxiosError<{ message?: string }>): string { + protected getResponseErrorMessage( + error: AxiosError<{ + error?: { code?: number; message?: string; status?: string }; + message?: string; + }> + ): string { if (!error.response?.status) { return `Unexpected API Error: ${error.code ?? ''} - ${error.message ?? 'Unknown error'}`; } + if (error.response?.data?.error) { + return `API Error: ${ + error.response?.data?.error.status ? `${error.response.data.error.status}: ` : '' + }${error.response?.data?.error.message ? `${error.response.data.error.message}` : ''}`; + } if ( error.response.status === 400 && error.response?.data?.message === 'The requested operation is not recognized by the service.' @@ -342,8 +352,8 @@ const formatGeminiPayload = ( safety_settings: [ { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - // when block high, the model will block responses about suspicious alerts - threshold: HarmBlockThreshold.BLOCK_NONE, + // without setting threshold, the model will block responses about suspicious alerts + threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH, }, ], }; From ffe2a377c70d035a867313f97113f6c229534c23 Mon Sep 17 00:00:00 2001 From: Patryk Kopycinski Date: Tue, 23 Jul 2024 17:57:11 +0200 Subject: [PATCH 55/55] update jest --- .../server/connector_types/gemini/gemini.test.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts index 10370facac202..d58cefe12f839 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts @@ -162,7 +162,7 @@ describe('GeminiConnector', () => { maxOutputTokens: 8192, }, safety_settings: [ - { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' }, ], }), headers: { @@ -194,7 +194,7 @@ describe('GeminiConnector', () => { maxOutputTokens: 8192, }, safety_settings: [ - { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' }, ], }), headers: { @@ -244,7 +244,7 @@ describe('GeminiConnector', () => { maxOutputTokens: 8192, }, safety_settings: [ - { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' }, ], }), responseType: 'stream', @@ -277,7 +277,7 @@ describe('GeminiConnector', () => { maxOutputTokens: 8192, }, safety_settings: [ - { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' }, ], }), responseType: 'stream',