Skip to content

Close the stream when receiving "done" event from the server #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 55 additions & 75 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import Replicate, {
parseProgressFromLogs,
} from "replicate";
import nock from "nock";
import { Readable } from "node:stream";
import { createReadableStream } from "./lib/stream";
import { PassThrough } from "node:stream";

let client: Replicate;
const BASE_URL = "https://api.replicate.com/v1";
Expand Down Expand Up @@ -1187,16 +1187,17 @@ describe("Replicate client", () => {
// Continue with tests for other methods

describe("createReadableStream", () => {
function createStream(body: string | NodeJS.ReadableStream, status = 200) {
const streamEndpoint = "https://stream.replicate.com";
nock(streamEndpoint)
.get("/fake_stream")
.matchHeader("Accept", "text/event-stream")
.reply(status, body);

function createStream(body: string | ReadableStream, status = 200) {
const streamEndpoint = "https://stream.replicate.com/fake_stream";
const fetch = jest.fn((url) => {
if (url !== streamEndpoint) {
throw new Error(`Unmocked call to fetch() with url: ${url}`);
}
return new Response(body, { status });
});
return createReadableStream({
url: `${streamEndpoint}/fake_stream`,
fetch: fetch,
url: streamEndpoint,
fetch: fetch as any,
});
}

Expand Down Expand Up @@ -1330,9 +1331,6 @@ describe("Replicate client", () => {
});

test("supports the server writing data lines in multiple chunks", async () => {
const body = new PassThrough();
const stream = createStream(body);

// Create a stream of data chunks split on the pipe character for readability.
const data = `
event: output
Expand All @@ -1348,45 +1346,47 @@ describe("Replicate client", () => {
`.replace(/^[ ]+/gm, "");

const chunks = data.split("|");
const body = new ReadableStream({
async pull(controller) {
if (chunks.length) {
await new Promise((resolve) => setTimeout(resolve, 1));
const chunk = chunks.shift();
controller.enqueue(new TextEncoder().encode(chunk));
}
},
});

const stream = createStream(body);

// Consume the iterator in parallel to writing it.
const reading = new Promise((resolve, reject) => {
(async () => {
const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
})().then(resolve, reject);
const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});

// Write the chunks to the stream at an interval.
const writing = new Promise((resolve, reject) => {
(async () => {
for await (const chunk of chunks) {
body.write(chunk);
await new Promise((resolve) => setTimeout(resolve, 1));
}
body.end();
resolve(null);
})().then(resolve, reject);
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });

// Wait for both promises to resolve.
await Promise.all([reading, writing]);
});

test("supports the server writing data in a complete mess", async () => {
const body = new PassThrough();
const body = new ReadableStream({
async pull(controller) {
if (chunks.length) {
await new Promise((resolve) => setTimeout(resolve, 1));
const chunk = chunks.shift();
controller.enqueue(new TextEncoder().encode(chunk));
}
},
});
const stream = createStream(body);

// Create a stream of data chunks split on the pipe character for readability.
Expand All @@ -1407,40 +1407,20 @@ describe("Replicate client", () => {

const chunks = data.split("|");

// Consume the iterator in parallel to writing it.
const reading = new Promise((resolve, reject) => {
(async () => {
const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
})().then(resolve, reject);
const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});

// Write the chunks to the stream at an interval.
const writing = new Promise((resolve, reject) => {
(async () => {
for await (const chunk of chunks) {
body.write(chunk);
await new Promise((resolve) => setTimeout(resolve, 1));
}
body.end();
resolve(null);
})().then(resolve, reject);
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});

// Wait for both promises to resolve.
await Promise.all([reading, writing]);
expect(await iterator.next()).toEqual({ done: true });
});

test("supports ending without a done", async () => {
Expand Down
3 changes: 2 additions & 1 deletion integration/cloudflare-worker/.npmrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
package-lock=false

audit=false
fund=false
15 changes: 10 additions & 5 deletions integration/cloudflare-worker/index.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import { unstable_dev as dev } from "wrangler";
import { test, after, before, describe } from "node:test";
import assert from "node:assert";

/** @type {import("wrangler").UnstableDevWorker} */
describe("CloudFlare Worker", () => {
/** @type {import("wrangler").UnstableDevWorker} */
let worker;

before(async () => {
Expand All @@ -22,15 +22,20 @@ describe("CloudFlare Worker", () => {
await worker.stop();
});

test("worker streams back a response", { timeout: 1000 }, async () => {
test("worker streams back a response", { timeout: 5000 }, async () => {
const resp = await worker.fetch();
const text = await resp.text();

assert.ok(resp.ok, "status is 2xx");
assert(text.length > 0, "body.length is greater than 0");
assert.ok(resp.ok, `expected status to be 2xx but got ${resp.status}`);
assert(
text.length > 0,
"expected body to have content but got body.length of 0"
);
assert(
text.includes("Colin CloudFlare"),
"body includes stream characters"
`expected body to include "Colin CloudFlare" but got ${JSON.stringify(
text
)}`
);
});
});
3 changes: 2 additions & 1 deletion integration/commonjs/.npmrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
package-lock=false

audit=false
fund=false
3 changes: 2 additions & 1 deletion integration/esm/.npmrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
package-lock=false

audit=false
fund=false
3 changes: 2 additions & 1 deletion integration/typescript/.npmrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
package-lock=false

audit=false
fund=false
17 changes: 12 additions & 5 deletions lib/stream.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function createReadableStream({ url, fetch, options = {} }) {
const request = new Request(url, init);
controller.error(
new ApiError(
`Request to ${url} failed with status ${response.status}`,
`Request to ${url} failed with status ${response.status}: ${text}`,
request,
response
)
Expand All @@ -72,15 +72,22 @@ function createReadableStream({ url, fetch, options = {} }) {
const stream = response.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream());

for await (const event of stream) {
if (event.event === "error") {
controller.error(new Error(event.data));
} else {
controller.enqueue(
new ServerSentEvent(event.event, event.data, event.id)
);
break;
}

controller.enqueue(
new ServerSentEvent(event.event, event.data, event.id)
);

if (event.event === "done") {
break;
}
}

controller.close();
},
});
Expand Down