diff --git a/packages/headless/src/api/knowledge/stream-answer-api.ts b/packages/headless/src/api/knowledge/stream-answer-api.ts index 6e63e42a321..f37c287a9fa 100644 --- a/packages/headless/src/api/knowledge/stream-answer-api.ts +++ b/packages/headless/src/api/knowledge/stream-answer-api.ts @@ -2,7 +2,13 @@ import { EventSourceMessage, fetchEventSource, } from '@microsoft/fetch-event-source'; -import {createSelector} from '@reduxjs/toolkit'; +import {createSelector, ThunkDispatch, UnknownAction} from '@reduxjs/toolkit'; +import { + setAnswerContentFormat, + updateCitations, + updateMessage, +} from '../../features/generated-answer/generated-answer-actions'; +import {logGeneratedAnswerStreamEnd} from '../../features/generated-answer/generated-answer-analytics-actions'; import {selectFieldsToIncludeInCitation} from '../../features/generated-answer/generated-answer-selectors'; import { GeneratedAnswerStyle, @@ -123,7 +129,8 @@ const handleError = ( const updateCacheWithEvent = ( event: EventSourceMessage, - draft: GeneratedAnswerStream + draft: GeneratedAnswerStream, + dispatch: ThunkDispatch ) => { const message: Required = JSON.parse(event.data); if (message.finishReason === 'ERROR' && message.errorMessage) { @@ -138,21 +145,27 @@ const updateCacheWithEvent = ( case 'genqa.headerMessageType': if (parsedPayload.answerStyle && parsedPayload.contentFormat) { handleHeaderMessage(draft, parsedPayload); + dispatch(setAnswerContentFormat(parsedPayload.contentFormat)); } break; case 'genqa.messageType': if (parsedPayload.textDelta) { handleMessage(draft, parsedPayload); + dispatch(updateMessage({textDelta: parsedPayload.textDelta})); } break; case 'genqa.citationsType': if (parsedPayload.citations) { handleCitations(draft, parsedPayload); + dispatch(updateCitations({citations: parsedPayload.citations})); } break; case 'genqa.endOfStreamType': if (draft.answer?.length || parsedPayload.answerGenerated) { handleEndOfStream(draft, parsedPayload); + dispatch( + logGeneratedAnswerStreamEnd(parsedPayload.answerGenerated ?? false) + ); } break; } @@ -176,7 +189,7 @@ export const answerApi = answerSlice.injectEndpoints({ }), async onCacheEntryAdded( args, - {getState, cacheDataLoaded, updateCachedData} + {getState, cacheDataLoaded, updateCachedData, dispatch} ) { await cacheDataLoaded; /** @@ -209,7 +222,7 @@ export const answerApi = answerSlice.injectEndpoints({ }, onmessage: (event) => { updateCachedData((draft) => { - updateCacheWithEvent(event, draft); + updateCacheWithEvent(event, draft, dispatch); }); }, onerror: (error) => {