diff --git a/packages/react/src/hooks/__tests__/use-render-activity-message.test.tsx b/packages/react/src/hooks/__tests__/use-render-activity-message.test.tsx new file mode 100644 index 0000000..e614ef8 --- /dev/null +++ b/packages/react/src/hooks/__tests__/use-render-activity-message.test.tsx @@ -0,0 +1,330 @@ +import React from "react"; +import { renderHook, waitFor } from "@testing-library/react"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { CopilotKitProvider } from "@/providers/CopilotKitProvider"; +import { useRenderActivityMessage } from "../use-render-activity-message"; +import { ActivityMessage } from "@ag-ui/core"; +import { ReactActivityMessageRenderer } from "@/types"; +import { z } from "zod"; + +// Mock console methods +const originalConsoleError = console.error; +const originalConsoleWarn = console.warn; + +describe("useRenderActivityMessage", () => { + let consoleErrorSpy: ReturnType; + let consoleWarnSpy: ReturnType; + + beforeEach(() => { + consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + consoleWarnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + consoleWarnSpy.mockRestore(); + }); + + describe("Basic rendering", () => { + it("should render activity messages with registered renderer", () => { + const TestRenderer: React.FC = ({ content }) => ( +
{content.message}
+ ); + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "test-activity", + content: z.object({ message: z.string() }), + render: TestRenderer, + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "test-activity", + content: { message: "Hello World" }, + }; + + const rendered = result.current(message); + expect(rendered).toBeTruthy(); + expect(rendered?.type).toBe(TestRenderer); + }); + + it("should return null for messages without matching renderer", () => { + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "unknown-activity", + content: {}, + }; + + const rendered = result.current(message); + expect(rendered).toBeNull(); + }); + + it("should work when provider is initialized", () => { + const TestRenderer: React.FC = () =>
Test
; + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "delayed-activity", + content: z.object({}), + render: TestRenderer, + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current).toBeDefined(); + expect(typeof result.current).toBe("function"); + }); + }); + + describe("Wildcard renderer", () => { + it("should use wildcard renderer for unknown activity types", () => { + const WildcardRenderer: React.FC = ({ activityType }) => ( +
{activityType}
+ ); + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "*", + content: z.any(), + render: WildcardRenderer, + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "any-type", + content: { foo: "bar" }, + }; + + const rendered = result.current(message); + expect(rendered).toBeTruthy(); + expect(rendered?.type).toBe(WildcardRenderer); + }); + }); + + describe("Agent-specific renderers", () => { + it("should prioritize agent-specific renderer over global", () => { + const GlobalRenderer: React.FC = () =>
Global
; + const AgentRenderer: React.FC = () =>
Agent-specific
; + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "test-activity", + content: z.any(), + render: GlobalRenderer, + }, + { + activityType: "test-activity", + content: z.any(), + render: AgentRenderer, + agentId: "agent-1", + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "test-activity", + content: {}, + }; + + const rendered = result.current(message); + expect(rendered).toBeTruthy(); + // Should use global renderer since we're using default agent + expect(rendered?.type).toBe(GlobalRenderer); + }); + }); + + describe("Content validation", () => { + it("should warn when content fails validation", () => { + const TestRenderer: React.FC = () =>
Test
; + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "strict-activity", + content: z.object({ requiredField: z.string() }), + render: TestRenderer, + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "strict-activity", + content: { wrongField: "value" }, + }; + + const rendered = result.current(message); + expect(rendered).toBeNull(); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining("Failed to parse content"), + expect.anything() + ); + }); + + it("should render when content passes validation", () => { + const TestRenderer: React.FC = ({ content }) => ( +
{content.value}
+ ); + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "valid-activity", + content: z.object({ value: z.string() }), + render: TestRenderer, + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "valid-activity", + content: { value: "test" }, + }; + + const rendered = result.current(message); + expect(rendered).toBeTruthy(); + // Don't check for no warnings - provider may emit warnings about missing runtime URL + // The important thing is that no parse warnings are emitted + expect(consoleWarnSpy).not.toHaveBeenCalledWith( + expect.stringContaining("Failed to parse content"), + expect.anything() + ); + }); + }); + + describe("Regression: Provider initialization timing", () => { + it("should not crash when called immediately after provider mount", async () => { + const TestRenderer: React.FC = () =>
Test
; + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "immediate-activity", + content: z.any(), + render: TestRenderer, + }, + ]; + + // This tests the scenario where useRenderActivityMessage is called + // during the same render cycle as the provider initialization + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + // Should not throw + expect(result.current).toBeDefined(); + expect(typeof result.current).toBe("function"); + + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "immediate-activity", + content: {}, + }; + + // Should be able to render immediately + const rendered = result.current(message); + expect(rendered).toBeTruthy(); + }); + + it("should handle activity messages arriving before full initialization", async () => { + const TestRenderer: React.FC = ({ content }) => ( +
{content.data}
+ ); + + const renderers: ReactActivityMessageRenderer[] = [ + { + activityType: "early-activity", + content: z.object({ data: z.string() }), + render: TestRenderer, + }, + ]; + + const { result } = renderHook(() => useRenderActivityMessage(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + // Simulate an activity message arriving immediately + const message: ActivityMessage = { + id: "msg-1", + role: "activity", + activityType: "early-activity", + content: { data: "early message" }, + }; + + // Should not crash and should render + await waitFor(() => { + const rendered = result.current(message); + expect(rendered).toBeTruthy(); + }); + }); + }); +}); diff --git a/packages/react/src/providers/CopilotKitProvider.tsx b/packages/react/src/providers/CopilotKitProvider.tsx index 4e52c21..a36bba4 100644 --- a/packages/react/src/providers/CopilotKitProvider.tsx +++ b/packages/react/src/providers/CopilotKitProvider.tsx @@ -310,7 +310,7 @@ export const useCopilotKit = (): CopilotKitContextValue => { const context = useContext(CopilotKitContext); const [, forceUpdate] = useReducer((x) => x + 1, 0); - if (!context) { + if (!context || !context.copilotkit) { throw new Error("useCopilotKit must be used within CopilotKitProvider"); } useEffect(() => { @@ -322,8 +322,7 @@ export const useCopilotKit = (): CopilotKitContextValue => { return () => { subscription.unsubscribe(); }; - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); + }, [context.copilotkit]); return context; }; diff --git a/packages/react/src/providers/__tests__/CopilotKitProvider.subscription.test.tsx b/packages/react/src/providers/__tests__/CopilotKitProvider.subscription.test.tsx new file mode 100644 index 0000000..222f7de --- /dev/null +++ b/packages/react/src/providers/__tests__/CopilotKitProvider.subscription.test.tsx @@ -0,0 +1,258 @@ +import React, { useState } from "react"; +import { renderHook, waitFor, act } from "@testing-library/react"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { CopilotKitProvider, useCopilotKit } from "../CopilotKitProvider"; + +// Mock console methods +const originalConsoleError = console.error; +const originalConsoleWarn = console.warn; + +describe("CopilotKitProvider Subscription", () => { + let consoleErrorSpy: ReturnType; + let consoleWarnSpy: ReturnType; + + beforeEach(() => { + consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + consoleWarnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + consoleWarnSpy.mockRestore(); + }); + + describe("useCopilotKit subscription behavior", () => { + it("should subscribe when copilotkit becomes available", async () => { + const { result } = renderHook(() => useCopilotKit(), { + wrapper: ({ children }) => ( + {children} + ), + }); + + // Verify subscription is set up + expect(result.current.copilotkit).toBeDefined(); + + // Get the initial subscription count + const initialSubCount = (result.current.copilotkit as any).subscribers?.size; + expect(initialSubCount).toBeGreaterThan(0); + }); + + it("should re-subscribe when copilotkit instance changes", async () => { + let setShowProvider: (show: boolean) => void; + + const Wrapper = ({ children }: { children: React.ReactNode }) => { + const [showProvider, setShow] = useState(true); + setShowProvider = setShow; + + if (!showProvider) { + return
{children}
; + } + + return {children}; + }; + + const { result } = renderHook(() => { + try { + return useCopilotKit(); + } catch (e) { + return null; + } + }, { + wrapper: Wrapper, + }); + + // Initially should have copilotkit + expect(result.current?.copilotkit).toBeDefined(); + const firstInstance = result.current?.copilotkit; + + // Remove and re-add provider + await act(async () => { + setShowProvider(false); + }); + + await waitFor(() => { + expect(result.current).toBeNull(); + }); + + await act(async () => { + setShowProvider(true); + }); + + await waitFor(() => { + expect(result.current?.copilotkit).toBeDefined(); + }); + + // Should be a new instance + const secondInstance = result.current?.copilotkit; + expect(secondInstance).not.toBe(firstInstance); + }); + + it("should handle runtime connection status changes", async () => { + const { result } = renderHook(() => useCopilotKit(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current.copilotkit).toBeDefined(); + expect(result.current.copilotkit.runtimeConnectionStatus).toBeDefined(); + }); + + it("should maintain subscription across re-renders", async () => { + const { result, rerender } = renderHook(() => useCopilotKit(), { + wrapper: ({ children }) => ( + {children} + ), + }); + + const firstCopilotkit = result.current.copilotkit; + + // Force a re-render + rerender(); + + // Should maintain the same instance + expect(result.current.copilotkit).toBe(firstCopilotkit); + }); + }); + + describe("Activity message rendering with delayed initialization", () => { + it("should handle components that use useCopilotKit mounting before provider is ready", async () => { + // This tests the race condition where a component tries to use + // useCopilotKit before the provider has fully initialized + + let providerReady = false; + + const DelayedProvider = ({ children }: { children: React.ReactNode }) => { + const [isReady, setIsReady] = useState(false); + + React.useEffect(() => { + const timer = setTimeout(() => { + setIsReady(true); + providerReady = true; + }, 50); + return () => clearTimeout(timer); + }, []); + + if (!isReady) { + return
{children}
; + } + + return {children}; + }; + + const { result } = renderHook(() => { + try { + return useCopilotKit(); + } catch (e) { + return { error: (e as Error).message }; + } + }, { + wrapper: DelayedProvider, + }); + + // Initially should throw error + expect(result.current).toHaveProperty("error"); + expect((result.current as any).error).toContain("must be used within CopilotKitProvider"); + + // Wait for provider to be ready + await waitFor(() => { + expect(providerReady).toBe(true); + }, { timeout: 100 }); + + // After provider is ready, should work + await waitFor(() => { + expect(result.current).toHaveProperty("copilotkit"); + }, { timeout: 100 }); + }); + }); + + describe("Subscription cleanup", () => { + it("should unsubscribe when component unmounts", async () => { + const { result, unmount } = renderHook(() => useCopilotKit(), { + wrapper: ({ children }) => ( + {children} + ), + }); + + const copilotkit = result.current.copilotkit; + const initialSubCount = (copilotkit as any).subscribers?.size; + + // Unmount the hook + unmount(); + + // Subscription count should decrease + await waitFor(() => { + const finalSubCount = (copilotkit as any).subscribers?.size; + expect(finalSubCount).toBeLessThan(initialSubCount); + }); + }); + + it("should handle multiple subscribers", async () => { + const { result: result1 } = renderHook(() => useCopilotKit(), { + wrapper: ({ children }) => ( + {children} + ), + }); + + const { result: result2, unmount: unmount2 } = renderHook(() => useCopilotKit(), { + wrapper: ({ children }) => ( + {children} + ), + }); + + // Both should work independently + expect(result1.current.copilotkit).toBeDefined(); + expect(result2.current.copilotkit).toBeDefined(); + + // Unmounting one shouldn't affect the other + unmount2(); + + expect(result1.current.copilotkit).toBeDefined(); + }); + }); + + describe("Edge cases", () => { + it("should handle rapid provider re-initialization", async () => { + let toggleProvider: () => void; + + const TogglingWrapper = ({ children }: { children: React.ReactNode }) => { + const [showProvider, setShowProvider] = useState(true); + toggleProvider = () => setShowProvider(prev => !prev); + + if (!showProvider) { + return
{children}
; + } + + return {children}; + }; + + const { result } = renderHook(() => { + try { + return useCopilotKit(); + } catch (e) { + return null; + } + }, { + wrapper: TogglingWrapper, + }); + + // Rapidly toggle provider + await act(async () => { + toggleProvider(); + await new Promise(resolve => setTimeout(resolve, 10)); + toggleProvider(); + await new Promise(resolve => setTimeout(resolve, 10)); + toggleProvider(); + await new Promise(resolve => setTimeout(resolve, 10)); + toggleProvider(); + }); + + // Should eventually stabilize with a working copilotkit + await waitFor(() => { + expect(result.current?.copilotkit).toBeDefined(); + }); + }); + }); +});