Skip to content

Commit

Permalink
[Security AI Assistant] Chat complete API (#184485)
Browse files Browse the repository at this point in the history
  • Loading branch information
YulNaumenko authored Jul 2, 2024
1 parent d00f36e commit d5a91fc
Show file tree
Hide file tree
Showing 45 changed files with 2,090 additions and 733 deletions.
4 changes: 3 additions & 1 deletion x-pack/packages/kbn-elastic-assistant-common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

export const ELASTIC_AI_ASSISTANT_INTERNAL_API_VERSION = '1';

export const ELASTIC_AI_ASSISTANT_URL = '/api/elastic_assistant';
export const ELASTIC_AI_ASSISTANT_URL = '/api/security_ai_assistant';
export const ELASTIC_AI_ASSISTANT_INTERNAL_URL = '/internal/elastic_assistant';

export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL = `${ELASTIC_AI_ASSISTANT_INTERNAL_URL}/current_user/conversations`;
export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/{id}`;
export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID_MESSAGES = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID}/messages`;

export const ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL = `${ELASTIC_AI_ASSISTANT_URL}/chat/complete`;

export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BULK_ACTION = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/_bulk_action`;
export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_FIND = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/_find`;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

/*
* NOTICE: Do not edit this file manually.
* This file is automatically generated by the OpenAPI Generator, @kbn/openapi-generator.
*
* info:
* title: Chat Complete API endpoint
* version: 2023-10-31
*/

import { z } from 'zod';

export type RootContext = z.infer<typeof RootContext>;
export const RootContext = z.literal('security');

/**
* Message role.
*/
export type ChatMessageRole = z.infer<typeof ChatMessageRole>;
export const ChatMessageRole = z.enum(['system', 'user', 'assistant']);
export type ChatMessageRoleEnum = typeof ChatMessageRole.enum;
export const ChatMessageRoleEnum = ChatMessageRole.enum;

export type MessageData = z.infer<typeof MessageData>;
export const MessageData = z.object({}).catchall(z.unknown());

/**
* AI assistant message.
*/
export type ChatMessage = z.infer<typeof ChatMessage>;
export const ChatMessage = z.object({
/**
* Message content.
*/
content: z.string().optional(),
/**
* Message role.
*/
role: ChatMessageRole,
/**
* ECS object to attach to the context of the message.
*/
data: MessageData.optional(),
fields_to_anonymize: z.array(z.string()).optional(),
});

export type ChatCompleteProps = z.infer<typeof ChatCompleteProps>;
export const ChatCompleteProps = z.object({
conversationId: z.string().optional(),
promptId: z.string().optional(),
isStream: z.boolean().optional(),
responseLanguage: z.string().optional(),
langSmithProject: z.string().optional(),
langSmithApiKey: z.string().optional(),
connectorId: z.string(),
model: z.string().optional(),
persist: z.boolean(),
messages: z.array(ChatMessage),
});

export type ChatCompleteRequestBody = z.infer<typeof ChatCompleteRequestBody>;
export const ChatCompleteRequestBody = ChatCompleteProps;
export type ChatCompleteRequestBodyInput = z.input<typeof ChatCompleteRequestBody>;
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
openapi: 3.0.0
info:
title: Chat Complete API endpoint
version: '2023-10-31'
paths:
/api/elastic_assistant/chat/complete:
post:
operationId: ChatComplete
x-codegen-enabled: true
description: Creates a model response for the given chat conversation.
summary: Creates a model response for the given chat conversation.
tags:
- Chat Complete API
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/ChatCompleteProps'
responses:
200:
description: Indicates a successful call.
content:
application/octet-stream:
schema:
type: string
format: binary
400:
description: Generic Error
content:
application/json:
schema:
type: object
properties:
statusCode:
type: number
error:
type: string
message:
type: string

components:
schemas:
RootContext:
type: string
enum:
- security

ChatMessageRole:
type: string
description: Message role.
enum:
- system
- user
- assistant

MessageData:
type: object
additionalProperties: true

ChatMessage:
type: object
description: AI assistant message.
required:
- 'role'
properties:
content:
type: string
description: Message content.
role:
$ref: '#/components/schemas/ChatMessageRole'
description: Message role.
data:
description: ECS object to attach to the context of the message.
$ref: '#/components/schemas/MessageData'
fields_to_anonymize:
type: array
items:
type: string

ChatCompleteProps:
type: object
properties:
conversationId:
type: string
promptId:
type: string
isStream:
type: boolean
responseLanguage:
type: string
langSmithProject:
type: string
langSmithApiKey:
type: string
connectorId:
type: string
model:
type: string
persist:
type: boolean
messages:
type: array
items:
$ref: '#/components/schemas/ChatMessage'
required:
- messages
- persist
- connectorId
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ export * from './attack_discovery/get_attack_discovery_route.gen';
export * from './attack_discovery/post_attack_discovery_route.gen';
export * from './attack_discovery/cancel_attack_discovery_route.gen';

// Chat Schemas
export * from './chat/post_chat_complete_route.gen';

// Evaluation Schemas
export * from './evaluation/post_evaluate_route.gen';
export * from './evaluation/get_evaluate_route.gen';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import type OpenAI from 'openai';
import { Stream } from 'openai/streaming';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { loggerMock } from '@kbn/logging-mocks';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';

import { ActionsClientChatOpenAI, ActionsClientChatOpenAIParams } from './chat_openai';
import { ActionsClientChatOpenAI } from './chat_openai';
import { mockActionResponse, mockChatCompletion } from './mocks';

const connectorId = 'mock-connector-id';
Expand All @@ -19,11 +19,8 @@ const mockExecute = jest.fn();

const mockLogger = loggerMock.create();

const mockActions = {
getActionsClientWithRequest: jest.fn().mockImplementation(() => ({
execute: mockExecute,
})),
} as unknown as ActionsPluginStart;
const actionsClient = actionsClientMock.create();

const chunk = {
object: 'chat.completion.chunk',
choices: [
Expand All @@ -40,30 +37,15 @@ export async function* asyncGenerator() {
yield chunk;
}
const mockStreamExecute = jest.fn();
const mockStreamActions = {
getActionsClientWithRequest: jest.fn().mockImplementation(() => ({
execute: mockStreamExecute,
})),
} as unknown as ActionsPluginStart;

const prompt = 'Do you know my name?';

const { signal } = new AbortController();

const mockRequest = {
params: { connectorId },
body: {
message: prompt,
subAction: 'invokeAI',
isEnabledKnowledgeBase: true,
},
} as ActionsClientChatOpenAIParams['request'];

const defaultArgs = {
actions: mockActions,
actionsClient,
connectorId,
logger: mockLogger,
request: mockRequest,
streaming: false,
signal,
timeout: 999999,
Expand All @@ -77,6 +59,7 @@ describe('ActionsClientChatOpenAI', () => {
data: mockChatCompletion,
status: 'ok',
}));
actionsClient.execute.mockImplementation(mockExecute);
});

describe('_llmType', () => {
Expand Down Expand Up @@ -116,10 +99,11 @@ describe('ActionsClientChatOpenAI', () => {
functions: [jest.fn()],
};
it('returns the expected data', async () => {
actionsClient.execute.mockImplementation(mockStreamExecute);
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
streaming: true,
actions: mockStreamActions,
actionsClient,
});

const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
Expand Down Expand Up @@ -178,16 +162,11 @@ describe('ActionsClientChatOpenAI', () => {
serviceMessage: 'action-result-service-message',
status: 'error', // <-- error status
}));

const badActions = {
getActionsClientWithRequest: jest.fn().mockImplementation(() => ({
execute: hasErrorStatus,
})),
} as unknown as ActionsPluginStart;
actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);

const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
actions: badActions,
actionsClient,
});

expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
*/

import { v4 as uuidv4 } from 'uuid';
import { KibanaRequest, Logger } from '@kbn/core/server';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { Logger } from '@kbn/core/server';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { get } from 'lodash/fp';

import { ChatOpenAI } from '@langchain/openai';
import { Stream } from 'openai/streaming';
import type OpenAI from 'openai';
import { PublicMethodsOf } from '@kbn/utility-types';
import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
import { InvokeAIActionParamsSchema, RunActionParamsSchema } from './types';

const LLM_TYPE = 'ActionsClientChatOpenAI';

export interface ActionsClientChatOpenAIParams {
actions: ActionsPluginStart;
actionsClient: PublicMethodsOf<ActionsClient>;
connectorId: string;
llmType?: string;
logger: Logger;
request: KibanaRequest;
streaming?: boolean;
traceId?: string;
maxRetries?: number;
Expand Down Expand Up @@ -54,22 +54,20 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
#temperature?: number;

// Kibana variables
#actions: ActionsPluginStart;
#actionsClient: PublicMethodsOf<ActionsClient>;
#connectorId: string;
#logger: Logger;
#request: KibanaRequest;
#actionResultData: string;
#traceId: string;
#signal?: AbortSignal;
#timeout?: number;

constructor({
actions,
actionsClient,
connectorId,
traceId = uuidv4(),
llmType,
logger,
request,
maxRetries,
model,
signal,
Expand All @@ -92,12 +90,11 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
azureOpenAIApiVersion: 'nothing',
openAIApiKey: '',
});
this.#actions = actions;
this.#actionsClient = actionsClient;
this.#connectorId = connectorId;
this.#traceId = traceId;
this.llmType = llmType ?? LLM_TYPE;
this.#logger = logger;
this.#request = request;
this.#timeout = timeout;
this.#actionResultData = '';
this.streaming = streaming;
Expand Down Expand Up @@ -146,10 +143,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
)} `
);

// create an actions client from the authenticated request context:
const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request);

const actionResult = await actionsClient.execute(requestBody);
const actionResult = await this.#actionsClient.execute(requestBody);

if (actionResult.status === 'error') {
throw new Error(`${LLM_TYPE}: ${actionResult?.message} - ${actionResult?.serviceMessage}`);
Expand Down
Loading

0 comments on commit d5a91fc

Please sign in to comment.