Skip to content

Commit

Permalink
[Security Assistant] Migrates to LangGraph and adds KB Tools (#184554)
Browse files Browse the repository at this point in the history
## Summary

Migrates our existing RAG pipeline to use LangGraph, and adds tools for
Knowledge Base retrieval/storage.

When the `assistantKnowledgeBaseByDefault` FF is enabled, a new branch,
`callAssistantGraph()`, is taken in `postActionsConnectorExecuteRoute`
that exercises the LangGraph implementation. This is a drop-in
replacement for the existing `callAgentExecutor()` in effort to keep
adoption as clean and easy as possible.

The new control flow is as follows:

`postActionsConnectorExecuteRoute` -> `callAssistantGraph()` ->
`getDefaultAssistantGraph()` -> `isStreamingEnabled ? streamGraph() :
invokeGraph()`

Graph creation is isolated to `getDefaultAssistantGraph()`, and
execution (streaming or not) has been extracted to `streamGraph()` and
`invokeGraph()` respectively. Note: Streaming currently only works with
`ChatOpenAI` models, but `SimpleChatModelStreaming` was de-risked and
just need to discuss potential solutions with @stephmilovic. See
[comment
here](https://github.com/elastic/kibana/pull/184554/files#diff-ad87c5621b231a40810419fc1e56f28aeb4f8328e125e465dfe95ae0e1c305b8R97-R98).

#### DefaultAssistantGraph

To start with a predictable and piecemeal migration, our existing
`agentExecutor` pipeline has been recreated in LangGraph. It consists of
a single agent node, either `OpenAIFunctionsAgent`, or
`StructuredChatAgent` (depending on the backing LLM), a tool executing
node, and a conditional edge that routes between the two nodes until
there are no more function calls chosen by the agent. This varies from
our initial implementation in that multiple tool calls are now
supported, so a user could ask about their alerts AND retrieve
additional knowledge base information in the same response.

> [!NOTE]
> While `chat_history` has been plumbed into the graph, after discussing
with @YulNaumenko we decided to wait to plumb the rest of persistence
into the graph until #184485 is merged. I had already plumbed through
the `chatTitleGeneration` node
([here](https://github.com/elastic/kibana/pull/184554/files#diff-26038489e9a3f1a14c5ea2ac2954671973d833349ef3ffaddcf9b29ce9e2b96eR33)),
and so will just need to include initial conversation creation and
append/update operations.

#### Knowledge History & KB Tools

Knowledge History is now always added in the initial prompt for any KB
documents marked as `required`, and two new tools were added for
creating and recalling KB entries from within the conversation,
`KnowledgeBaseWriteTool` and `KnowledgeBaseRetrievalTool` respectively.
All three methods of storing and retrieving KB content use the
`kbDataClient` for access, and scopes all requests to the
authenticatedUser that made the initial request.




Additional Notes:
* LangChain dependencies have been updated, and a new dependency on
`LangGraph` has been added.



### Checklist

Delete any items that are not applicable to this PR.

- [X] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
* Feature currently behind a FF, documentation to be added once feature
is complete. Tracked in
elastic/security-docs#5337.
- [ ] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
  * Test coverage in progress...

---------

Co-authored-by: Patryk Kopycinski <contact@patrykkopycinski.com>
Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent 011b7eb commit 199eb64
Show file tree
Hide file tree
Showing 29 changed files with 1,130 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ export const KnowledgeBaseEntryErrorSchema = z
export type Metadata = z.infer<typeof Metadata>;
export const Metadata = z.object({
/**
* Knowledge Base resource name
* Knowledge Base resource name for grouping entries, e.g. 'esql', 'lens-docs', etc
*/
kbResource: z.string(),
/**
* Original text content source
* Source document name or filepath
*/
source: z.string(),
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ components:
properties:
kbResource:
type: string
description: Knowledge Base resource name
description: Knowledge Base resource name for grouping entries, e.g. 'esql', 'lens-docs', etc
source:
type: string
description: Original text content source
description: Source document name or filepath
required:
type: boolean
description: Whether or not this resource should always be included
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import { CreateKnowledgeBaseEntrySchema } from './types';

export interface CreateKnowledgeBaseEntryParams {
esClient: ElasticsearchClient;
logger: Logger;
knowledgeBaseIndex: string;
logger: Logger;
spaceId: string;
user: AuthenticatedUser;
knowledgeBaseEntry: KnowledgeBaseEntryCreateProps;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*/

import { errors } from '@elastic/elasticsearch';
import { QueryDslQueryContainer } from '@elastic/elasticsearch/lib/api/types';
import { AuthenticatedUser } from '@kbn/core-security-common';

export const isModelAlreadyExistsError = (error: Error) => {
return (
Expand All @@ -14,3 +16,87 @@ export const isModelAlreadyExistsError = (error: Error) => {
error.body.error.type === 'status_exception')
);
};

/**
* Returns an Elasticsearch query DSL that performs a vector search against the Knowledge Base for the given query/user/filter.
*
* @param filter - Optional filter to apply to the search
* @param kbResource - Specific resource tag to filter for, e.g. 'esql' or 'user'
* @param modelId - ID of the model to search with, e.g. `.elser_model_2`
* @param query - The search query provided by the user
* @param required - Whether to only include required entries
* @param user - The authenticated user
* @returns
*/
export const getKBVectorSearchQuery = ({
filter,
kbResource,
modelId,
query,
required,
user,
}: {
filter?: QueryDslQueryContainer | undefined;
kbResource?: string | undefined;
modelId: string;
query: string;
required?: boolean | undefined;
user: AuthenticatedUser;
}): QueryDslQueryContainer => {
const resourceFilter = kbResource
? [
{
term: {
'metadata.kbResource': kbResource,
},
},
]
: [];
const requiredFilter = required
? [
{
term: {
'metadata.required': required,
},
},
]
: [];

const userFilter = [
{
nested: {
path: 'users',
query: {
bool: {
must: [
{
match: user.profile_uid
? { 'users.id': user.profile_uid }
: { 'users.name': user.username },
},
],
},
},
},
},
];

return {
bool: {
must: [
{
text_expansion: {
'vector.tokens': {
model_id: modelId,
model_text: query,
},
},
},
...requiredFilter,
...resourceFilter,
...userFilter,
],
filter,
},
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@ import {
} from '@elastic/elasticsearch/lib/api/types';
import type { MlPluginSetup } from '@kbn/ml-plugin/server';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { Document } from 'langchain/document';
import { Document } from 'langchain/document';
import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { KnowledgeBaseEntryResponse } from '@kbn/elastic-assistant-common';
import {
KnowledgeBaseEntryCreateProps,
KnowledgeBaseEntryResponse,
} from '@kbn/elastic-assistant-common';
import pRetry from 'p-retry';
import { QueryDslQueryContainer } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { AIAssistantDataClient, AIAssistantDataClientParams } from '..';
import { ElasticsearchStore } from '../../lib/langchain/elasticsearch_store/elasticsearch_store';
import { loadESQL } from '../../lib/langchain/content_loaders/esql_loader';
import { GetElser } from '../../types';
import { transformToCreateSchema } from './create_knowledge_base_entry';
import { createKnowledgeBaseEntry, transformToCreateSchema } from './create_knowledge_base_entry';
import { EsKnowledgeBaseEntrySchema } from './types';
import { transformESSearchToKnowledgeBaseEntry } from './transforms';
import { ESQL_DOCS_LOADED_QUERY } from '../../routes/knowledge_base/constants';
import { isModelAlreadyExistsError } from './helpers';
import { getKBVectorSearchQuery, isModelAlreadyExistsError } from './helpers';

interface KnowledgeBaseDataClientParams extends AIAssistantDataClientParams {
ml: MlPluginSetup;
Expand Down Expand Up @@ -217,8 +221,7 @@ export class AIAssistantKnowledgeBaseDataClient extends AIAssistantDataClient {
/**
* Adds LangChain Documents to the knowledge base
*
* @param documents
* @param authenticatedUser
* @param documents LangChain Documents to add to the knowledge base
*/
public addKnowledgeBaseDocuments = async ({
documents,
Expand Down Expand Up @@ -261,4 +264,100 @@ export class AIAssistantKnowledgeBaseDataClient extends AIAssistantDataClient {

return created?.data ? transformESSearchToKnowledgeBaseEntry(created?.data) : [];
};

/**
* Performs similarity search to retrieve LangChain Documents from the knowledge base
*/
public getKnowledgeBaseDocuments = async ({
filter,
kbResource,
query,
required,
}: {
filter?: QueryDslQueryContainer;
kbResource?: string;
query: string;
required?: boolean;
}): Promise<Document[]> => {
const user = this.options.currentUser;
if (user == null) {
throw new Error(
'Authenticated user not found! Ensure kbDataClient was initialized from a request.'
);
}

const esClient = await this.options.elasticsearchClientPromise;
const modelId = await this.options.getElserId();

const vectorSearchQuery = getKBVectorSearchQuery({
filter,
kbResource,
modelId,
query,
required,
user,
});

try {
const result = await esClient.search<EsKnowledgeBaseEntrySchema>({
index: this.indexTemplateAndPattern.alias,
size: 10,
query: vectorSearchQuery,
});

const results = result.hits.hits.map(
(hit) =>
new Document({
pageContent: hit?._source?.text ?? '',
metadata: hit?._source?.metadata ?? {},
})
);

this.options.logger.debug(
`getKnowledgeBaseDocuments() - Similarity Search Query:\n ${JSON.stringify(
vectorSearchQuery
)}`
);
this.options.logger.debug(
`getKnowledgeBaseDocuments() - Similarity Search Results:\n ${JSON.stringify(results)}`
);

return results;
} catch (e) {
this.options.logger.error(`Error performing KB Similarity Search: ${e.message}`);
return [];
}
};

/**
* Creates a new Knowledge Base Entry.
*
* @param knowledgeBaseEntry
*/
public createKnowledgeBaseEntry = async ({
knowledgeBaseEntry,
}: {
knowledgeBaseEntry: KnowledgeBaseEntryCreateProps;
}): Promise<KnowledgeBaseEntryResponse | null> => {
const authenticatedUser = this.options.currentUser;
if (authenticatedUser == null) {
throw new Error(
'Authenticated user not found! Ensure kbDataClient was initialized from a request.'
);
}

this.options.logger.debug(
`Creating Knowledge Base Entry:\n ${JSON.stringify(knowledgeBaseEntry, null, 2)}`
);
this.options.logger.debug(`kbIndex: ${this.indexTemplateAndPattern.alias}`);
const esClient = await this.options.elasticsearchClientPromise;
return createKnowledgeBaseEntry({
esClient,
knowledgeBaseIndex: this.indexTemplateAndPattern.alias,
logger: this.options.logger,
spaceId: this.spaceId,
user: authenticatedUser,
knowledgeBaseEntry,
});
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({

// Fetch any applicable tools that the source plugin may have registered
const assistantToolParams: AssistantToolParams = {
anonymizationFields,
alertsIndexPattern,
isEnabledKnowledgeBase,
anonymizationFields,
chain,
llm,
esClient,
isEnabledKnowledgeBase,
llm,
logger,
modelExists,
onNewReplacements,
replacements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,29 @@ import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/s
import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import { BaseMessage } from '@langchain/core/messages';
import { Logger } from '@kbn/logging';
import { KibanaRequest, ResponseHeaders } from '@kbn/core-http-server';
import { KibanaRequest, KibanaResponseFactory, ResponseHeaders } from '@kbn/core-http-server';
import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain';
import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common';
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store';
import { AIAssistantKnowledgeBaseDataClient } from '../../../ai_assistant_data_clients/knowledge_base';
import { AIAssistantConversationsDataClient } from '../../../ai_assistant_data_clients/conversations';
import { AIAssistantDataClient } from '../../../ai_assistant_data_clients';

export type OnLlmResponse = (
content: string,
traceData?: Message['traceData'],
isError?: boolean
) => Promise<void>;

export interface AssistantDataClients {
anonymizationFieldsDataClient?: AIAssistantDataClient;
conversationsDataClient?: AIAssistantConversationsDataClient;
kbDataClient?: AIAssistantKnowledgeBaseDataClient;
}

export interface AgentExecutorParams<T extends boolean> {
abortSignal?: AbortSignal;
Expand All @@ -26,6 +41,8 @@ export interface AgentExecutorParams<T extends boolean> {
isEnabledKnowledgeBase: boolean;
assistantTools?: AssistantTool[];
connectorId: string;
conversationId?: string;
dataClients?: AssistantDataClients;
esClient: ElasticsearchClient;
esStore: ElasticsearchStore;
langChainMessages: BaseMessage[];
Expand All @@ -34,12 +51,9 @@ export interface AgentExecutorParams<T extends boolean> {
onNewReplacements?: (newReplacements: Replacements) => void;
replacements: Replacements;
isStream?: T;
onLlmResponse?: (
content: string,
traceData?: Message['traceData'],
isError?: boolean
) => Promise<void>;
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
response?: KibanaResponseFactory;
size?: number;
traceOptions?: TraceOptions;
}
Expand Down
Loading

0 comments on commit 199eb64

Please sign in to comment.