Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relay client-side fetch requests to the server using the Storybook channel API #331

Merged
merged 8 commits into from
Sep 6, 2024
10 changes: 5 additions & 5 deletions src/Panel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
IS_OFFLINE,
IS_OUTDATED,
LOCAL_BUILD_PROGRESS,
PANEL_ID,
REMOVE_ADDON,
TELEMETRY,
} from "./constants";
Expand All @@ -28,9 +27,10 @@ import { ControlsProvider } from "./screens/VisualTests/ControlsContext";
import { RunBuildProvider } from "./screens/VisualTests/RunBuildContext";
import { VisualTests } from "./screens/VisualTests/VisualTests";
import { GitInfoPayload, LocalBuildProgress, UpdateStatusFunction } from "./types";
import { client, Provider, useAccessToken } from "./utils/graphQLClient";
import { createClient, GraphQLClientProvider, useAccessToken } from "./utils/graphQLClient";
import { TelemetryProvider } from "./utils/TelemetryContext";
import { useBuildEvents } from "./utils/useBuildEvents";
import { useChannelFetch } from "./utils/useChannelFetch";
import { useProjectId } from "./utils/useProjectId";
import { clearSessionState, useSessionState } from "./utils/useSessionState";
import { useSharedState } from "./utils/useSharedState";
Expand Down Expand Up @@ -93,8 +93,9 @@ export const Panel = ({ active, api }: PanelProps) => {
const trackEvent = useCallback((data: any) => emit(TELEMETRY, data), [emit]);
const { isRunning, startBuild, stopBuild } = useBuildEvents({ localBuildProgress, accessToken });

const fetch = useChannelFetch();
const withProviders = (children: React.ReactNode) => (
<Provider key={PANEL_ID} value={client}>
<GraphQLClientProvider value={createClient({ fetch })}>
<TelemetryProvider value={trackEvent}>
<AuthProvider value={{ accessToken, setAccessToken }}>
<UninstallProvider
Expand All @@ -111,7 +112,7 @@ export const Panel = ({ active, api }: PanelProps) => {
</UninstallProvider>
</AuthProvider>
</TelemetryProvider>
</Provider>
</GraphQLClientProvider>
);

if (!active) {
Expand All @@ -134,7 +135,6 @@ export const Panel = ({ active, api }: PanelProps) => {
if (!accessToken) {
return withProviders(
<Authentication
key={PANEL_ID}
setAccessToken={setAccessToken}
setCreatedProjectId={setCreatedProjectId}
hasProjectId={!!projectId}
Expand Down
4 changes: 4 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ export const ENABLE_FILTER = `${ADDON_ID}/enableFilter`;
export const REMOVE_ADDON = `${ADDON_ID}/removeAddon`;
export const PARAM_KEY = "chromatic";

export const FETCH_ABORTED = `${ADDON_ID}/ChannelFetch/aborted`;
export const FETCH_REQUEST = `${ADDON_ID}ChannelFetch/request`;
export const FETCH_RESPONSE = `${ADDON_ID}ChannelFetch/response`;

export const CONFIG_OVERRIDES = {
// Local changes should never be auto-accepted
autoAcceptChanges: false,
Expand Down
4 changes: 4 additions & 0 deletions src/preset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {
LocalBuildProgress,
ProjectInfoPayload,
} from "./types";
import { ChannelFetch } from "./utils/ChannelFetch";
import { SharedState } from "./utils/SharedState";
import { updateChromaticConfig } from "./utils/updateChromaticConfig";

Expand Down Expand Up @@ -160,6 +161,9 @@ const watchConfigFile = async (
async function serverChannel(channel: Channel, options: Options & { configFile?: string }) {
const { configFile, presets } = options;

// Handle relayed fetch requests from the client
ChannelFetch.subscribe(ADDON_ID, channel);

// Lazy load these APIs since we don't need them right away
const apiPromise = presets.apply<any>("experimental_serverAPI");
const corePromise = presets.apply("core");
Expand Down
96 changes: 96 additions & 0 deletions src/utils/ChannelFetch.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { beforeEach, describe, expect, it, vi } from "vitest";

import { FETCH_ABORTED, FETCH_REQUEST, FETCH_RESPONSE } from "../constants";
import { ChannelFetch } from "./ChannelFetch";
import { MockChannel } from "./MockChannel";

const resolveAfter = (ms: number, value: any) =>
new Promise((resolve) => setTimeout(resolve, ms, value));

const rejectAfter = (ms: number, reason: any) =>
new Promise((_, reject) => setTimeout(reject, ms, reason));

describe("ChannelFetch", () => {
let channel: MockChannel;

beforeEach(() => {
channel = new MockChannel();
});

it("should handle fetch requests", async () => {
const fetch = vi.fn(() => resolveAfter(100, { headers: [], text: async () => "data" }));
ChannelFetch.subscribe("req", channel, fetch as any);

channel.emit(FETCH_REQUEST, {
requestId: "req",
input: "https://example.com",
init: { headers: { foo: "bar" } },
});

await vi.waitFor(() => {
expect(fetch).toHaveBeenCalledWith("https://example.com", {
headers: { foo: "bar" },
signal: expect.any(AbortSignal),
});
});
});

it("should send fetch responses", async () => {
const fetch = vi.fn(() => resolveAfter(100, { headers: [], text: async () => "data" }));
const instance = ChannelFetch.subscribe("res", channel, fetch as any);

const promise = new Promise<void>((resolve) => {
channel.on(FETCH_RESPONSE, ({ response, error }) => {
expect(response.body).toBe("data");
expect(error).toBeUndefined();
resolve();
});
});

channel.emit(FETCH_REQUEST, { requestId: "res", input: "https://example.com" });
await vi.waitFor(() => {
expect(instance.abortControllers.size).toBe(1);
});

await promise;

expect(instance.abortControllers.size).toBe(0);
});

it("should send fetch error responses", async () => {
const fetch = vi.fn(() => rejectAfter(100, new Error("oops")));
const instance = ChannelFetch.subscribe("err", channel, fetch as any);

const promise = new Promise<void>((resolve) => {
channel.on(FETCH_RESPONSE, ({ response, error }) => {
expect(response).toBeUndefined();
expect(error).toMatch(/oops/);
resolve();
});
});

channel.emit(FETCH_REQUEST, { requestId: "err", input: "https://example.com" });
await vi.waitFor(() => {
expect(instance.abortControllers.size).toBe(1);
});

await promise;
expect(instance.abortControllers.size).toBe(0);
});

it("should abort fetch requests", async () => {
const fetch = vi.fn((input, init) => new Promise<Response>(() => {}));
const instance = ChannelFetch.subscribe("abort", channel, fetch);

channel.emit(FETCH_REQUEST, { requestId: "abort", input: "https://example.com" });
await vi.waitFor(() => {
expect(instance.abortControllers.size).toBe(1);
});

channel.emit(FETCH_ABORTED, { requestId: "abort" });
await vi.waitFor(() => {
expect(fetch.mock.lastCall?.[1].signal.aborted).toBe(true);
expect(instance.abortControllers.size).toBe(0);
});
});
});
47 changes: 47 additions & 0 deletions src/utils/ChannelFetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import type { Channel } from "@storybook/channels";

import { FETCH_ABORTED, FETCH_REQUEST, FETCH_RESPONSE } from "../constants";

type ChannelLike = Pick<Channel, "emit" | "on" | "off">;

const instances = new Map<string, ChannelFetch>();

export class ChannelFetch {
channel: ChannelLike;

abortControllers: Map<string, AbortController>;

constructor(channel: ChannelLike, _fetch = fetch) {
this.channel = channel;
this.abortControllers = new Map<string, AbortController>();

this.channel.on(FETCH_ABORTED, ({ requestId }) => {
this.abortControllers.get(requestId)?.abort();
this.abortControllers.delete(requestId);
});

this.channel.on(FETCH_REQUEST, async ({ requestId, input, init }) => {
const controller = new AbortController();
this.abortControllers.set(requestId, controller);

try {
const res = await _fetch(input as RequestInfo, { ...init, signal: controller.signal });
const body = await res.text();
const headers = Array.from(res.headers as any);
const response = { body, headers, status: res.status, statusText: res.statusText };
this.channel.emit(FETCH_RESPONSE, { requestId, response });
} catch (err) {
const error = err instanceof Error ? err.message : String(err);
this.channel.emit(FETCH_RESPONSE, { requestId, error });
} finally {
this.abortControllers.delete(requestId);
}
});
}

static subscribe(key: string, channel: ChannelLike, _fetch = fetch) {
const instance = instances.get(key) || new ChannelFetch(channel, _fetch);
if (!instances.has(key)) instances.set(key, instance);
return instance;
}
}
16 changes: 16 additions & 0 deletions src/utils/MockChannel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
export class MockChannel {
private listeners: Record<string, ((...args: any[]) => void)[]> = {};

on(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = [...(this.listeners[event] ?? []), listener];
}

off(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = (this.listeners[event] ?? []).filter((l) => l !== listener);
}

emit(event: string, ...args: any[]) {
// setTimeout is used to simulate the asynchronous nature of the real channel
(this.listeners[event] || []).forEach((listener) => setTimeout(() => listener(...args)));
}
}
18 changes: 1 addition & 17 deletions src/utils/SharedState.test.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,8 @@
import { beforeEach, describe, expect, it } from "vitest";

import { MockChannel } from "./MockChannel";
import { SharedState } from "./SharedState";

class MockChannel {
private listeners: Record<string, ((...args: any[]) => void)[]> = {};

on(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = [...(this.listeners[event] ?? []), listener];
}

off(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = (this.listeners[event] ?? []).filter((l) => l !== listener);
}

emit(event: string, ...args: any[]) {
// setTimeout is used to simulate the asynchronous nature of the real channel
(this.listeners[event] || []).forEach((listener) => setTimeout(() => listener(...args)));
}
}

const tick = () => new Promise((resolve) => setTimeout(resolve, 0));

describe("SharedState", () => {
Expand Down
104 changes: 54 additions & 50 deletions src/utils/graphQLClient.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { useAddonState } from "@storybook/manager-api";
import { authExchange } from "@urql/exchange-auth";
import React from "react";
import { Client, fetchExchange, mapExchange, Provider } from "urql";
import { Client, ClientOptions, fetchExchange, mapExchange, Provider } from "urql";
import { v4 as uuid } from "uuid";

import { ACCESS_TOKEN_KEY, ADDON_ID, CHROMATIC_API_URL } from "../constants";

export { Provider };

let currentToken: string | null;
let currentTokenExpiration: number | null;
const setCurrentToken = (token: string | null) => {
Expand Down Expand Up @@ -56,56 +54,62 @@ export const getFetchOptions = (token?: string) => ({
},
});

export const client = new Client({
url: CHROMATIC_API_URL,
exchanges: [
// We don't use cacheExchange, because it would inadvertently share data between stories.
mapExchange({
onResult(result) {
// Not all queries contain the `viewer` field, in which case it will be `undefined`.
// When we do retrieve the field but the token is invalid, it will be `null`.
if (result.data?.viewer === null) setCurrentToken(null);
},
}),
authExchange(async (utils) => {
return {
addAuthToOperation(operation) {
if (!currentToken) return operation;
return utils.appendHeaders(operation, { Authorization: `Bearer ${currentToken}` });
export const createClient = (options?: Partial<ClientOptions>) =>
new Client({
url: CHROMATIC_API_URL,
exchanges: [
// We don't use cacheExchange, because it would inadvertently share data between stories.
mapExchange({
onResult(result) {
// Not all queries contain the `viewer` field, in which case it will be `undefined`.
// When we do retrieve the field but the token is invalid, it will be `null`.
if (result.data?.viewer === null) setCurrentToken(null);
},
}),
authExchange(async (utils) => {
return {
addAuthToOperation(operation) {
if (!currentToken) return operation;
return utils.appendHeaders(operation, { Authorization: `Bearer ${currentToken}` });
},

// Determine if the current error is an authentication error.
didAuthError: (error) =>
error.response.status === 401 ||
error.graphQLErrors.some((e) => e.message.includes("Must login")),
// Determine if the current error is an authentication error.
didAuthError: (error) =>
error.response.status === 401 ||
error.graphQLErrors.some((e) => e.message.includes("Must login")),

// If didAuthError returns true, clear the token. Ideally we should refresh the token here.
// The operation will be retried automatically.
async refreshAuth() {
setCurrentToken(null);
},
// If didAuthError returns true, clear the token. Ideally we should refresh the token here.
// The operation will be retried automatically.
async refreshAuth() {
setCurrentToken(null);
},

// Prevent making a request if we know the token is missing, invalid or expired.
// This handler is called repeatedly so we avoid parsing the token each time.
willAuthError() {
if (!currentToken) return true;
try {
if (!currentTokenExpiration) {
const { exp } = JSON.parse(atob(currentToken.split(".")[1]));
currentTokenExpiration = exp;
// Prevent making a request if we know the token is missing, invalid or expired.
// This handler is called repeatedly so we avoid parsing the token each time.
willAuthError() {
if (!currentToken) return true;
try {
if (!currentTokenExpiration) {
const { exp } = JSON.parse(atob(currentToken.split(".")[1]));
currentTokenExpiration = exp;
}
return Date.now() / 1000 > (currentTokenExpiration || 0);
} catch (e) {
return true;
}
return Date.now() / 1000 > (currentTokenExpiration || 0);
} catch (e) {
return true;
}
},
};
}),
fetchExchange,
],
fetchOptions: getFetchOptions(), // Auth header (token) is handled by authExchange
});
},
};
}),
fetchExchange,
],
fetchOptions: getFetchOptions(), // Auth header (token) is handled by authExchange
...options,
});

export const GraphQLClientProvider = ({ children }: { children: React.ReactNode }) => {
return <Provider value={client}>{children}</Provider>;
};
export const GraphQLClientProvider = ({
children,
value = createClient(),
}: {
children: React.ReactNode;
value?: Client;
}) => <Provider value={value}>{children}</Provider>;
Loading
Loading