diff --git a/index.test.ts b/index.test.ts index 6624eb2..7e0ae22 100644 --- a/index.test.ts +++ b/index.test.ts @@ -7,8 +7,8 @@ import Replicate, { parseProgressFromLogs, } from "replicate"; import nock from "nock"; +import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; -import { PassThrough } from "node:stream"; let client: Replicate; const BASE_URL = "https://api.replicate.com/v1"; @@ -1187,16 +1187,17 @@ describe("Replicate client", () => { // Continue with tests for other methods describe("createReadableStream", () => { - function createStream(body: string | NodeJS.ReadableStream, status = 200) { - const streamEndpoint = "https://stream.replicate.com"; - nock(streamEndpoint) - .get("/fake_stream") - .matchHeader("Accept", "text/event-stream") - .reply(status, body); - + function createStream(body: string | ReadableStream, status = 200) { + const streamEndpoint = "https://stream.replicate.com/fake_stream"; + const fetch = jest.fn((url) => { + if (url !== streamEndpoint) { + throw new Error(`Unmocked call to fetch() with url: ${url}`); + } + return new Response(body, { status }); + }); return createReadableStream({ - url: `${streamEndpoint}/fake_stream`, - fetch: fetch, + url: streamEndpoint, + fetch: fetch as any, }); } @@ -1330,9 +1331,6 @@ describe("Replicate client", () => { }); test("supports the server writing data lines in multiple chunks", async () => { - const body = new PassThrough(); - const stream = createStream(body); - // Create a stream of data chunks split on the pipe character for readability. const data = ` event: output @@ -1348,45 +1346,47 @@ describe("Replicate client", () => { `.replace(/^[ ]+/gm, ""); const chunks = data.split("|"); + const body = new ReadableStream({ + async pull(controller) { + if (chunks.length) { + await new Promise((resolve) => setTimeout(resolve, 1)); + const chunk = chunks.shift(); + controller.enqueue(new TextEncoder().encode(chunk)); + } + }, + }); + + const stream = createStream(body); // Consume the iterator in parallel to writing it. - const reading = new Promise((resolve, reject) => { - (async () => { - const iterator = stream[Symbol.asyncIterator](); - expect(await iterator.next()).toEqual({ - done: false, - value: { - event: "output", - id: "EVENT_1", - data: "hello,\nthis is a new line,\nand this is a new line too", - }, - }); - expect(await iterator.next()).toEqual({ - done: false, - value: { event: "done", id: "EVENT_2", data: "{}" }, - }); - expect(await iterator.next()).toEqual({ done: true }); - })().then(resolve, reject); + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, }); - - // Write the chunks to the stream at an interval. - const writing = new Promise((resolve, reject) => { - (async () => { - for await (const chunk of chunks) { - body.write(chunk); - await new Promise((resolve) => setTimeout(resolve, 1)); - } - body.end(); - resolve(null); - })().then(resolve, reject); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, }); + expect(await iterator.next()).toEqual({ done: true }); // Wait for both promises to resolve. - await Promise.all([reading, writing]); }); test("supports the server writing data in a complete mess", async () => { - const body = new PassThrough(); + const body = new ReadableStream({ + async pull(controller) { + if (chunks.length) { + await new Promise((resolve) => setTimeout(resolve, 1)); + const chunk = chunks.shift(); + controller.enqueue(new TextEncoder().encode(chunk)); + } + }, + }); const stream = createStream(body); // Create a stream of data chunks split on the pipe character for readability. @@ -1407,40 +1407,20 @@ describe("Replicate client", () => { const chunks = data.split("|"); - // Consume the iterator in parallel to writing it. - const reading = new Promise((resolve, reject) => { - (async () => { - const iterator = stream[Symbol.asyncIterator](); - expect(await iterator.next()).toEqual({ - done: false, - value: { - event: "output", - id: "EVENT_1", - data: "hello,\nthis is a new line,\nand this is a new line too", - }, - }); - expect(await iterator.next()).toEqual({ - done: false, - value: { event: "done", id: "EVENT_2", data: "{}" }, - }); - expect(await iterator.next()).toEqual({ done: true }); - })().then(resolve, reject); + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, }); - - // Write the chunks to the stream at an interval. - const writing = new Promise((resolve, reject) => { - (async () => { - for await (const chunk of chunks) { - body.write(chunk); - await new Promise((resolve) => setTimeout(resolve, 1)); - } - body.end(); - resolve(null); - })().then(resolve, reject); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, }); - - // Wait for both promises to resolve. - await Promise.all([reading, writing]); + expect(await iterator.next()).toEqual({ done: true }); }); test("supports ending without a done", async () => { diff --git a/integration/cloudflare-worker/.npmrc b/integration/cloudflare-worker/.npmrc index b15cbc2..7775040 100644 --- a/integration/cloudflare-worker/.npmrc +++ b/integration/cloudflare-worker/.npmrc @@ -1,2 +1,3 @@ package-lock=false - +audit=false +fund=false diff --git a/integration/cloudflare-worker/index.test.js b/integration/cloudflare-worker/index.test.js index 0c0fc5e..932d8f5 100644 --- a/integration/cloudflare-worker/index.test.js +++ b/integration/cloudflare-worker/index.test.js @@ -3,8 +3,8 @@ import { unstable_dev as dev } from "wrangler"; import { test, after, before, describe } from "node:test"; import assert from "node:assert"; -/** @type {import("wrangler").UnstableDevWorker} */ describe("CloudFlare Worker", () => { + /** @type {import("wrangler").UnstableDevWorker} */ let worker; before(async () => { @@ -22,15 +22,20 @@ describe("CloudFlare Worker", () => { await worker.stop(); }); - test("worker streams back a response", { timeout: 1000 }, async () => { + test("worker streams back a response", { timeout: 5000 }, async () => { const resp = await worker.fetch(); const text = await resp.text(); - assert.ok(resp.ok, "status is 2xx"); - assert(text.length > 0, "body.length is greater than 0"); + assert.ok(resp.ok, `expected status to be 2xx but got ${resp.status}`); + assert( + text.length > 0, + "expected body to have content but got body.length of 0" + ); assert( text.includes("Colin CloudFlare"), - "body includes stream characters" + `expected body to include "Colin CloudFlare" but got ${JSON.stringify( + text + )}` ); }); }); diff --git a/integration/commonjs/.npmrc b/integration/commonjs/.npmrc index b15cbc2..7775040 100644 --- a/integration/commonjs/.npmrc +++ b/integration/commonjs/.npmrc @@ -1,2 +1,3 @@ package-lock=false - +audit=false +fund=false diff --git a/integration/esm/.npmrc b/integration/esm/.npmrc index b15cbc2..7775040 100644 --- a/integration/esm/.npmrc +++ b/integration/esm/.npmrc @@ -1,2 +1,3 @@ package-lock=false - +audit=false +fund=false diff --git a/integration/typescript/.npmrc b/integration/typescript/.npmrc index b15cbc2..7775040 100644 --- a/integration/typescript/.npmrc +++ b/integration/typescript/.npmrc @@ -1,2 +1,3 @@ package-lock=false - +audit=false +fund=false diff --git a/lib/stream.js b/lib/stream.js index a97642d..cd9274c 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -62,7 +62,7 @@ function createReadableStream({ url, fetch, options = {} }) { const request = new Request(url, init); controller.error( new ApiError( - `Request to ${url} failed with status ${response.status}`, + `Request to ${url} failed with status ${response.status}: ${text}`, request, response ) @@ -72,15 +72,22 @@ function createReadableStream({ url, fetch, options = {} }) { const stream = response.body .pipeThrough(new TextDecoderStream()) .pipeThrough(new EventSourceParserStream()); + for await (const event of stream) { if (event.event === "error") { controller.error(new Error(event.data)); - } else { - controller.enqueue( - new ServerSentEvent(event.event, event.data, event.id) - ); + break; + } + + controller.enqueue( + new ServerSentEvent(event.event, event.data, event.id) + ); + + if (event.event === "done") { + break; } } + controller.close(); }, });