Skip to content

Commit

Permalink
Merge branch 'main' into pattern-analysis-tab-in-discover
Browse files Browse the repository at this point in the history
  • Loading branch information
jgowdyelastic authored May 22, 2024
2 parents fd42267 + 51f9eed commit ec9b46e
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ export const crossClusterApiKeySchema = restApiKeySchema.extends({
schema.arrayOf(
schema.object({
names: schema.arrayOf(schema.string()),
query: schema.maybe(schema.any()),
field_security: schema.maybe(schema.any()),
allow_restricted_indices: schema.maybe(schema.boolean()),
})
)
),
Expand Down
154 changes: 106 additions & 48 deletions x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,30 @@ import type { Client } from '@elastic/elasticsearch';
import { createAssist as Assist } from '../utils/assist';
import { ConversationalChain } from './conversational_chain';
import { FakeListChatModel } from '@langchain/core/utils/testing';
import { FakeListLLM } from 'langchain/llms/fake';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Message } from 'ai';

describe('conversational chain', () => {
const createTestChain = async (
responses: string[],
chat: Message[],
expectedFinalAnswer: string,
expectedDocs: any,
expectedTokens: any,
expectedSearchRequest: any,
contentField: Record<string, string> = { index: 'field', website: 'body_content' }
) => {
const createTestChain = async ({
responses,
chat,
expectedFinalAnswer,
expectedDocs,
expectedTokens,
expectedSearchRequest,
contentField = { index: 'field', website: 'body_content' },
isChatModel = true,
}: {
responses: string[];
chat: Message[];
expectedFinalAnswer: string;
expectedDocs: any;
expectedTokens: any;
expectedSearchRequest: any;
contentField?: Record<string, string>;
isChatModel?: boolean;
}) => {
const searchMock = jest.fn().mockImplementation(() => {
return {
hits: {
Expand Down Expand Up @@ -54,9 +65,11 @@ describe('conversational chain', () => {
},
};

const llm = new FakeListChatModel({
responses,
});
const llm = isChatModel
? new FakeListChatModel({
responses,
})
: new FakeListLLM({ responses });

const aiClient = Assist({
es_client: mockElasticsearchClient as unknown as Client,
Expand Down Expand Up @@ -118,17 +131,17 @@ describe('conversational chain', () => {
};

it('should be able to create a conversational chain', async () => {
await createTestChain(
['the final answer'],
[
await createTestChain({
responses: ['the final answer'],
chat: [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -137,32 +150,32 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 },
},
]
);
],
});
});

it('should be able to create a conversational chain with nested field', async () => {
await createTestChain(
['the final answer'],
[
await createTestChain({
responses: ['the final answer'],
chat: [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -171,25 +184,25 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 },
},
],
{ index: 'field', website: 'metadata.source' }
);
contentField: { index: 'field', website: 'metadata.source' },
});
});

it('asking with chat history should re-write the question', async () => {
await createTestChain(
['rewrite the question', 'the final answer'],
[
await createTestChain({
responses: ['rewrite the question', 'the final answer'],
chat: [
{
id: '1',
role: 'user',
Expand All @@ -206,8 +219,8 @@ describe('conversational chain', () => {
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -216,24 +229,24 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite the question' } }, size: 3 },
},
]
);
],
});
});

it('should cope with quotes in the query', async () => {
await createTestChain(
['rewrite "the" question', 'the final answer'],
[
await createTestChain({
responses: ['rewrite "the" question', 'the final answer'],
chat: [
{
id: '1',
role: 'user',
Expand All @@ -250,8 +263,8 @@ describe('conversational chain', () => {
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -260,17 +273,62 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 },
},
]
);
],
});
});

it('should work with an LLM based model', async () => {
await createTestChain({
responses: ['rewrite "the" question', 'the final answer'],
chat: [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
{
id: '2',
role: 'assistant',
content: 'the final answer',
},
{
id: '3',
role: 'user',
content: 'what is the work from home policy?',
},
],
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 7 },
],
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 },
},
],
isChatModel: false,
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ConversationalChainFn {
{
callbacks: [
{
// callback for chat based models (OpenAI)
handleChatModelStart(
llm,
msg: BaseMessage[][],
Expand All @@ -166,6 +167,15 @@ class ConversationalChainFn {
});
}
},
// callback for prompt based models (Bedrock uses ActionsClientLlm)
handleLLMStart(llm, input, runId, parentRunId, extraParams, tags, metadata) {
if (metadata?.type === 'question_answer_qa') {
data.appendMessageAnnotation({
type: 'prompt_token_count',
count: getTokenEstimate(input[0]),
});
}
},
handleRetrieverEnd(documents) {
retrievedDocs.push(...documents);
data.appendMessageAnnotation({
Expand Down
Loading

0 comments on commit ec9b46e

Please sign in to comment.