Skip to content

Commit 6874d7a

Browse files
committed
Add replicate.stream method
1 parent cee886e commit 6874d7a

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

index.js

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
3+
const { Stream } = require("./lib/stream");
34
const { withAutomaticRetries } = require("./lib/util");
45

56
const collections = require("./lib/collections");
@@ -235,6 +236,50 @@ class Replicate {
235236
return response;
236237
}
237238

239+
/**
240+
* Stream a model and wait for its output.
241+
*
242+
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
243+
* @param {object} options
244+
* @param {object} options.input - Required. An object with the model inputs
245+
* @param {object} [options.wait] - Options for waiting for the prediction to finish
246+
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250
247+
* @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit
248+
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
249+
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
250+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
251+
* @throws {Error} If the prediction failed
252+
* @yields {ServerSentEvent} Each streamed event from the prediction
253+
*/
254+
async *stream(ref, options, progress) {
255+
const { wait, ...data } = options;
256+
257+
const identifier = ModelVersionIdentifier.parse(ref);
258+
259+
let prediction;
260+
if (identifier.version) {
261+
prediction = await this.predictions.create({
262+
...data,
263+
version: identifier.version,
264+
stream: true,
265+
});
266+
} else {
267+
prediction = await this.models.predictions.create(
268+
identifier.owner,
269+
identifier.name,
270+
{ ...data, stream: true }
271+
);
272+
}
273+
274+
if (prediction.urls && prediction.urls.stream) {
275+
const { signal } = options;
276+
const stream = new Stream(prediction.urls.stream, { signal });
277+
yield* stream;
278+
} else {
279+
throw new Error("Prediction does not support streaming");
280+
}
281+
}
282+
238283
/**
239284
* Paginate through a list of results.
240285
*

lib/stream.js

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
const { Readable } = require("stream");
2+
3+
class ServerSentEvent {
4+
constructor(event, data, id, retry) {
5+
this.event = event;
6+
this.data = data;
7+
this.id = id;
8+
this.retry = retry;
9+
}
10+
11+
toString() {
12+
if (this.event === "output") {
13+
return this.data;
14+
}
15+
16+
return "";
17+
}
18+
}
19+
20+
class Stream extends Readable {
21+
constructor(url, options) {
22+
super();
23+
this.url = url;
24+
this.options = options;
25+
26+
this.event = null;
27+
this.data = [];
28+
this.lastEventId = null;
29+
this.retry = null;
30+
}
31+
32+
decode(line) {
33+
if (!line) {
34+
if (!this.event && !this.data.length && !this.lastEventId) {
35+
return null;
36+
}
37+
38+
const sse = new ServerSentEvent(
39+
this.event,
40+
this.data.join("\n"),
41+
this.lastEventId
42+
);
43+
44+
this.event = null;
45+
this.data = [];
46+
this.retry = null;
47+
48+
return sse;
49+
}
50+
51+
if (line.startsWith(":")) {
52+
return null;
53+
}
54+
55+
const [field, value] = line.split(": ");
56+
if (field === "event") {
57+
this.event = value;
58+
} else if (field === "data") {
59+
this.data.push(value);
60+
} else if (field === "id") {
61+
this.lastEventId = value;
62+
}
63+
64+
return null;
65+
}
66+
67+
async *[Symbol.asyncIterator]() {
68+
const response = await fetch(this.url, {
69+
...this.options,
70+
headers: {
71+
Accept: "text/event-stream",
72+
},
73+
});
74+
75+
for await (const chunk of response.body) {
76+
const decoder = new TextDecoder("utf-8");
77+
const text = decoder.decode(chunk);
78+
const lines = text.split("\n");
79+
for (const line of lines) {
80+
const sse = this.decode(line);
81+
if (sse) {
82+
if (sse.event === "error") {
83+
throw new Error(sse.data);
84+
}
85+
86+
yield sse;
87+
88+
if (sse.event === "done") {
89+
return;
90+
}
91+
}
92+
}
93+
}
94+
}
95+
}
96+
97+
module.exports = {
98+
Stream,
99+
ServerSentEvent,
100+
};

0 commit comments

Comments
 (0)