|
1 | 1 | const ApiError = require("./lib/error");
|
2 | 2 | const ModelVersionIdentifier = require("./lib/identifier");
|
| 3 | +const { Stream } = require("./lib/stream"); |
3 | 4 | const { withAutomaticRetries } = require("./lib/util");
|
4 | 5 |
|
5 | 6 | const collections = require("./lib/collections");
|
@@ -235,6 +236,50 @@ class Replicate {
|
235 | 236 | return response;
|
236 | 237 | }
|
237 | 238 |
|
| 239 | + /** |
| 240 | + * Stream a model and wait for its output. |
| 241 | + * |
| 242 | + * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" |
| 243 | + * @param {object} options |
| 244 | + * @param {object} options.input - Required. An object with the model inputs |
| 245 | + * @param {object} [options.wait] - Options for waiting for the prediction to finish |
| 246 | + * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 |
| 247 | + * @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit |
| 248 | + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output |
| 249 | + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) |
| 250 | + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction |
| 251 | + * @throws {Error} If the prediction failed |
| 252 | + * @yields {ServerSentEvent} Each streamed event from the prediction |
| 253 | + */ |
| 254 | + async *stream(ref, options, progress) { |
| 255 | + const { wait, ...data } = options; |
| 256 | + |
| 257 | + const identifier = ModelVersionIdentifier.parse(ref); |
| 258 | + |
| 259 | + let prediction; |
| 260 | + if (identifier.version) { |
| 261 | + prediction = await this.predictions.create({ |
| 262 | + ...data, |
| 263 | + version: identifier.version, |
| 264 | + stream: true, |
| 265 | + }); |
| 266 | + } else { |
| 267 | + prediction = await this.models.predictions.create( |
| 268 | + identifier.owner, |
| 269 | + identifier.name, |
| 270 | + { ...data, stream: true } |
| 271 | + ); |
| 272 | + } |
| 273 | + |
| 274 | + if (prediction.urls && prediction.urls.stream) { |
| 275 | + const { signal } = options; |
| 276 | + const stream = new Stream(prediction.urls.stream, { signal }); |
| 277 | + yield* stream; |
| 278 | + } else { |
| 279 | + throw new Error("Prediction does not support streaming"); |
| 280 | + } |
| 281 | + } |
| 282 | + |
238 | 283 | /**
|
239 | 284 | * Paginate through a list of results.
|
240 | 285 | *
|
|
0 commit comments