diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts index 9e0adc5a94d8f..7232078d2efe8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts @@ -72,10 +72,6 @@ export function createService({ return of( createFunctionRequestMessage({ name: 'context', - args: { - queries: [], - categories: [], - }, }), createFunctionResponseMessage({ name: 'context', diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts index 4bc32a2330acd..baf006844c516 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts @@ -40,34 +40,10 @@ export function registerContextFunction({ description: 'This function provides context as to what the user is looking at on their screen, and recalled documents from the knowledge base that matches their query', visibility: FunctionVisibility.Internal, - parameters: { - type: 'object', - properties: { - queries: { - type: 'array', - description: 'The query for the semantic search', - items: { - type: 'string', - }, - }, - categories: { - type: 'array', - description: - 'Categories of internal documentation that you want to search for. By default internal documentation will be excluded. Use `apm` to get internal APM documentation, `lens` to get internal Lens documentation, or both.', - items: { - type: 'string', - enum: ['apm', 'lens'], - }, - }, - }, - required: ['queries', 'categories'], - } as const, }, - async ({ arguments: args, messages, screenContexts, chat }, signal) => { + async ({ messages, screenContexts, chat }, signal) => { const { analytics } = (await resources.context.core).coreStart; - const { queries, categories } = args; - async function getContext() { const screenDescription = compact( screenContexts.map((context) => context.screenDescription) @@ -94,30 +70,21 @@ export function registerContextFunction({ messages.filter((message) => message.message.role === MessageRole.User) ); - const nonEmptyQueries = compact(queries); - - const queriesOrUserPrompt = nonEmptyQueries.length - ? nonEmptyQueries - : compact([userMessage?.message.content]); - - queriesOrUserPrompt.push(screenDescription); - - const suggestions = await retrieveSuggestions({ - client, - categories, - queries: queriesOrUserPrompt, - }); + const userPrompt = userMessage?.message.content; + const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter( + ({ text }) => text + ) as Array<{ text: string; boost?: number }>; + const suggestions = await retrieveSuggestions({ client, queries }); if (suggestions.length === 0) { - return { - content, - }; + return { content }; } try { const { relevantDocuments, scores } = await scoreSuggestions({ suggestions, - queries: queriesOrUserPrompt, + screenDescription, + userPrompt, messages, chat, signal, @@ -125,7 +92,7 @@ export function registerContextFunction({ }); analytics.reportEvent(RecallRankingEventType, { - prompt: queriesOrUserPrompt.join('|'), + prompt: queries.map((query) => query.text).join('|'), scoredDocuments: suggestions.map((suggestion) => { const llmScore = scores.find((score) => score.id === suggestion.id); return { @@ -178,15 +145,12 @@ export function registerContextFunction({ async function retrieveSuggestions({ queries, client, - categories, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; client: ObservabilityAIAssistantClient; - categories: Array<'apm' | 'lens'>; }) { const recallResponse = await client.recall({ queries, - categories, }); return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction')); @@ -208,14 +172,16 @@ const scoreFunctionArgumentsRt = t.type({ async function scoreSuggestions({ suggestions, messages, - queries, + userPrompt, + screenDescription, chat, signal, logger, }: { suggestions: Awaited>; messages: Message[]; - queries: string[]; + userPrompt: string | undefined; + screenDescription: string; chat: FunctionCallChatFunction; signal: AbortSignal; logger: Logger; @@ -237,7 +203,10 @@ async function scoreSuggestions({ - The document contains new information not mentioned before in the conversation Question: - ${queries.join('\n')} + ${userPrompt} + + Screen description: + ${screenDescription} Documents: ${JSON.stringify(indexedSuggestions, null, 2)}`); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts index 8d509271c1e37..52be33c2a372d 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts @@ -65,7 +65,16 @@ const functionRecallRoute = createObservabilityAIAssistantServerRoute({ params: t.type({ body: t.intersection([ t.type({ - queries: t.array(nonEmptyStringRt), + queries: t.array( + t.intersection([ + t.type({ + text: t.string, + }), + t.partial({ + boost: t.number, + }), + ]) + ), }), t.partial({ categories: t.array(t.string), diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts index 74cc19d8aa153..e5ea0ad0ff829 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts @@ -28,9 +28,5 @@ export function getContextFunctionRequestIfNeeded( return createFunctionRequestMessage({ name: CONTEXT_FUNCTION_NAME, - args: { - queries: [], - categories: [], - }, }); } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index 4ffc8dc926fc7..0349d597b7ba0 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -1232,7 +1232,6 @@ describe('Observability AI Assistant client', () => { role: MessageRole.Assistant, function_call: { name: CONTEXT_FUNCTION_NAME, - arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, @@ -1456,7 +1455,6 @@ describe('Observability AI Assistant client', () => { role: MessageRole.Assistant, function_call: { name: CONTEXT_FUNCTION_NAME, - arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index 803e0e904223e..9739a59125011 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -694,7 +694,7 @@ export class ObservabilityAIAssistantClient { queries, categories, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; categories?: string[]; }): Promise<{ entries: RecalledEntry[] }> => { return this.dependencies.knowledgeBaseService.recall({ @@ -757,11 +757,9 @@ export class ObservabilityAIAssistantClient { }; fetchUserInstructions = async () => { - const userInstructions = await this.dependencies.knowledgeBaseService.getUserInstructions( + return this.dependencies.knowledgeBaseService.getUserInstructions( this.dependencies.namespace, this.dependencies.user ); - - return userInstructions; }; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts index 576fd8dc5552b..7c504aa43c38c 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts @@ -303,7 +303,7 @@ export class KnowledgeBaseService { user, modelId, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; categories?: string[]; namespace: string; user?: { name: string }; @@ -311,11 +311,12 @@ export class KnowledgeBaseService { }): Promise { const query = { bool: { - should: queries.map((text) => ({ + should: queries.map(({ text, boost = 1 }) => ({ text_expansion: { 'ml.tokens': { model_text: text, model_id: modelId, + boost, }, }, })), @@ -385,7 +386,7 @@ export class KnowledgeBaseService { uiSettingsClient, modelId, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; asCurrentUser: ElasticsearchClient; uiSettingsClient: IUiSettingsClient; modelId: string; @@ -414,15 +415,16 @@ export class KnowledgeBaseService { const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`; const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`; - return queries.map((query) => { + return queries.map(({ text, boost = 1 }) => { return { bool: { should: [ { text_expansion: { [vectorField]: { - model_text: query, + model_text: text, model_id: modelId, + boost, }, }, }, @@ -470,7 +472,7 @@ export class KnowledgeBaseService { asCurrentUser, uiSettingsClient, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; categories?: string[]; user?: { name: string }; namespace: string; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx index e39bcf5d1891e..65ac65264f307 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx @@ -40,7 +40,7 @@ describe('', () => { role: 'assistant', function_call: { name: CONTEXT_FUNCTION_NAME, - arguments: '{"queries":[],"categories":[]}', + arguments: '{}', trigger: 'assistant', }, content: '', @@ -88,7 +88,7 @@ describe('', () => { role: 'assistant', function_call: { name: CONTEXT_FUNCTION_NAME, - arguments: '{"queries":[],"categories":[]}', + arguments: '{}', trigger: 'assistant', }, content: '', diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts index 01f6e8cdd7bce..eb5ed07d3ea08 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts @@ -193,7 +193,6 @@ export default function ApiTest({ getService }: FtrProviderContext) { role: MessageRole.Assistant, function_call: { name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts index ac2fa36f6b0fd..f496e42868ac8 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts @@ -72,6 +72,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { format, }) .set('kbn-xsrf', 'foo') + .set('elastic-api-version', '2023-10-31') .send({ messages, connectorId, @@ -83,13 +84,20 @@ export default function ApiTest({ getService }: FtrProviderContext) { if (err) { return reject(err); } + if (response.status !== 200) { + return reject(new Error(`${response.status}: ${JSON.stringify(response.body)}`)); + } return resolve(response); }); }); - const [conversationSimulator, titleSimulator] = await Promise.all([ - conversationInterceptor.waitForIntercept(), - titleInterceptor.waitForIntercept(), + const [conversationSimulator, titleSimulator] = await Promise.race([ + Promise.all([ + conversationInterceptor.waitForIntercept(), + titleInterceptor.waitForIntercept(), + ]), + // make sure any request failures (like 400s) are properly propagated + responsePromise.then(() => []), ]); await titleSimulator.status(200); diff --git a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts index b7c33db0a4122..3e766877c5bca 100644 --- a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts +++ b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts @@ -94,7 +94,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte content: '', function_call: { name: 'context', - arguments: '{"queries":[],"categories":[]}', + arguments: '{}', trigger: MessageRole.Assistant, }, }, @@ -290,7 +290,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({ name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), }); expect(contextResponse.name).to.eql('context'); @@ -354,7 +353,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({ name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), }); expect(contextResponse.name).to.eql('context');