Skip to content

Commit

Permalink
fix(website): resolve listChatMessages inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
JeremyJonas committed Oct 13, 2023
1 parent 501a326 commit afc9c85
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 121 deletions.
13 changes: 7 additions & 6 deletions demo/website/src/components/chat/ConversationView.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { Alert, Spinner } from "@cloudscape-design/components";
import {
ChatMessage,
useListChatMessages,
} from "api-typescript-react-query-hooks";
import { ChatMessage } from "api-typescript-react-query-hooks";
import { forwardRef, useEffect, useMemo } from "react";
import Message from "./Message";
import { CHAT_MESSAGE_PARAMS } from "../../hooks/chats";
import { CHAT_MESSAGE_PARAMS, useListChatMessages } from "../../hooks/chats";
import EmptyState from "../Empty";

type ConversationViewProps = {
Expand Down Expand Up @@ -42,7 +39,11 @@ export const ConversationView = forwardRef(

return (
<>
{error && <Alert type="error">{error.message}</Alert>}
{error && (
<Alert type="error">
{(error as Error).message || String(error)}
</Alert>
)}
<div
ref={ref}
style={{
Expand Down
24 changes: 15 additions & 9 deletions demo/website/src/hooks/chats.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
import { rest } from "msw";
import { setupServer } from "msw/node";
import React, { FC } from "react";
import { useInfiniteChatMessages } from "./chats";
import { useListChatMessages } from "./chats";

export const mswServer = setupServer();

Expand Down Expand Up @@ -70,12 +70,14 @@ describe("useInfiniteChatMessage", () => {
})
);

const { result } = renderHook(() => useInfiniteChatMessages(chatId), {
const { result } = renderHook(() => useListChatMessages({ chatId }), {
wrapper,
});

await waitFor(() => expect(result.current.status).toBe("success"));
const allMessages = result.current.data?.pages.flatMap((d) => d.data);
const allMessages = result.current.data?.pages.flatMap(
(d) => d.chatMessages
);
expect(allMessages).toStrictEqual(records);
});

Expand All @@ -100,12 +102,14 @@ describe("useInfiniteChatMessage", () => {
})
);

const { result } = renderHook(() => useInfiniteChatMessages(chatId), {
const { result } = renderHook(() => useListChatMessages({ chatId }), {
wrapper,
});

await waitFor(() => expect(result.current.status).toBe("success"));
const allMessages = result.current.data?.pages.flatMap((d) => d.data);
const allMessages = result.current.data?.pages.flatMap(
(d) => d.chatMessages
);
expect(allMessages).toStrictEqual(records);
});

Expand Down Expand Up @@ -171,22 +175,24 @@ describe("useInfiniteChatMessage", () => {
);

const { result } = renderHook(
() => useInfiniteChatMessages(chatId, pageSize),
() => useListChatMessages({ chatId, pageSize }),
{
wrapper,
}
);

// get the initial result which should be 2 records
await waitFor(() => expect(result.current.status).toBe("success"));
const firstPage = result.current.data?.pages.flatMap((d) => d.data);
const firstPage = result.current.data?.pages.flatMap((d) => d.chatMessages);
expect(firstPage).toStrictEqual(records.slice(0, pageSize));

// fetch the next page
expect(result.current.hasNextPage).toBe(true);
await result.current.fetchNextPage();
await waitFor(() => expect(result.current.isFetching).toBe(false));
const secondPage = result.current.data?.pages.flatMap((d) => d.data);
const secondPage = result.current.data?.pages.flatMap(
(d) => d.chatMessages
);
expect(secondPage).toStrictEqual(records.slice(0, pageSize * 2));

// fetch the last page
Expand All @@ -195,7 +201,7 @@ describe("useInfiniteChatMessage", () => {
await result.current.fetchNextPage();
});
await waitFor(() => expect(result.current.isFetching).toBe(false));
const lastPage = result.current.data?.pages.flatMap((d) => d.data);
const lastPage = result.current.data?.pages.flatMap((d) => d.chatMessages);
expect(lastPage).toStrictEqual(records.slice(0, pageSize * 3));

expect(result.current.hasNextPage).toBe(false);
Expand Down
141 changes: 42 additions & 99 deletions demo/website/src/hooks/chats.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,30 @@
PDX-License-Identifier: Apache-2.0 */
import {
InfiniteData,
QueryFunction,
UseQueryResult,
useInfiniteQuery,
useQuery,
useQueryClient,
} from "@tanstack/react-query";
import {
ChatMessageSource,
CreateChatResponseContent,
ListChatMessageSourcesResponseContent,
ListChatMessagesRequest,
ListChatMessagesResponseContent,
ListChatsResponseContent,
useCreateChatMessage,
useUpdateChat,
useListChatMessages as _useListChatMessages,
useListChats as _useListChats,
useCreateChat,
useListChats as useOriginalListChats,
CreateChatResponseContent,
useCreateChatMessage,
useDeleteChat,
useDeleteChatMessage,
useListChatMessageSources,
ListChatMessageSourcesResponseContent,
ChatMessageSource,
DefaultApiClientContext,
ChatMessage,
ListChatMessagesRequest,
useUpdateChat,
} from "api-typescript-react-query-hooks";
import produce from "immer";
import { last } from "lodash";
import { useCallback, useContext } from "react";

type ListChatMessagesData = InfiniteData<FetchMessagesResponse>;
type PaginatedListChatMessagesResponse =
InfiniteData<ListChatMessagesResponseContent>;

export const CHAT_MESSAGE_PARAMS: Partial<ListChatMessagesRequest> = {
ascending: true,
Expand All @@ -51,8 +47,8 @@ export const queryKeyGenerators = {
],
};

export function useListChats(): ReturnType<typeof useOriginalListChats> {
return useOriginalListChats({
export function useListChats(): ReturnType<typeof _useListChats> {
return _useListChats({
select: (
chatsResponse: ListChatsResponseContent
): ListChatsResponseContent => {
Expand Down Expand Up @@ -94,9 +90,11 @@ export function useCreateChatMutation(

queryClient.setQueryData(
listChatMessagesQueryKey,
(_old: ListChatMessagesData | undefined): ListChatMessagesData => {
(
_old: PaginatedListChatMessagesResponse | undefined
): PaginatedListChatMessagesResponse => {
return {
pages: [{ data: [], nextCursor: undefined }],
pages: [{ chatMessages: [] }],
pageParams: [null],
};
}
Expand All @@ -110,10 +108,17 @@ export function useCreateChatMutation(
return createChat;
}

type ListChatMessagesDataPage = ListChatMessagesData["pages"][number];
type ListChatMessagePage =
| ListChatMessagesDataPage
| ListChatMessagesResponseContent;
export function useListChatMessages(
...args: Parameters<typeof _useListChatMessages>
): ReturnType<typeof _useListChatMessages> {
return _useListChatMessages(
{
...CHAT_MESSAGE_PARAMS,
...args[0],
},
args[1]
);
}

export function useCreateChatMessageMutation(
chatId: string,
Expand All @@ -131,19 +136,14 @@ export function useCreateChatMessageMutation(
// listChatMessages query cache
queryClient.setQueryData(
listChatMessagesQueryKey,
(old: ListChatMessagesData | undefined) => {
(old: PaginatedListChatMessagesResponse | undefined) => {
return produce(old, (draft) => {
if (question && answer) {
const lastPage: ListChatMessagePage | undefined = last(
draft?.pages || []
) as any;
const lastPage: ListChatMessagesResponseContent | undefined =
last(draft?.pages || []) as any;

if (lastPage) {
// empty chat (new) page contains "data" while non-empty container "chatMessages"
const chatMessages =
("data" in lastPage && lastPage.data) ||
("chatMessages" in lastPage && lastPage.chatMessages) ||
undefined;
const chatMessages = lastPage.chatMessages;

if (chatMessages == null) {
// unable to inject new chat messages, just reset to resolve
Expand All @@ -163,12 +163,11 @@ export function useCreateChatMessageMutation(
return {
pages: [
{
data: [question, answer],
nextCursor: undefined,
chatMessages: [question, answer],
},
],
pageParams: [null],
} as ListChatMessagesData;
} as PaginatedListChatMessagesResponse;
}

onSuccess && onSuccess();
Expand Down Expand Up @@ -303,26 +302,20 @@ export function useDeleteChatMessageMutation(
variables.chatId
);

queryClient.setQueryData<ListChatMessagesData>(
queryClient.setQueryData<PaginatedListChatMessagesResponse>(
listChatMessagesQueryKey,
(old) =>
produce(old, (listChatMessagesDraft) => {
if (listChatMessagesDraft && listChatMessagesDraft.pages) {
for (let page of listChatMessagesDraft.pages) {
page.data = page.data.filter(
(message) => message.messageId !== variables.messageId
);
produce(old, (draft) => {
if (draft && draft.pages) {
for (let page of draft.pages) {
page.chatMessages =
page.chatMessages?.filter(
(message) => message.messageId !== variables.messageId
) || [];
}
} else if (old && "chatMessages" in old) {
const filtered = (old.chatMessages as ChatMessage[]).filter(
(v) => v.messageId !== _data.messageId
) as any;
return {
chatMessages: filtered,
} as any;
}

return listChatMessagesDraft;
return draft;
})
);
},
Expand All @@ -349,53 +342,3 @@ export function useMessageSources(
}
);
}

type FetchMessagesResponse = {
data: ChatMessage[];
nextCursor: string | undefined;
};
export function useInfiniteChatMessages(
chatId: string,
pageSize: number = 100
) {
const key = queryKeyGenerators.listChatMessages(chatId);

const api = useContext(DefaultApiClientContext);
const fetchMessages: QueryFunction<FetchMessagesResponse> = useCallback(
async ({ pageParam }) => {
const result = await api.listChatMessages({
chatId,
nextToken: pageParam,
pageSize: pageSize,
reverse: true,
ascending: true,
});
return {
data: result.chatMessages || [],
nextCursor: result.nextToken,
};
},
[api]
);

return useInfiniteQuery(key, {
queryFn: fetchMessages,
retry() {
return false;
},
getNextPageParam: (lastPage) => lastPage.nextCursor,
});
}

// TODO remove once we change over completely to infinite scrolling
export function useListChatMessages(chatId: string) {
const api = useContext(DefaultApiClientContext);
const key = queryKeyGenerators.listChatMessages(chatId);
async function queryFn(): Promise<ListChatMessagesResponseContent> {
// We want messages in reverse order, since results are descending order
return api.listChatMessages({ chatId, reverse: true });
}
return useQuery(key, {
queryFn,
});
}
3 changes: 2 additions & 1 deletion demo/website/tsconfig.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions packages/galileo-cli/src/lib/prompts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ namespace galileoPrompts {
value: x,
})),
initial: () => {
const _initial =
context.cache.getItem("defaultModelId")
const _initial = context.cache.getItem("defaultModelId");
if (_initial && availableModelIds.includes(_initial)) {
return availableModelIds.indexOf(_initial);
}
Expand Down
9 changes: 5 additions & 4 deletions projenrc/demo/website.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import * as path from "node:path";
import { CloudscapeReactTsWebsiteProject } from "@aws/pdk/cloudscape-react-ts-website";
import { MonorepoTsProject, NxProject } from "@aws/pdk/monorepo";
import * as path from "node:path";
import { javascript } from "projen";
import { withStorybook } from "../helpers/withStorybook";
import { Api } from "./api";
import { TypeScriptModuleResolution } from "projen/lib/javascript";
import { DEFAULT_RELEASE_BRANCH, VERSIONS } from "../constants";
import { GalileoSdk } from "../framework";
import { TypeScriptModuleResolution } from "projen/lib/javascript";
import { withStorybook } from "../helpers/withStorybook";
import { Api } from "./api";

export interface WebsiteOptions {
readonly monorepo: MonorepoTsProject;
Expand Down Expand Up @@ -66,6 +66,7 @@ export class Website {
},
},
});
this.project.tsconfig?.addInclude("src/**/*.tsx");
this.project.addGitIgnore("public/api.html");
this.project.addGitIgnore("runtime-config.*");
this.project.addGitIgnore("!runtime-config.example.json");
Expand Down

0 comments on commit afc9c85

Please sign in to comment.