Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Minor cleanup and restructure #191069

Merged
merged 14 commits into from
Aug 26, 2024
Merged
1 change: 1 addition & 0 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { Observable } from 'rxjs';
import type { InferenceTaskEventBase } from '../tasks';
import type { ToolCall, ToolCallsOf, ToolOptions } from './tools';
Expand Down
12 changes: 7 additions & 5 deletions x-pack/plugins/inference/common/connectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ export enum InferenceConnectorType {
Gemini = '.gemini',
}

const allSupportedConnectorTypes = Object.values(InferenceConnectorType);

export interface InferenceConnector {
type: InferenceConnectorType;
name: string;
connectorId: string;
}

export function isSupportedConnectorType(id: string): id is InferenceConnectorType {
return (
id === InferenceConnectorType.OpenAI ||
id === InferenceConnectorType.Bedrock ||
id === InferenceConnectorType.Gemini
);
return allSupportedConnectorTypes.includes(id as InferenceConnectorType);
}

export interface GetConnectorsResponseBody {
connectors: InferenceConnector[];
}
4 changes: 2 additions & 2 deletions x-pack/plugins/inference/public/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
* 2.0.
*/

import type { HttpStart } from '@kbn/core/public';
import { from } from 'rxjs';
import { ChatCompleteAPI } from '../../common/chat_complete';
import type { HttpStart } from '@kbn/core/public';
import type { ChatCompleteAPI } from '../../common/chat_complete';
import type { ChatCompleteRequestBody } from '../../common/chat_complete/request';
import { httpResponseIntoObservable } from '../util/http_response_into_observable';

Expand Down
12 changes: 9 additions & 3 deletions x-pack/plugins/inference/public/plugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';

