diff --git a/opensearch_dashboards.json b/opensearch_dashboards.json
index b9733635..ca95ebdc 100644
--- a/opensearch_dashboards.json
+++ b/opensearch_dashboards.json
@@ -12,7 +12,7 @@
"dataSourceManagement"
],
"optionalPlugins": [
- "securityDashboards"
+ "securityDashboards","dataSource", "dataSourceManagement"
],
"configPath": [
"assistant"
diff --git a/public/chat_flyout.tsx b/public/chat_flyout.tsx
index 8f3af9e4..73c4092d 100644
--- a/public/chat_flyout.tsx
+++ b/public/chat_flyout.tsx
@@ -7,6 +7,7 @@ import { EuiResizableContainer } from '@elastic/eui';
import cs from 'classnames';
import React, { useMemo, useRef } from 'react';
import { useObservable } from 'react-use';
+import { BehaviorSubject } from 'rxjs';
import { useChatContext } from './contexts/chat_context';
import { ChatPage } from './tabs/chat/chat_page';
import { ChatWindowHeader } from './tabs/chat_window_header';
@@ -25,8 +26,10 @@ export const ChatFlyout = (props: ChatFlyoutProps) => {
const chatContext = useChatContext();
const chatHistoryPageLoadedRef = useRef(false);
const core = useCore();
+ // TODO: use DataSourceService to replace
const selectedDS = useObservable(
- core.services.setupDeps.dataSourceManagement.dataSourceSelection.getSelection$()
+ core.services.setupDeps?.dataSourceManagement?.dataSourceSelection?.getSelection$() ??
+ new BehaviorSubject(new Map())
);
const currentDS = useMemo(() => {
return selectedDS?.values().next().value;
@@ -97,7 +100,7 @@ export const ChatFlyout = (props: ChatFlyoutProps) => {
Data Source:
- {currentDS?.[0].label}
+ {currentDS?.[0]?.label}
diff --git a/public/components/feedback_modal.tsx b/public/components/feedback_modal.tsx
deleted file mode 100644
index e9889da3..00000000
--- a/public/components/feedback_modal.tsx
+++ /dev/null
@@ -1,309 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-import {
- EuiButton,
- EuiButtonEmpty,
- EuiForm,
- EuiFormRow,
- EuiModal,
- EuiModalBody,
- EuiModalFooter,
- EuiModalHeader,
- EuiModalHeaderTitle,
- EuiRadioGroup,
- EuiTextArea,
-} from '@elastic/eui';
-import React, { useState } from 'react';
-import { HttpStart } from '../../../../src/core/public';
-import { ASSISTANT_API } from '../../common/constants/llm';
-import { getCoreStart } from '../plugin';
-
-export interface LabelData {
- formHeader: string;
- inputPlaceholder: string;
- outputPlaceholder: string;
-}
-
-export interface FeedbackFormData {
- input: string;
- output: string;
- correct: boolean | undefined;
- expectedOutput: string;
- comment: string;
-}
-
-interface FeedbackMetaData {
- type: 'event_analytics' | 'chat' | 'ppl_submit';
- conversationId?: string;
- interactionId?: string;
- error?: boolean;
- selectedIndex?: string;
-}
-
-interface FeedbackModelProps {
- input?: string;
- output?: string;
- metadata: FeedbackMetaData;
- onClose: () => void;
-}
-
-export const FeedbackModal: React.FC = (props) => {
- const [formData, setFormData] = useState({
- input: props.input ?? '',
- output: props.output ?? '',
- correct: undefined,
- expectedOutput: '',
- comment: '',
- });
- return (
-
-
-
- );
-};
-
-interface FeedbackModalContentProps {
- formData: FeedbackFormData;
- setFormData: React.Dispatch>;
- metadata: FeedbackMetaData;
- displayLabels?: Partial> & Partial;
- onClose: () => void;
-}
-
-export const FeedbackModalContent: React.FC = (props) => {
- const core = getCoreStart();
- const labels: NonNullable> = Object.assign(
- {
- formHeader: 'Olly Skills Feedback',
- inputPlaceholder: 'Your input question',
- input: 'Input question',
- outputPlaceholder: 'The LLM response',
- output: 'Output',
- correct: 'Does the output match your expectations?',
- expectedOutput: 'Expected output',
- comment: 'Comment',
- },
- props.displayLabels
- );
- const { loading, submitFeedback } = useSubmitFeedback(props.formData, props.metadata, core.http);
- const [formErrors, setFormErrors] = useState<
- Partial<{ [x in keyof FeedbackFormData]: string[] }>
- >({
- input: [],
- output: [],
- expectedOutput: [],
- });
-
- const hasError = (key?: keyof FeedbackFormData) => {
- if (!key) return Object.values(formErrors).some((e) => !!e.length);
- return !!formErrors[key]?.length;
- };
-
- const onSubmit = async (event: React.FormEvent) => {
- event.preventDefault();
- const errors = {
- input: validator
- .input(props.formData.input)
- .concat(await validator.validateQuery(props.formData.input, props.metadata.type)),
- output: validator.output(props.formData.output),
- correct: validator.correct(props.formData.correct),
- expectedOutput: validator.expectedOutput(
- props.formData.expectedOutput,
- props.formData.correct === false
- ),
- };
- if (Object.values(errors).some((e) => !!e.length)) {
- setFormErrors(errors);
- return;
- }
-
- try {
- await submitFeedback();
- props.setFormData({
- input: '',
- output: '',
- correct: undefined,
- expectedOutput: '',
- comment: '',
- });
- core.notifications.toasts.addSuccess('Thanks for your feedback!');
- props.onClose();
- } catch (e) {
- core.notifications.toasts.addError(e, { title: 'Failed to submit feedback' });
- }
- };
-
- return (
- <>
-
- {labels.formHeader}
-
-
-
-
-
- props.setFormData({ ...props.formData, input: e.target.value })}
- onBlur={(e) => {
- setFormErrors({ ...formErrors, input: validator.input(e.target.value) });
- }}
- isInvalid={hasError('input')}
- />
-
-
- props.setFormData({ ...props.formData, output: e.target.value })}
- onBlur={(e) => {
- setFormErrors({ ...formErrors, output: validator.output(e.target.value) });
- }}
- isInvalid={hasError('output')}
- />
-
- {props.metadata.type !== 'ppl_submit' && (
-
- {
- props.setFormData({ ...props.formData, correct: id === 'yes' });
- setFormErrors({ ...formErrors, expectedOutput: [] });
- }}
- onBlur={() => setFormErrors({ ...formErrors, correct: [] })}
- />
-
- )}
- {props.formData.correct === false && (
-
-
- props.setFormData({ ...props.formData, expectedOutput: e.target.value })
- }
- onBlur={(e) => {
- setFormErrors({
- ...formErrors,
- expectedOutput: validator.expectedOutput(
- e.target.value,
- props.formData.correct === false
- ),
- });
- }}
- isInvalid={hasError('expectedOutput')}
- />
-
- )}
-
- props.setFormData({ ...props.formData, comment: e.target.value })}
- />
-
-
-
-
-
- Cancel
-
- Send
-
-
- >
- );
-};
-
-const useSubmitFeedback = (data: FeedbackFormData, metadata: FeedbackMetaData, http: HttpStart) => {
- const [loading, setLoading] = useState(false);
- return {
- loading,
- submitFeedback: async () => {
- setLoading(true);
- const auth = await http
- .get<{ data: { user_name: string; user_requested_tenant: string; roles: string[] } }>(
- '/api/v1/configuration/account'
- )
- .then((res) => ({ user: res.data.user_name, tenant: res.data.user_requested_tenant }));
-
- return http
- .post(ASSISTANT_API.FEEDBACK, {
- body: JSON.stringify({ metadata: { ...metadata, ...auth }, ...data }),
- })
- .finally(() => setLoading(false));
- },
- };
-};
-
-const validatePPLQuery = async (logsQuery: string, feedBackType: FeedbackMetaData['type']) => {
- return [];
- // TODO remove
- // let responseMessage: [] | string[] = [];
- // const errorMessage = [' Invalid PPL Query, please re-check the ppl syntax'];
-
- // if (feedBackType === 'ppl_submit') {
- // const pplService = getPPLService();
- // await pplService
- // .fetch({ query: logsQuery, format: 'jdbc' })
- // .then((res) => {
- // if (res === undefined) responseMessage = errorMessage;
- // })
- // .catch((error: Error) => {
- // responseMessage = errorMessage;
- // });
- // }
- // return responseMessage;
-};
-
-const validator = {
- input: (text: string) => (text.trim().length === 0 ? ['Input is required'] : []),
- output: (text: string) => (text.trim().length === 0 ? ['Output is required'] : []),
- correct: (correct: boolean | undefined) =>
- correct === undefined ? ['Correctness is required'] : [],
- expectedOutput: (text: string, required: boolean) =>
- required && text.trim().length === 0 ? ['expectedOutput is required'] : [],
- validateQuery: async (logsQuery: string, feedBackType: FeedbackMetaData['type']) =>
- await validatePPLQuery(logsQuery, feedBackType),
-};
diff --git a/public/contexts/core_context.tsx b/public/contexts/core_context.tsx
index c2d4ff08..6db5dd2e 100644
--- a/public/contexts/core_context.tsx
+++ b/public/contexts/core_context.tsx
@@ -8,13 +8,15 @@ import {
useOpenSearchDashboards,
} from '../../../../src/plugins/opensearch_dashboards_react/public';
import { AssistantPluginStartDependencies, AssistantPluginSetupDependencies } from '../types';
-import { ConversationLoadService, ConversationsService } from '../services';
+import { ConversationLoadService, ConversationsService, DataSourceService } from '../services';
export interface AssistantServices extends Required {
setupDeps: AssistantPluginSetupDependencies;
startDeps: AssistantPluginStartDependencies;
conversationLoad: ConversationLoadService;
conversations: ConversationsService;
+ // This service is maintained in chatbot instead of dataSource exported from core plugin.
+ dataSource: DataSourceService;
}
export const useCore: () => OpenSearchDashboardsReactContextValue<
diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx
index a8f6358b..6723a0c6 100644
--- a/public/hooks/use_chat_actions.tsx
+++ b/public/hooks/use_chat_actions.tsx
@@ -35,6 +35,7 @@ export const useChatActions = (): AssistantActions => {
...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
+ query: core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) return;
// Refresh history list after new conversation created if new conversation saved and history list page visible
@@ -162,6 +163,7 @@ export const useChatActions = (): AssistantActions => {
// abort agent execution
await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, {
body: JSON.stringify({ conversationId }),
+ query: core.services.dataSource.getDataSourceQuery(),
});
}
};
@@ -178,6 +180,7 @@ export const useChatActions = (): AssistantActions => {
conversationId: chatContext.conversationId,
interactionId,
}),
+ query: core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) {
diff --git a/public/hooks/use_conversations.ts b/public/hooks/use_conversations.ts
index 89e1a058..9fbed688 100644
--- a/public/hooks/use_conversations.ts
+++ b/public/hooks/use_conversations.ts
@@ -20,6 +20,7 @@ export const useDeleteConversation = () => {
return core.services.http
.delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, {
signal: abortControllerRef.current.signal,
+ query: core.services.dataSource.getDataSourceQuery(),
})
.then((payload) => {
dispatch({ type: 'success', payload });
diff --git a/public/hooks/use_feed_back.tsx b/public/hooks/use_feed_back.tsx
index 50e56380..7a37bd37 100644
--- a/public/hooks/use_feed_back.tsx
+++ b/public/hooks/use_feed_back.tsx
@@ -38,6 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => {
try {
await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, {
body: JSON.stringify(body),
+ query: core.services.dataSource.getDataSourceQuery(),
});
setFeedbackResult(correct);
} catch (error) {
diff --git a/public/plugin.tsx b/public/plugin.tsx
index 2d15682f..1ef050ae 100644
--- a/public/plugin.tsx
+++ b/public/plugin.tsx
@@ -31,6 +31,7 @@ import {
setIncontextInsightRegistry,
} from './services';
import { ConfigSchema } from '../common/types/config';
+import { DataSourceService } from './services/data_source_service';
export const [getCoreStart, setCoreStart] = createGetterSetter('CoreStart');
@@ -58,9 +59,11 @@ export class AssistantPlugin
> {
private config: ConfigSchema;
incontextInsightRegistry: IncontextInsightRegistry | undefined;
+ private dataSourceService: DataSourceService;
constructor(initializerContext: PluginInitializerContext) {
this.config = initializerContext.config.get();
+ this.dataSourceService = new DataSourceService();
}
public setup(
@@ -103,8 +106,9 @@ export class AssistantPlugin
...coreStart,
setupDeps,
startDeps,
- conversationLoad: new ConversationLoadService(coreStart.http),
- conversations: new ConversationsService(coreStart.http),
+ conversationLoad: new ConversationLoadService(coreStart.http, this.dataSourceService),
+ conversations: new ConversationsService(coreStart.http, this.dataSourceService),
+ dataSource: this.dataSourceService,
});
const account = await getAccount();
const username = account.data.user_name;
diff --git a/public/services/conversation_load_service.ts b/public/services/conversation_load_service.ts
index 44b2e970..b2875a20 100644
--- a/public/services/conversation_load_service.ts
+++ b/public/services/conversation_load_service.ts
@@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs';
import { HttpStart } from '../../../../src/core/public';
import { IConversation } from '../../common/types/chat_saved_object_attributes';
import { ASSISTANT_API } from '../../common/constants/llm';
+import { DataSourceService } from './data_source_service';
export class ConversationLoadService {
status$: BehaviorSubject<
@@ -14,7 +15,7 @@ export class ConversationLoadService {
> = new BehaviorSubject<'idle' | 'loading' | { status: 'error'; error: Error }>('idle');
abortController?: AbortController;
- constructor(private _http: HttpStart) {}
+ constructor(private _http: HttpStart, private _dataSource: DataSourceService) {}
load = async (conversationId: string) => {
this.abortController?.abort();
@@ -25,6 +26,7 @@ export class ConversationLoadService {
`${ASSISTANT_API.CONVERSATION}/${conversationId}`,
{
signal: this.abortController.signal,
+ query: this._dataSource.getDataSourceQuery(),
}
);
this.status$.next('idle');
diff --git a/public/services/conversations_service.ts b/public/services/conversations_service.ts
index 4f95794a..3c1faf9d 100644
--- a/public/services/conversations_service.ts
+++ b/public/services/conversations_service.ts
@@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs';
import { HttpFetchQuery, HttpStart, SavedObjectsFindOptions } from '../../../../src/core/public';
import { IConversationFindResponse } from '../../common/types/chat_saved_object_attributes';
import { ASSISTANT_API } from '../../common/constants/llm';
+import { DataSourceService } from './data_source_service';
export class ConversationsService {
conversations$: BehaviorSubject = new BehaviorSubject(
@@ -21,7 +22,7 @@ export class ConversationsService {
>;
abortController?: AbortController;
- constructor(private _http: HttpStart) {}
+ constructor(private _http: HttpStart, private _dataSource: DataSourceService) {}
public get options() {
return this._options;
@@ -37,7 +38,7 @@ export class ConversationsService {
this.status$.next('loading');
this.conversations$.next(
await this._http.get(ASSISTANT_API.CONVERSATIONS, {
- query: this._options as HttpFetchQuery,
+ query: { ...this._options, ...this._dataSource.getDataSourceQuery() } as HttpFetchQuery,
signal: this.abortController.signal,
})
);
diff --git a/public/services/data_source_service.ts b/public/services/data_source_service.ts
new file mode 100644
index 00000000..1607d7fa
--- /dev/null
+++ b/public/services/data_source_service.ts
@@ -0,0 +1,32 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { BehaviorSubject } from 'rxjs';
+import type { DataSourceManagementPluginSetup } from '../../../../src/plugins/data_source_management/public/plugin';
+
+export class DataSourceService {
+ dataSourceId$ = new BehaviorSubject(null);
+ private dataSourceManagement: DataSourceManagementPluginSetup | undefined | null = null;
+
+ constructor() {}
+
+ getDataSourceQuery() {
+ if (!this.dataSourceManagement) {
+ return {};
+ }
+ // TODO: use new handle logic to update
+ const dataSourceId = this.dataSourceManagement.dataSourceSelection
+ .getSelectionValue()
+ ?.values()
+ .next().value;
+ if (dataSourceId === null) {
+ throw new Error('No data source id');
+ }
+ if (dataSourceId === '') {
+ return {};
+ }
+ return { dataSourceId };
+ }
+}
diff --git a/public/services/index.ts b/public/services/index.ts
index 72243960..66c1d040 100644
--- a/public/services/index.ts
+++ b/public/services/index.ts
@@ -20,3 +20,5 @@ export const [getChrome, setChrome] = createGetterSetter('Chrome');
export const [getNotifications, setNotifications] = createGetterSetter(
'Notifications'
);
+
+export { DataSourceService } from './data_source_service';
diff --git a/public/types.ts b/public/types.ts
index 08c4a6ca..9d7926e2 100644
--- a/public/types.ts
+++ b/public/types.ts
@@ -10,6 +10,7 @@ import { IChatContext } from './contexts/chat_context';
import { MessageContentProps } from './tabs/chat/messages/message_content';
import { IncontextInsightRegistry } from './services';
import { DataSourceManagementPluginSetup } from '../../../src/plugins/data_source_management/public';
+import { DataSourcePluginSetup } from '../../../src/plugins/data_source/public';
export interface RenderProps {
props: MessageContentProps;
@@ -35,7 +36,8 @@ export interface AssistantPluginStartDependencies {
export interface AssistantPluginSetupDependencies {
embeddable: EmbeddableSetup;
securityDashboards?: {};
- dataSourceManagement: DataSourceManagementPluginSetup;
+ dataSource?: DataSourcePluginSetup;
+ dataSourceManagement?: DataSourceManagementPluginSetup;
}
export interface AssistantSetup {
diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts
index 0a2b7b51..e0809be5 100644
--- a/server/routes/chat_routes.ts
+++ b/server/routes/chat_routes.ts
@@ -17,6 +17,7 @@ import { OllyChatService } from '../services/chat/olly_chat_service';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
import { ChatService } from '../services/chat/chat_service';
+import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport';
const llmRequestRoute = {
path: ASSISTANT_API.SEND_MESSAGE,
@@ -33,6 +34,9 @@ const llmRequestRoute = {
contentType: schema.literal('text'),
}),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type LLMRequestSchema = TypeOf;
@@ -43,6 +47,9 @@ const getConversationRoute = {
params: schema.object({
conversationId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type GetConversationSchema = TypeOf;
@@ -53,6 +60,9 @@ const abortAgentExecutionRoute = {
body: schema.object({
conversationId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type AbortAgentExecutionSchema = TypeOf;
@@ -64,6 +74,9 @@ const regenerateRoute = {
conversationId: schema.string(),
interactionId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type RegenerateSchema = TypeOf;
@@ -79,6 +92,7 @@ const getConversationsRoute = {
fields: schema.maybe(schema.arrayOf(schema.string())),
search: schema.maybe(schema.string()),
searchFields: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])),
+ dataSourceId: schema.maybe(schema.string()),
}),
},
};
@@ -90,6 +104,9 @@ const deleteConversationRoute = {
params: schema.object({
conversationId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
@@ -102,6 +119,9 @@ const updateConversationRoute = {
body: schema.object({
title: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
@@ -111,6 +131,9 @@ const getTracesRoute = {
params: schema.object({
interactionId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
@@ -123,16 +146,20 @@ const feedbackRoute = {
body: schema.object({
satisfaction: schema.boolean(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) {
- const createStorageService = (context: RequestHandlerContext) =>
+ const createStorageService = async (context: RequestHandlerContext, dataSourceId?: string) =>
new AgentFrameworkStorageService(
- context.core.opensearch.client.asCurrentUser,
+ await getOpenSearchClientTransport({ context, dataSourceId }),
routeOptions.messageParsers
);
- const createChatService = (context: RequestHandlerContext) => new OllyChatService(context);
+ const createChatService = async (context: RequestHandlerContext, dataSourceId?: string) =>
+ new OllyChatService(await getOpenSearchClientTransport({ context, dataSourceId }));
router.post(
llmRequestRoute,
@@ -142,8 +169,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
response
): Promise> => {
const { messages = [], input, conversationId: conversationIdInRequestBody } = request.body;
- const storageService = createStorageService(context);
- const chatService = createChatService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
+ const chatService = await createChatService(context, request.query.dataSourceId);
let outputs: Awaited> | undefined;
@@ -217,7 +244,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.getConversation(request.params.conversationId);
@@ -236,7 +263,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.getConversations(request.query);
@@ -255,7 +282,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.deleteConversation(request.params.conversationId);
@@ -274,7 +301,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.updateConversation(
@@ -296,7 +323,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.getTraces(request.params.interactionId);
@@ -315,7 +342,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const chatService = createChatService(context, '');
+ const chatService = await createChatService(context, request.query.dataSourceId);
try {
chatService.abortAgentExecution(request.body.conversationId);
context.assistant_plugin.logger.info(
@@ -337,8 +364,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
response
): Promise> => {
const { conversationId, interactionId } = request.body;
- const storageService = createStorageService(context);
- const chatService = createChatService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
+ const chatService = await createChatService(context, request.query.dataSourceId);
let outputs: Awaited> | undefined;
@@ -386,7 +413,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
const { interactionId } = request.params;
try {
diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts
index cb239079..bdebec5a 100644
--- a/server/services/chat/olly_chat_service.ts
+++ b/server/services/chat/olly_chat_service.ts
@@ -4,7 +4,7 @@
*/
import { ApiResponse } from '@opensearch-project/opensearch';
-import { RequestHandlerContext } from '../../../../../src/core/server';
+import { OpenSearchClient } from '../../../../../src/core/server';
import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes';
import { ChatService } from './chat_service';
import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants';
@@ -22,13 +22,12 @@ const INTERACTION_ID_FIELD = 'parent_interaction_id';
export class OllyChatService implements ChatService {
static abortControllers: Map = new Map();
- constructor(private readonly context: RequestHandlerContext) {}
+ constructor(private readonly opensearchClientTransport: OpenSearchClient['transport']) {}
private async getRootAgent(): Promise {
try {
- const opensearchClient = this.context.core.opensearch.client.asCurrentUser;
const path = `${ML_COMMONS_BASE_API}/config/${ROOT_AGENT_CONFIG_ID}`;
- const response = await opensearchClient.transport.request({
+ const response = await this.opensearchClientTransport.request({
method: 'GET',
path,
});
@@ -53,9 +52,8 @@ export class OllyChatService implements ChatService {
}
private async callExecuteAgentAPI(payload: AgentRunPayload, rootAgentId: string) {
- const opensearchClient = this.context.core.opensearch.client.asCurrentUser;
try {
- const agentFrameworkResponse = (await opensearchClient.transport.request(
+ const agentFrameworkResponse = (await this.opensearchClientTransport.request(
{
method: 'POST',
path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`,
diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts
index f0f1a735..05f57b19 100644
--- a/server/services/storage/agent_framework_storage_service.ts
+++ b/server/services/storage/agent_framework_storage_service.ts
@@ -28,12 +28,12 @@ export interface ConversationOptResponse {
export class AgentFrameworkStorageService implements StorageService {
constructor(
- private readonly client: OpenSearchClient,
+ private readonly clientTransport: OpenSearchClient['transport'],
private readonly messageParsers: MessageParser[] = []
) {}
async getConversation(conversationId: string): Promise {
const [interactionsResp, conversation] = await Promise.all([
- this.client.transport.request({
+ this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(
conversationId
@@ -43,7 +43,7 @@ export class AgentFrameworkStorageService implements StorageService {
messages: InteractionFromAgentFramework[];
}>
>,
- this.client.transport.request({
+ this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`,
}) as TransportRequestPromise<
@@ -103,7 +103,7 @@ export class AgentFrameworkStorageService implements StorageService {
...(sortField && query.sortOrder && { sort: [{ [sortField]: query.sortOrder }] }),
};
- const conversations = await this.client.transport.request({
+ const conversations = await this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/_search`,
body: requestParams,
@@ -146,7 +146,7 @@ export class AgentFrameworkStorageService implements StorageService {
}
async deleteConversation(conversationId: string): Promise {
- const response = await this.client.transport.request({
+ const response = await this.clientTransport.request({
method: 'DELETE',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`,
});
@@ -167,7 +167,7 @@ export class AgentFrameworkStorageService implements StorageService {
conversationId: string,
title: string
): Promise {
- const response = await this.client.transport.request({
+ const response = await this.clientTransport.request({
method: 'PUT',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`,
body: {
@@ -188,7 +188,7 @@ export class AgentFrameworkStorageService implements StorageService {
}
async getTraces(interactionId: string): Promise {
- const response = (await this.client.transport.request({
+ const response = (await this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}/traces`,
})) as ApiResponse<{
@@ -216,7 +216,7 @@ export class AgentFrameworkStorageService implements StorageService {
interactionId: string,
additionalInfo: Record>
): Promise {
- const response = await this.client.transport.request({
+ const response = await this.clientTransport.request({
method: 'PUT',
path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`,
body: {
@@ -260,7 +260,7 @@ export class AgentFrameworkStorageService implements StorageService {
if (!interactionId) {
throw new Error('interactionId is required');
}
- const interactionsResp = (await this.client.transport.request({
+ const interactionsResp = (await this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`,
})) as ApiResponse;
diff --git a/server/utils/get_opensearch_client_transport.test.ts b/server/utils/get_opensearch_client_transport.test.ts
new file mode 100644
index 00000000..0c59c842
--- /dev/null
+++ b/server/utils/get_opensearch_client_transport.test.ts
@@ -0,0 +1,43 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { getOpenSearchClientTransport } from './get_opensearch_client_transport';
+import { coreMock } from '../../../../src/core/server/mocks';
+import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
+
+const mockedLogger = loggerMock.create();
+
+describe('getOpenSearchClientTransport', () => {
+ it('should return current user opensearch transport', async () => {
+ const core = coreMock.createRequestHandlerContext();
+
+ expect(
+ await getOpenSearchClientTransport({
+ context: { core, assistant_plugin: { logger: mockedLogger } },
+ })
+ ).toBe(core.opensearch.client.asCurrentUser.transport);
+ });
+ it('should data source id related opensearch transport', async () => {
+ const transportMock = {};
+ const core = coreMock.createRequestHandlerContext();
+ const context = {
+ core,
+ dataSource: {
+ opensearch: {
+ getClient: async (_dataSourceId: string) => ({
+ transport: transportMock,
+ }),
+ },
+ },
+ };
+
+ expect(
+ await getOpenSearchClientTransport({
+ context: { core, assistant_plugin: { logger: mockedLogger } },
+ dataSourceId: 'foo',
+ })
+ ).toBe(transportMock);
+ });
+});
diff --git a/server/utils/get_opensearch_client_transport.ts b/server/utils/get_opensearch_client_transport.ts
new file mode 100644
index 00000000..9f54a4ff
--- /dev/null
+++ b/server/utils/get_opensearch_client_transport.ts
@@ -0,0 +1,25 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { OpenSearchClient, RequestHandlerContext } from '../../../../src/core/server';
+
+export const getOpenSearchClientTransport = async ({
+ context,
+ dataSourceId,
+}: {
+ context: RequestHandlerContext & {
+ dataSource?: {
+ opensearch: {
+ getClient: (dataSourceId: string) => Promise;
+ };
+ };
+ };
+ dataSourceId?: string;
+}) => {
+ if (dataSourceId && context.dataSource) {
+ return (await context.dataSource.opensearch.getClient(dataSourceId)).transport;
+ }
+ return context.core.opensearch.client.asCurrentUser.transport;
+};