Skip to content

Commit 1113eec

Browse files
committed
Fix tests
1 parent a6a7c46 commit 1113eec

File tree

6 files changed

+152
-51
lines changed

6 files changed

+152
-51
lines changed

packages/core/src/v3/realtimeStreams/types.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { AnyZodFetchOptions, ApiRequestOptions } from "../apiClient/core.js";
22
import { AsyncIterableStream } from "../streams/asyncIterableStream.js";
3+
import { Prettify } from "../types/utils.js";
34

45
export type RealtimeStreamOperationOptions = {
56
signal?: AbortSignal;
@@ -134,9 +135,11 @@ export type AppendStreamOptions = {
134135
requestOptions?: ApiRequestOptions;
135136
};
136137

137-
export type WriterStreamOptions<TPart> = {
138-
execute: (options: {
139-
write: (part: TPart) => void;
140-
merge(stream: ReadableStream<TPart>): void;
141-
}) => Promise<void> | void;
142-
};
138+
export type WriterStreamOptions<TPart> = Prettify<
139+
PipeStreamOptions & {
140+
execute: (options: {
141+
write: (part: TPart) => void;
142+
merge(stream: ReadableStream<TPart>): void;
143+
}) => Promise<void> | void;
144+
}
145+
>;

packages/core/test/streamsWriterV1.test.ts

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { describe, it, expect, beforeEach, afterEach } from "vitest";
22
import { createServer, Server, IncomingMessage, ServerResponse } from "node:http";
33
import { AddressInfo } from "node:net";
44
import { StreamsWriterV1 } from "../src/v3/realtimeStreams/streamsWriterV1.js";
5+
import { ensureReadableStream } from "../src/v3/streams/asyncIterableStream.js";
56

67
type RequestHandler = (req: IncomingMessage, res: ServerResponse) => void;
78

@@ -71,7 +72,7 @@ describe("StreamsWriterV1", () => {
7172
baseUrl,
7273
runId: "run_123",
7374
key: "test-stream",
74-
source: generateChunks(),
75+
source: ensureReadableStream(generateChunks()),
7576
});
7677

7778
await metadataStream.wait();
@@ -99,7 +100,7 @@ describe("StreamsWriterV1", () => {
99100
baseUrl,
100101
runId: "run_123",
101102
key: "test-stream",
102-
source: generateChunks(),
103+
source: ensureReadableStream(generateChunks()),
103104
clientId: "custom-client-123",
104105
});
105106

@@ -142,7 +143,7 @@ describe("StreamsWriterV1", () => {
142143
baseUrl,
143144
runId: "run_123",
144145
key: "test-stream",
145-
source: generateChunks(),
146+
source: ensureReadableStream(generateChunks()),
146147
});
147148

148149
await metadataStream.wait();
@@ -191,7 +192,7 @@ describe("StreamsWriterV1", () => {
191192
baseUrl,
192193
runId: "run_123",
193194
key: "test-stream",
194-
source: generateChunks(),
195+
source: ensureReadableStream(generateChunks()),
195196
});
196197

197198
await metadataStream.wait();
@@ -235,7 +236,7 @@ describe("StreamsWriterV1", () => {
235236
baseUrl,
236237
runId: "run_123",
237238
key: "test-stream",
238-
source: generateChunks(),
239+
source: ensureReadableStream(generateChunks()),
239240
});
240241

241242
await metadataStream.wait();
@@ -278,7 +279,7 @@ describe("StreamsWriterV1", () => {
278279
baseUrl,
279280
runId: "run_123",
280281
key: "test-stream",
281-
source: generateChunks(),
282+
source: ensureReadableStream(generateChunks()),
282283
maxBufferSize: 100, // Small buffer for testing
283284
});
284285

@@ -323,7 +324,7 @@ describe("StreamsWriterV1", () => {
323324
baseUrl,
324325
runId: "run_123",
325326
key: "test-stream",
326-
source: generateChunks(),
327+
source: ensureReadableStream(generateChunks()),
327328
maxRetries: 3, // Low retry count for faster test
328329
});
329330

@@ -370,7 +371,7 @@ describe("StreamsWriterV1", () => {
370371
baseUrl,
371372
runId: "run_123",
372373
key: "test-stream",
373-
source: generateChunks(),
374+
source: ensureReadableStream(generateChunks()),
374375
});
375376

376377
await metadataStream.wait();
@@ -418,7 +419,7 @@ describe("StreamsWriterV1", () => {
418419
baseUrl,
419420
runId: "run_123",
420421
key: "test-stream",
421-
source: generateChunks(),
422+
source: ensureReadableStream(generateChunks()),
422423
});
423424

424425
await metadataStream.wait();
@@ -459,7 +460,7 @@ describe("StreamsWriterV1", () => {
459460
baseUrl,
460461
runId: "run_123",
461462
key: "test-stream",
462-
source: generateChunks(),
463+
source: ensureReadableStream(generateChunks()),
463464
});
464465

465466
await metadataStream.wait();
@@ -482,7 +483,7 @@ describe("StreamsWriterV1", () => {
482483
baseUrl,
483484
runId: "run_123",
484485
key: "test-stream",
485-
source: generateChunks(),
486+
source: ensureReadableStream(generateChunks()),
486487
});
487488

488489
await metadataStream.wait();
@@ -532,7 +533,7 @@ describe("StreamsWriterV1", () => {
532533
baseUrl,
533534
runId: "run_123",
534535
key: "test-stream",
535-
source: generateChunks(),
536+
source: ensureReadableStream(generateChunks()),
536537
maxBufferSize: 100, // Large enough to hold all chunks
537538
});
538539

@@ -591,7 +592,7 @@ describe("StreamsWriterV1", () => {
591592
baseUrl,
592593
runId: "run_123",
593594
key: "test-stream",
594-
source: generateChunks(),
595+
source: ensureReadableStream(generateChunks()),
595596
});
596597

597598
await metadataStream.wait();
@@ -641,7 +642,7 @@ describe("StreamsWriterV1", () => {
641642
baseUrl,
642643
runId: "run_123",
643644
key: "test-stream",
644-
source: generateChunks(),
645+
source: ensureReadableStream(generateChunks()),
645646
maxBufferSize: 50, // Small buffer - will overflow
646647
});
647648

@@ -663,7 +664,7 @@ describe("StreamsWriterV1", () => {
663664
baseUrl,
664665
runId: "run_123",
665666
key: "test-stream",
666-
source: generateChunks(),
667+
source: ensureReadableStream(generateChunks()),
667668
});
668669

669670
// Consumer reads from the stream
@@ -701,7 +702,7 @@ describe("StreamsWriterV1", () => {
701702
baseUrl,
702703
runId: "run_123",
703704
key: "test-stream",
704-
source: generateChunks(),
705+
source: ensureReadableStream(generateChunks()),
705706
});
706707

707708
await expect(metadataStream.wait()).rejects.toThrow("HTTP error! status: 400");
@@ -743,7 +744,7 @@ describe("StreamsWriterV1", () => {
743744
baseUrl,
744745
runId: "run_123",
745746
key: "test-stream",
746-
source: generateChunks(),
747+
source: ensureReadableStream(generateChunks()),
747748
});
748749

749750
await metadataStream.wait();
@@ -774,7 +775,7 @@ describe("StreamsWriterV1", () => {
774775
baseUrl,
775776
runId: "run_123",
776777
key: "test-stream",
777-
source: generateChunks(),
778+
source: ensureReadableStream(generateChunks()),
778779
signal: abortController.signal,
779780
});
780781

@@ -798,7 +799,7 @@ describe("StreamsWriterV1", () => {
798799
baseUrl,
799800
runId: "run_123",
800801
key: "test-stream",
801-
source: generateChunks(),
802+
source: ensureReadableStream(generateChunks()),
802803
});
803804

804805
await metadataStream.wait();
@@ -824,7 +825,7 @@ describe("StreamsWriterV1", () => {
824825
baseUrl,
825826
runId: "run_123",
826827
key: "test-stream",
827-
source: generateChunks(),
828+
source: ensureReadableStream(generateChunks()),
828829
});
829830

830831
await metadataStream.wait();
@@ -865,7 +866,7 @@ describe("StreamsWriterV1", () => {
865866
baseUrl,
866867
runId: "run_123",
867868
key: "test-stream",
868-
source: generateChunks(),
869+
source: ensureReadableStream(generateChunks()),
869870
});
870871

871872
await metadataStream.wait();
@@ -915,7 +916,7 @@ describe("StreamsWriterV1", () => {
915916
baseUrl,
916917
runId: "run_123",
917918
key: "test-stream",
918-
source: generateChunks(),
919+
source: ensureReadableStream(generateChunks()),
919920
});
920921

921922
await metadataStream.wait();
@@ -961,7 +962,7 @@ describe("StreamsWriterV1", () => {
961962
baseUrl,
962963
runId: "run_123",
963964
key: "test-stream",
964-
source: generateChunks(),
965+
source: ensureReadableStream(generateChunks()),
965966
maxBufferSize: 50, // Small buffer
966967
});
967968

packages/trigger-sdk/src/v3/streams.ts

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ function pipe<T>(
125125
value = keyOrValue;
126126
opts = valueOrOptions as PipeStreamOptions | undefined;
127127
}
128+
129+
return pipeInternal(key, value, opts, "streams.pipe()");
130+
}
131+
132+
/**
133+
* Internal pipe implementation that allows customizing the span name.
134+
* This is used by both the public `pipe` method and the `writer` method.
135+
*/
136+
function pipeInternal<T>(
137+
key: string,
138+
value: AsyncIterable<T> | ReadableStream<T>,
139+
opts: PipeStreamOptions | undefined,
140+
spanName: string
141+
): PipeStreamResult<T> {
128142
const runId = getRunIdForOptions(opts);
129143

130144
if (!runId) {
@@ -133,7 +147,7 @@ function pipe<T>(
133147
);
134148
}
135149

136-
const span = tracer.startSpan("streams.pipe()", {
150+
const span = tracer.startSpan(spanName, {
137151
attributes: {
138152
key,
139153
runId,
@@ -376,7 +390,74 @@ function isAppendStreamOptions(val: unknown): val is AppendStreamOptions {
376390
);
377391
}
378392

379-
function writer<TPart>(options: WriterStreamOptions<TPart>) {}
393+
function writer<TPart>(key: string, options: WriterStreamOptions<TPart>) {
394+
let controller!: ReadableStreamDefaultController<TPart>;
395+
396+
const ongoingStreamPromises: Promise<void>[] = [];
397+
398+
const stream = new ReadableStream({
399+
start(controllerArg) {
400+
controller = controllerArg;
401+
},
402+
});
403+
404+
function safeEnqueue(data: TPart) {
405+
try {
406+
controller.enqueue(data);
407+
} catch (error) {
408+
// suppress errors when the stream has been closed
409+
}
410+
}
411+
412+
try {
413+
const result = options.execute({
414+
write(part) {
415+
safeEnqueue(part);
416+
},
417+
merge(streamArg) {
418+
ongoingStreamPromises.push(
419+
(async () => {
420+
const reader = streamArg.getReader();
421+
while (true) {
422+
const { done, value } = await reader.read();
423+
if (done) break;
424+
safeEnqueue(value);
425+
}
426+
})().catch((error) => {
427+
console.error(error);
428+
})
429+
);
430+
},
431+
});
432+
433+
if (result) {
434+
ongoingStreamPromises.push(
435+
result.catch((error) => {
436+
console.error(error);
437+
})
438+
);
439+
}
440+
} catch (error) {
441+
console.error(error);
442+
}
443+
444+
const waitForStreams: Promise<void> = new Promise(async (resolve) => {
445+
while (ongoingStreamPromises.length > 0) {
446+
await ongoingStreamPromises.shift();
447+
}
448+
resolve();
449+
});
450+
451+
waitForStreams.finally(() => {
452+
try {
453+
controller.close();
454+
} catch (error) {
455+
// suppress errors when the stream has been closed
456+
}
457+
});
458+
459+
return pipeInternal(key, stream, options, "streams.writer()");
460+
}
380461

381462
export type RealtimeDefineStreamOptions = {
382463
id: string;
@@ -395,7 +476,7 @@ function define<TPart>(opts: RealtimeDefineStreamOptions): RealtimeDefinedStream
395476
return append(opts.id, value as BodyInit, options);
396477
},
397478
writer(options) {
398-
return writer(options);
479+
return writer(opts.id, options);
399480
},
400481
};
401482
}

references/realtime-streams/src/components/ai-chat.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ function AIChatStats({ accessToken, runId }: { accessToken: string; runId: strin
3232
);
3333
}
3434

35+
console.log(parts);
36+
3537
// Calculate statistics
3638
const stats = {
3739
totalChunks: parts.length,

references/realtime-streams/src/trigger/ai-chat.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import { aiStream } from "@/app/streams";
22
import { openai } from "@ai-sdk/openai";
33
import { logger, streams, task } from "@trigger.dev/sdk";
4-
import { convertToModelMessages, readUIMessageStream, streamText, tool, UIMessage } from "ai";
4+
import {
5+
convertToModelMessages,
6+
readUIMessageStream,
7+
stepCountIs,
8+
streamText,
9+
tool,
10+
UIMessage,
11+
} from "ai";
512
import { z } from "zod/v4";
613

714
export type AIChatPayload = {
@@ -20,6 +27,7 @@ export const aiChatTask = task({
2027
model: openai("gpt-4o"),
2128
system: "You are a helpful assistant.",
2229
messages: convertToModelMessages(payload.messages),
30+
stopWhen: stepCountIs(20),
2331
tools: {
2432
getCommonUseCases: tool({
2533
description: "Get common use cases",

0 commit comments

Comments
 (0)