import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';
import type { Logger } from '@kbn/logging';
import { createOutputApi } from '../common/output/create_output_api';
import type { GetConnectorsResponseBody } from '../common/connectors';
import { createChatCompleteApi } from './chat_complete';
import type {
ConfigSchema,
Expand Down Expand Up @@ -39,11 +41,15 @@ export class InferencePlugin

start(coreStart: CoreStart, pluginsStart: InferenceStartDependencies): InferencePublicStart {
const chatComplete = createChatCompleteApi({ http: coreStart.http });

return {
chatComplete,
output: createOutputApi(chatComplete),
getConnectors: () => {
return coreStart.http.get('/internal/inference/connectors');
getConnectors: async () => {
const res = await coreStart.http.get<GetConnectorsResponseBody>(
'/internal/inference/connectors'
);
return res.connectors;
},
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { InferenceConnectorType } from '../../../common/connectors';
import { getInferenceAdapter } from './get_inference_adapter';
import { openAIAdapter } from './openai';

describe('getInferenceAdapter', () => {
it('returns the openAI adapter for OpenAI type', () => {
expect(getInferenceAdapter(InferenceConnectorType.OpenAI)).toBe(openAIAdapter);
});

it('returns undefined for Bedrock type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined);
});

it('returns undefined for Gemini type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(undefined);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { InferenceConnectorType } from '../../../common/connectors';
import type { InferenceConnectorAdapter } from '../types';
import { openAIAdapter } from './openai';

export const getInferenceAdapter = (
connectorType: InferenceConnectorType
): InferenceConnectorAdapter | undefined => {
switch (connectorType) {
case InferenceConnectorType.OpenAI:
return openAIAdapter;

case InferenceConnectorType.Bedrock:
// not implemented yet
break;

case InferenceConnectorType.Gemini:
// not implemented yet
break;
}

return undefined;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { getInferenceAdapter } from './get_inference_adapter';
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
*/

import OpenAI from 'openai';
import { openAIAdapter } from '.';
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete';
import { v4 } from 'uuid';
import { PassThrough } from 'stream';
import { pick } from 'lodash';
import { lastValueFrom, Subject, toArray } from 'rxjs';
import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete';
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
import { v4 } from 'uuid';
import { InferenceExecutor } from '../../utils/inference_executor';
import { openAIAdapter } from '.';

function createOpenAIChunk({
delta,
Expand All @@ -39,38 +39,27 @@ function createOpenAIChunk({
}

describe('openAIAdapter', () => {
const actionsClientMock = {
execute: jest.fn(),
} as ActionsClient & { execute: jest.MockedFn<ActionsClient['execute']> };
const executorMock = {
invoke: jest.fn(),
} as InferenceExecutor & { invoke: jest.MockedFn<InferenceExecutor['invoke']> };

beforeEach(() => {
actionsClientMock.execute.mockReset();
executorMock.invoke.mockReset();
});

const defaultArgs = {
connector: {
id: 'foo',
actionTypeId: '.gen-ai',
name: 'OpenAI',
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
},
actionsClient: actionsClientMock,
executor: executorMock,
};

describe('when creating the request', () => {
function getRequest() {
const params = actionsClientMock.execute.mock.calls[0][0].params.subActionParams as Record<
string,
any
>;
const params = executorMock.invoke.mock.calls[0][0].subActionParams as Record<string, any>;

return { stream: params.stream, body: JSON.parse(params.body) };
}

beforeEach(() => {
actionsClientMock.execute.mockImplementation(async () => {
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
Expand Down Expand Up @@ -262,7 +251,7 @@ describe('openAIAdapter', () => {
beforeEach(() => {
source$ = new Subject<Record<string, any>>();

actionsClientMock.execute.mockImplementation(async () => {
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,60 +21,30 @@ import {
Message,
MessageRole,
} from '../../../../common/chat_complete';
import type { ToolOptions } from '../../../../common/chat_complete/tools';
import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors';
import { createInferenceInternalError } from '../../../../common/errors';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { InferenceConnectorAdapter } from '../../types';
import { eventSourceStreamIntoObservable } from '../event_source_stream_into_observable';

export const openAIAdapter: InferenceConnectorAdapter = {
chatComplete: ({ connector, actionsClient, system, messages, toolChoice, tools }) => {
const openAIMessages = messagesToOpenAI({ system, messages });

const toolChoiceForOpenAI =
typeof toolChoice === 'string'
? toolChoice
: toolChoice
? {
function: {
name: toolChoice.function,
},
type: 'function' as const,
}
: undefined;

chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
const stream = true;

const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
stream,
messages: openAIMessages,
messages: messagesToOpenAI({ system, messages }),
tool_choice: toolChoiceToOpenAI(toolChoice),
tools: toolsToOpenAI(tools),
temperature: 0,
tool_choice: toolChoiceForOpenAI,
tools: tools
? Object.entries(tools).map(([toolName, { description, schema }]) => {
return {
type: 'function',
function: {
name: toolName,
description,
parameters: (schema ?? {
type: 'object' as const,
properties: {},
}) as unknown as Record<string, unknown>,
},
};
})
: undefined,
};

return from(
actionsClient.execute({
actionId: connector.id,
params: {
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
stream,
},
executor.invoke({
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
stream,
},
})
).pipe(
Expand Down Expand Up @@ -125,6 +95,39 @@ export const openAIAdapter: InferenceConnectorAdapter = {
},
};

function toolsToOpenAI(tools: ToolOptions['tools']): OpenAI.ChatCompletionCreateParams['tools'] {
return tools
? Object.entries(tools).map(([toolName, { description, schema }]) => {
return {
type: 'function',
function: {
name: toolName,
description,
parameters: (schema ?? {
type: 'object' as const,
properties: {},
}) as unknown as Record<string, unknown>,
},
};
})
: undefined;
}

function toolChoiceToOpenAI(
toolChoice: ToolOptions['toolChoice']
): OpenAI.ChatCompletionCreateParams['tool_choice'] {
return typeof toolChoice === 'string'
? toolChoice
: toolChoice
? {
function: {
name: toolChoice.function,
},
type: 'function' as const,
}
: undefined;
}

function messagesToOpenAI({
system,
messages,
Expand Down
63 changes: 63 additions & 0 deletions x-pack/plugins/inference/server/chat_complete/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { KibanaRequest } from '@kbn/core-http-server';
import { defer, switchMap, throwError } from 'rxjs';
import type { ChatCompleteAPI, ChatCompletionResponse } from '../../common/chat_complete';
import { createInferenceRequestError } from '../../common/errors';
import type { InferenceStartDependencies } from '../types';
import { getConnectorById } from '../util/get_connector_by_id';
import { getInferenceAdapter } from './adapters';
import { createInferenceExecutor, chunksIntoMessage } from './utils';

export function createChatCompleteApi({
request,
actions,
}: {
request: KibanaRequest;
actions: InferenceStartDependencies['actions'];
}) {
const chatCompleteAPI: ChatCompleteAPI = ({
connectorId,
messages,
toolChoice,
tools,
system,
}): ChatCompletionResponse => {
return defer(async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await getConnectorById({ connectorId, actionsClient });
const executor = createInferenceExecutor({ actionsClient, connector });
return { executor, connector };
}).pipe(
switchMap(({ executor, connector }) => {
const connectorType = connector.type;
const inferenceAdapter = getInferenceAdapter(connectorType);

if (!inferenceAdapter) {
return throwError(() =>
createInferenceRequestError(`Adapter for type ${connectorType} not implemented`, 400)
);
}

return inferenceAdapter.chatComplete({
system,
executor,
messages,
toolChoice,
tools,
});
}),
chunksIntoMessage({
toolChoice,
tools,
})
);
};

return chatCompleteAPI;
}
Loading