diff --git a/opensearch_dashboards.json b/opensearch_dashboards.json index b9733635..ca95ebdc 100644 --- a/opensearch_dashboards.json +++ b/opensearch_dashboards.json @@ -12,7 +12,7 @@ "dataSourceManagement" ], "optionalPlugins": [ - "securityDashboards" + "securityDashboards","dataSource", "dataSourceManagement" ], "configPath": [ "assistant" diff --git a/public/chat_flyout.tsx b/public/chat_flyout.tsx index 8f3af9e4..73c4092d 100644 --- a/public/chat_flyout.tsx +++ b/public/chat_flyout.tsx @@ -7,6 +7,7 @@ import { EuiResizableContainer } from '@elastic/eui'; import cs from 'classnames'; import React, { useMemo, useRef } from 'react'; import { useObservable } from 'react-use'; +import { BehaviorSubject } from 'rxjs'; import { useChatContext } from './contexts/chat_context'; import { ChatPage } from './tabs/chat/chat_page'; import { ChatWindowHeader } from './tabs/chat_window_header'; @@ -25,8 +26,10 @@ export const ChatFlyout = (props: ChatFlyoutProps) => { const chatContext = useChatContext(); const chatHistoryPageLoadedRef = useRef(false); const core = useCore(); + // TODO: use DataSourceService to replace const selectedDS = useObservable( - core.services.setupDeps.dataSourceManagement.dataSourceSelection.getSelection$() + core.services.setupDeps?.dataSourceManagement?.dataSourceSelection?.getSelection$() ?? + new BehaviorSubject(new Map()) ); const currentDS = useMemo(() => { return selectedDS?.values().next().value; @@ -97,7 +100,7 @@ export const ChatFlyout = (props: ChatFlyoutProps) => {
Data Source: - {currentDS?.[0].label} + {currentDS?.[0]?.label}
diff --git a/public/components/feedback_modal.tsx b/public/components/feedback_modal.tsx deleted file mode 100644 index e9889da3..00000000 --- a/public/components/feedback_modal.tsx +++ /dev/null @@ -1,309 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - EuiButton, - EuiButtonEmpty, - EuiForm, - EuiFormRow, - EuiModal, - EuiModalBody, - EuiModalFooter, - EuiModalHeader, - EuiModalHeaderTitle, - EuiRadioGroup, - EuiTextArea, -} from '@elastic/eui'; -import React, { useState } from 'react'; -import { HttpStart } from '../../../../src/core/public'; -import { ASSISTANT_API } from '../../common/constants/llm'; -import { getCoreStart } from '../plugin'; - -export interface LabelData { - formHeader: string; - inputPlaceholder: string; - outputPlaceholder: string; -} - -export interface FeedbackFormData { - input: string; - output: string; - correct: boolean | undefined; - expectedOutput: string; - comment: string; -} - -interface FeedbackMetaData { - type: 'event_analytics' | 'chat' | 'ppl_submit'; - conversationId?: string; - interactionId?: string; - error?: boolean; - selectedIndex?: string; -} - -interface FeedbackModelProps { - input?: string; - output?: string; - metadata: FeedbackMetaData; - onClose: () => void; -} - -export const FeedbackModal: React.FC = (props) => { - const [formData, setFormData] = useState({ - input: props.input ?? '', - output: props.output ?? '', - correct: undefined, - expectedOutput: '', - comment: '', - }); - return ( - - - - ); -}; - -interface FeedbackModalContentProps { - formData: FeedbackFormData; - setFormData: React.Dispatch>; - metadata: FeedbackMetaData; - displayLabels?: Partial> & Partial; - onClose: () => void; -} - -export const FeedbackModalContent: React.FC = (props) => { - const core = getCoreStart(); - const labels: NonNullable> = Object.assign( - { - formHeader: 'Olly Skills Feedback', - inputPlaceholder: 'Your input question', - input: 'Input question', - outputPlaceholder: 'The LLM response', - output: 'Output', - correct: 'Does the output match your expectations?', - expectedOutput: 'Expected output', - comment: 'Comment', - }, - props.displayLabels - ); - const { loading, submitFeedback } = useSubmitFeedback(props.formData, props.metadata, core.http); - const [formErrors, setFormErrors] = useState< - Partial<{ [x in keyof FeedbackFormData]: string[] }> - >({ - input: [], - output: [], - expectedOutput: [], - }); - - const hasError = (key?: keyof FeedbackFormData) => { - if (!key) return Object.values(formErrors).some((e) => !!e.length); - return !!formErrors[key]?.length; - }; - - const onSubmit = async (event: React.FormEvent) => { - event.preventDefault(); - const errors = { - input: validator - .input(props.formData.input) - .concat(await validator.validateQuery(props.formData.input, props.metadata.type)), - output: validator.output(props.formData.output), - correct: validator.correct(props.formData.correct), - expectedOutput: validator.expectedOutput( - props.formData.expectedOutput, - props.formData.correct === false - ), - }; - if (Object.values(errors).some((e) => !!e.length)) { - setFormErrors(errors); - return; - } - - try { - await submitFeedback(); - props.setFormData({ - input: '', - output: '', - correct: undefined, - expectedOutput: '', - comment: '', - }); - core.notifications.toasts.addSuccess('Thanks for your feedback!'); - props.onClose(); - } catch (e) { - core.notifications.toasts.addError(e, { title: 'Failed to submit feedback' }); - } - }; - - return ( - <> - - {labels.formHeader} - - - - - - props.setFormData({ ...props.formData, input: e.target.value })} - onBlur={(e) => { - setFormErrors({ ...formErrors, input: validator.input(e.target.value) }); - }} - isInvalid={hasError('input')} - /> - - - props.setFormData({ ...props.formData, output: e.target.value })} - onBlur={(e) => { - setFormErrors({ ...formErrors, output: validator.output(e.target.value) }); - }} - isInvalid={hasError('output')} - /> - - {props.metadata.type !== 'ppl_submit' && ( - - { - props.setFormData({ ...props.formData, correct: id === 'yes' }); - setFormErrors({ ...formErrors, expectedOutput: [] }); - }} - onBlur={() => setFormErrors({ ...formErrors, correct: [] })} - /> - - )} - {props.formData.correct === false && ( - - - props.setFormData({ ...props.formData, expectedOutput: e.target.value }) - } - onBlur={(e) => { - setFormErrors({ - ...formErrors, - expectedOutput: validator.expectedOutput( - e.target.value, - props.formData.correct === false - ), - }); - }} - isInvalid={hasError('expectedOutput')} - /> - - )} - - props.setFormData({ ...props.formData, comment: e.target.value })} - /> - - - - - - Cancel - - Send - - - - ); -}; - -const useSubmitFeedback = (data: FeedbackFormData, metadata: FeedbackMetaData, http: HttpStart) => { - const [loading, setLoading] = useState(false); - return { - loading, - submitFeedback: async () => { - setLoading(true); - const auth = await http - .get<{ data: { user_name: string; user_requested_tenant: string; roles: string[] } }>( - '/api/v1/configuration/account' - ) - .then((res) => ({ user: res.data.user_name, tenant: res.data.user_requested_tenant })); - - return http - .post(ASSISTANT_API.FEEDBACK, { - body: JSON.stringify({ metadata: { ...metadata, ...auth }, ...data }), - }) - .finally(() => setLoading(false)); - }, - }; -}; - -const validatePPLQuery = async (logsQuery: string, feedBackType: FeedbackMetaData['type']) => { - return []; - // TODO remove - // let responseMessage: [] | string[] = []; - // const errorMessage = [' Invalid PPL Query, please re-check the ppl syntax']; - - // if (feedBackType === 'ppl_submit') { - // const pplService = getPPLService(); - // await pplService - // .fetch({ query: logsQuery, format: 'jdbc' }) - // .then((res) => { - // if (res === undefined) responseMessage = errorMessage; - // }) - // .catch((error: Error) => { - // responseMessage = errorMessage; - // }); - // } - // return responseMessage; -}; - -const validator = { - input: (text: string) => (text.trim().length === 0 ? ['Input is required'] : []), - output: (text: string) => (text.trim().length === 0 ? ['Output is required'] : []), - correct: (correct: boolean | undefined) => - correct === undefined ? ['Correctness is required'] : [], - expectedOutput: (text: string, required: boolean) => - required && text.trim().length === 0 ? ['expectedOutput is required'] : [], - validateQuery: async (logsQuery: string, feedBackType: FeedbackMetaData['type']) => - await validatePPLQuery(logsQuery, feedBackType), -}; diff --git a/public/contexts/core_context.tsx b/public/contexts/core_context.tsx index c2d4ff08..6db5dd2e 100644 --- a/public/contexts/core_context.tsx +++ b/public/contexts/core_context.tsx @@ -8,13 +8,15 @@ import { useOpenSearchDashboards, } from '../../../../src/plugins/opensearch_dashboards_react/public'; import { AssistantPluginStartDependencies, AssistantPluginSetupDependencies } from '../types'; -import { ConversationLoadService, ConversationsService } from '../services'; +import { ConversationLoadService, ConversationsService, DataSourceService } from '../services'; export interface AssistantServices extends Required { setupDeps: AssistantPluginSetupDependencies; startDeps: AssistantPluginStartDependencies; conversationLoad: ConversationLoadService; conversations: ConversationsService; + // This service is maintained in chatbot instead of dataSource exported from core plugin. + dataSource: DataSourceService; } export const useCore: () => OpenSearchDashboardsReactContextValue< diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx index a8f6358b..6723a0c6 100644 --- a/public/hooks/use_chat_actions.tsx +++ b/public/hooks/use_chat_actions.tsx @@ -35,6 +35,7 @@ export const useChatActions = (): AssistantActions => { ...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats input, }), + query: core.services.dataSource.getDataSourceQuery(), }); if (abortController.signal.aborted) return; // Refresh history list after new conversation created if new conversation saved and history list page visible @@ -162,6 +163,7 @@ export const useChatActions = (): AssistantActions => { // abort agent execution await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, { body: JSON.stringify({ conversationId }), + query: core.services.dataSource.getDataSourceQuery(), }); } }; @@ -178,6 +180,7 @@ export const useChatActions = (): AssistantActions => { conversationId: chatContext.conversationId, interactionId, }), + query: core.services.dataSource.getDataSourceQuery(), }); if (abortController.signal.aborted) { diff --git a/public/hooks/use_conversations.ts b/public/hooks/use_conversations.ts index 89e1a058..9fbed688 100644 --- a/public/hooks/use_conversations.ts +++ b/public/hooks/use_conversations.ts @@ -20,6 +20,7 @@ export const useDeleteConversation = () => { return core.services.http .delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, { signal: abortControllerRef.current.signal, + query: core.services.dataSource.getDataSourceQuery(), }) .then((payload) => { dispatch({ type: 'success', payload }); diff --git a/public/hooks/use_feed_back.tsx b/public/hooks/use_feed_back.tsx index 50e56380..7a37bd37 100644 --- a/public/hooks/use_feed_back.tsx +++ b/public/hooks/use_feed_back.tsx @@ -38,6 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => { try { await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, { body: JSON.stringify(body), + query: core.services.dataSource.getDataSourceQuery(), }); setFeedbackResult(correct); } catch (error) { diff --git a/public/plugin.tsx b/public/plugin.tsx index 2d15682f..1ef050ae 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -31,6 +31,7 @@ import { setIncontextInsightRegistry, } from './services'; import { ConfigSchema } from '../common/types/config'; +import { DataSourceService } from './services/data_source_service'; export const [getCoreStart, setCoreStart] = createGetterSetter('CoreStart'); @@ -58,9 +59,11 @@ export class AssistantPlugin > { private config: ConfigSchema; incontextInsightRegistry: IncontextInsightRegistry | undefined; + private dataSourceService: DataSourceService; constructor(initializerContext: PluginInitializerContext) { this.config = initializerContext.config.get(); + this.dataSourceService = new DataSourceService(); } public setup( @@ -103,8 +106,9 @@ export class AssistantPlugin ...coreStart, setupDeps, startDeps, - conversationLoad: new ConversationLoadService(coreStart.http), - conversations: new ConversationsService(coreStart.http), + conversationLoad: new ConversationLoadService(coreStart.http, this.dataSourceService), + conversations: new ConversationsService(coreStart.http, this.dataSourceService), + dataSource: this.dataSourceService, }); const account = await getAccount(); const username = account.data.user_name; diff --git a/public/services/conversation_load_service.ts b/public/services/conversation_load_service.ts index 44b2e970..b2875a20 100644 --- a/public/services/conversation_load_service.ts +++ b/public/services/conversation_load_service.ts @@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs'; import { HttpStart } from '../../../../src/core/public'; import { IConversation } from '../../common/types/chat_saved_object_attributes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { DataSourceService } from './data_source_service'; export class ConversationLoadService { status$: BehaviorSubject< @@ -14,7 +15,7 @@ export class ConversationLoadService { > = new BehaviorSubject<'idle' | 'loading' | { status: 'error'; error: Error }>('idle'); abortController?: AbortController; - constructor(private _http: HttpStart) {} + constructor(private _http: HttpStart, private _dataSource: DataSourceService) {} load = async (conversationId: string) => { this.abortController?.abort(); @@ -25,6 +26,7 @@ export class ConversationLoadService { `${ASSISTANT_API.CONVERSATION}/${conversationId}`, { signal: this.abortController.signal, + query: this._dataSource.getDataSourceQuery(), } ); this.status$.next('idle'); diff --git a/public/services/conversations_service.ts b/public/services/conversations_service.ts index 4f95794a..3c1faf9d 100644 --- a/public/services/conversations_service.ts +++ b/public/services/conversations_service.ts @@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs'; import { HttpFetchQuery, HttpStart, SavedObjectsFindOptions } from '../../../../src/core/public'; import { IConversationFindResponse } from '../../common/types/chat_saved_object_attributes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { DataSourceService } from './data_source_service'; export class ConversationsService { conversations$: BehaviorSubject = new BehaviorSubject( @@ -21,7 +22,7 @@ export class ConversationsService { >; abortController?: AbortController; - constructor(private _http: HttpStart) {} + constructor(private _http: HttpStart, private _dataSource: DataSourceService) {} public get options() { return this._options; @@ -37,7 +38,7 @@ export class ConversationsService { this.status$.next('loading'); this.conversations$.next( await this._http.get(ASSISTANT_API.CONVERSATIONS, { - query: this._options as HttpFetchQuery, + query: { ...this._options, ...this._dataSource.getDataSourceQuery() } as HttpFetchQuery, signal: this.abortController.signal, }) ); diff --git a/public/services/data_source_service.ts b/public/services/data_source_service.ts new file mode 100644 index 00000000..1607d7fa --- /dev/null +++ b/public/services/data_source_service.ts @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { BehaviorSubject } from 'rxjs'; +import type { DataSourceManagementPluginSetup } from '../../../../src/plugins/data_source_management/public/plugin'; + +export class DataSourceService { + dataSourceId$ = new BehaviorSubject(null); + private dataSourceManagement: DataSourceManagementPluginSetup | undefined | null = null; + + constructor() {} + + getDataSourceQuery() { + if (!this.dataSourceManagement) { + return {}; + } + // TODO: use new handle logic to update + const dataSourceId = this.dataSourceManagement.dataSourceSelection + .getSelectionValue() + ?.values() + .next().value; + if (dataSourceId === null) { + throw new Error('No data source id'); + } + if (dataSourceId === '') { + return {}; + } + return { dataSourceId }; + } +} diff --git a/public/services/index.ts b/public/services/index.ts index 72243960..66c1d040 100644 --- a/public/services/index.ts +++ b/public/services/index.ts @@ -20,3 +20,5 @@ export const [getChrome, setChrome] = createGetterSetter('Chrome'); export const [getNotifications, setNotifications] = createGetterSetter( 'Notifications' ); + +export { DataSourceService } from './data_source_service'; diff --git a/public/types.ts b/public/types.ts index 08c4a6ca..9d7926e2 100644 --- a/public/types.ts +++ b/public/types.ts @@ -10,6 +10,7 @@ import { IChatContext } from './contexts/chat_context'; import { MessageContentProps } from './tabs/chat/messages/message_content'; import { IncontextInsightRegistry } from './services'; import { DataSourceManagementPluginSetup } from '../../../src/plugins/data_source_management/public'; +import { DataSourcePluginSetup } from '../../../src/plugins/data_source/public'; export interface RenderProps { props: MessageContentProps; @@ -35,7 +36,8 @@ export interface AssistantPluginStartDependencies { export interface AssistantPluginSetupDependencies { embeddable: EmbeddableSetup; securityDashboards?: {}; - dataSourceManagement: DataSourceManagementPluginSetup; + dataSource?: DataSourcePluginSetup; + dataSourceManagement?: DataSourceManagementPluginSetup; } export interface AssistantSetup { diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index 0a2b7b51..e0809be5 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -17,6 +17,7 @@ import { OllyChatService } from '../services/chat/olly_chat_service'; import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service'; import { RoutesOptions } from '../types'; import { ChatService } from '../services/chat/chat_service'; +import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; const llmRequestRoute = { path: ASSISTANT_API.SEND_MESSAGE, @@ -33,6 +34,9 @@ const llmRequestRoute = { contentType: schema.literal('text'), }), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type LLMRequestSchema = TypeOf; @@ -43,6 +47,9 @@ const getConversationRoute = { params: schema.object({ conversationId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type GetConversationSchema = TypeOf; @@ -53,6 +60,9 @@ const abortAgentExecutionRoute = { body: schema.object({ conversationId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type AbortAgentExecutionSchema = TypeOf; @@ -64,6 +74,9 @@ const regenerateRoute = { conversationId: schema.string(), interactionId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type RegenerateSchema = TypeOf; @@ -79,6 +92,7 @@ const getConversationsRoute = { fields: schema.maybe(schema.arrayOf(schema.string())), search: schema.maybe(schema.string()), searchFields: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])), + dataSourceId: schema.maybe(schema.string()), }), }, }; @@ -90,6 +104,9 @@ const deleteConversationRoute = { params: schema.object({ conversationId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; @@ -102,6 +119,9 @@ const updateConversationRoute = { body: schema.object({ title: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; @@ -111,6 +131,9 @@ const getTracesRoute = { params: schema.object({ interactionId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; @@ -123,16 +146,20 @@ const feedbackRoute = { body: schema.object({ satisfaction: schema.boolean(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) { - const createStorageService = (context: RequestHandlerContext) => + const createStorageService = async (context: RequestHandlerContext, dataSourceId?: string) => new AgentFrameworkStorageService( - context.core.opensearch.client.asCurrentUser, + await getOpenSearchClientTransport({ context, dataSourceId }), routeOptions.messageParsers ); - const createChatService = (context: RequestHandlerContext) => new OllyChatService(context); + const createChatService = async (context: RequestHandlerContext, dataSourceId?: string) => + new OllyChatService(await getOpenSearchClientTransport({ context, dataSourceId })); router.post( llmRequestRoute, @@ -142,8 +169,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) response ): Promise> => { const { messages = [], input, conversationId: conversationIdInRequestBody } = request.body; - const storageService = createStorageService(context); - const chatService = createChatService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); + const chatService = await createChatService(context, request.query.dataSourceId); let outputs: Awaited> | undefined; @@ -217,7 +244,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.getConversation(request.params.conversationId); @@ -236,7 +263,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.getConversations(request.query); @@ -255,7 +282,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.deleteConversation(request.params.conversationId); @@ -274,7 +301,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.updateConversation( @@ -296,7 +323,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.getTraces(request.params.interactionId); @@ -315,7 +342,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const chatService = createChatService(context, ''); + const chatService = await createChatService(context, request.query.dataSourceId); try { chatService.abortAgentExecution(request.body.conversationId); context.assistant_plugin.logger.info( @@ -337,8 +364,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) response ): Promise> => { const { conversationId, interactionId } = request.body; - const storageService = createStorageService(context); - const chatService = createChatService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); + const chatService = await createChatService(context, request.query.dataSourceId); let outputs: Awaited> | undefined; @@ -386,7 +413,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); const { interactionId } = request.params; try { diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index cb239079..bdebec5a 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -4,7 +4,7 @@ */ import { ApiResponse } from '@opensearch-project/opensearch'; -import { RequestHandlerContext } from '../../../../../src/core/server'; +import { OpenSearchClient } from '../../../../../src/core/server'; import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes'; import { ChatService } from './chat_service'; import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants'; @@ -22,13 +22,12 @@ const INTERACTION_ID_FIELD = 'parent_interaction_id'; export class OllyChatService implements ChatService { static abortControllers: Map = new Map(); - constructor(private readonly context: RequestHandlerContext) {} + constructor(private readonly opensearchClientTransport: OpenSearchClient['transport']) {} private async getRootAgent(): Promise { try { - const opensearchClient = this.context.core.opensearch.client.asCurrentUser; const path = `${ML_COMMONS_BASE_API}/config/${ROOT_AGENT_CONFIG_ID}`; - const response = await opensearchClient.transport.request({ + const response = await this.opensearchClientTransport.request({ method: 'GET', path, }); @@ -53,9 +52,8 @@ export class OllyChatService implements ChatService { } private async callExecuteAgentAPI(payload: AgentRunPayload, rootAgentId: string) { - const opensearchClient = this.context.core.opensearch.client.asCurrentUser; try { - const agentFrameworkResponse = (await opensearchClient.transport.request( + const agentFrameworkResponse = (await this.opensearchClientTransport.request( { method: 'POST', path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`, diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts index f0f1a735..05f57b19 100644 --- a/server/services/storage/agent_framework_storage_service.ts +++ b/server/services/storage/agent_framework_storage_service.ts @@ -28,12 +28,12 @@ export interface ConversationOptResponse { export class AgentFrameworkStorageService implements StorageService { constructor( - private readonly client: OpenSearchClient, + private readonly clientTransport: OpenSearchClient['transport'], private readonly messageParsers: MessageParser[] = [] ) {} async getConversation(conversationId: string): Promise { const [interactionsResp, conversation] = await Promise.all([ - this.client.transport.request({ + this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent( conversationId @@ -43,7 +43,7 @@ export class AgentFrameworkStorageService implements StorageService { messages: InteractionFromAgentFramework[]; }> >, - this.client.transport.request({ + this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`, }) as TransportRequestPromise< @@ -103,7 +103,7 @@ export class AgentFrameworkStorageService implements StorageService { ...(sortField && query.sortOrder && { sort: [{ [sortField]: query.sortOrder }] }), }; - const conversations = await this.client.transport.request({ + const conversations = await this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/_search`, body: requestParams, @@ -146,7 +146,7 @@ export class AgentFrameworkStorageService implements StorageService { } async deleteConversation(conversationId: string): Promise { - const response = await this.client.transport.request({ + const response = await this.clientTransport.request({ method: 'DELETE', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`, }); @@ -167,7 +167,7 @@ export class AgentFrameworkStorageService implements StorageService { conversationId: string, title: string ): Promise { - const response = await this.client.transport.request({ + const response = await this.clientTransport.request({ method: 'PUT', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`, body: { @@ -188,7 +188,7 @@ export class AgentFrameworkStorageService implements StorageService { } async getTraces(interactionId: string): Promise { - const response = (await this.client.transport.request({ + const response = (await this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}/traces`, })) as ApiResponse<{ @@ -216,7 +216,7 @@ export class AgentFrameworkStorageService implements StorageService { interactionId: string, additionalInfo: Record> ): Promise { - const response = await this.client.transport.request({ + const response = await this.clientTransport.request({ method: 'PUT', path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`, body: { @@ -260,7 +260,7 @@ export class AgentFrameworkStorageService implements StorageService { if (!interactionId) { throw new Error('interactionId is required'); } - const interactionsResp = (await this.client.transport.request({ + const interactionsResp = (await this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`, })) as ApiResponse; diff --git a/server/utils/get_opensearch_client_transport.test.ts b/server/utils/get_opensearch_client_transport.test.ts new file mode 100644 index 00000000..0c59c842 --- /dev/null +++ b/server/utils/get_opensearch_client_transport.test.ts @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { getOpenSearchClientTransport } from './get_opensearch_client_transport'; +import { coreMock } from '../../../../src/core/server/mocks'; +import { loggerMock } from '../../../../src/core/server/logging/logger.mock'; + +const mockedLogger = loggerMock.create(); + +describe('getOpenSearchClientTransport', () => { + it('should return current user opensearch transport', async () => { + const core = coreMock.createRequestHandlerContext(); + + expect( + await getOpenSearchClientTransport({ + context: { core, assistant_plugin: { logger: mockedLogger } }, + }) + ).toBe(core.opensearch.client.asCurrentUser.transport); + }); + it('should data source id related opensearch transport', async () => { + const transportMock = {}; + const core = coreMock.createRequestHandlerContext(); + const context = { + core, + dataSource: { + opensearch: { + getClient: async (_dataSourceId: string) => ({ + transport: transportMock, + }), + }, + }, + }; + + expect( + await getOpenSearchClientTransport({ + context: { core, assistant_plugin: { logger: mockedLogger } }, + dataSourceId: 'foo', + }) + ).toBe(transportMock); + }); +}); diff --git a/server/utils/get_opensearch_client_transport.ts b/server/utils/get_opensearch_client_transport.ts new file mode 100644 index 00000000..9f54a4ff --- /dev/null +++ b/server/utils/get_opensearch_client_transport.ts @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OpenSearchClient, RequestHandlerContext } from '../../../../src/core/server'; + +export const getOpenSearchClientTransport = async ({ + context, + dataSourceId, +}: { + context: RequestHandlerContext & { + dataSource?: { + opensearch: { + getClient: (dataSourceId: string) => Promise; + }; + }; + }; + dataSourceId?: string; +}) => { + if (dataSourceId && context.dataSource) { + return (await context.dataSource.opensearch.getClient(dataSourceId)).transport; + } + return context.core.opensearch.client.asCurrentUser.transport; +};