Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/brown-lions-double.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat(ai): add websocket transport support
55 changes: 55 additions & 0 deletions examples/next-openai/app/ws-chat/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
'use client';

import { useChat } from '@ai-sdk/react';
import { WebSocketChatTransport } from 'ai';
import ChatInput from '@/components/chat-input';

export default function WsChatPage() {
// Point to the local demo WS server (scripts/ws-chat-server.js)
const { error, status, sendMessage, messages, regenerate, stop } = useChat({
transport: new WebSocketChatTransport({ url: 'wss://localhost:8787' }),
});

return (
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
<h1 className="mb-4 text-xl font-bold">WebSocket Chat Transport Demo</h1>

{messages.map(m => (
<div key={m.id} className="whitespace-pre-wrap">
{m.role === 'user' ? 'User: ' : 'AI: '}
{m.parts.map(part => {
if (part.type === 'text') return part.text;
})}
</div>
))}

{(status === 'submitted' || status === 'streaming') && (
<div className="mt-4 text-gray-500">
{status === 'submitted' && <div>Loading...</div>}
<button
type="button"
className="px-4 py-2 mt-4 text-blue-500 border border-blue-500 rounded-md"
onClick={stop}
>
Stop
</button>
</div>
)}

{error && (
<div className="mt-4">
<div className="text-red-500">An error occurred.</div>
<button
type="button"
className="px-4 py-2 mt-4 text-blue-500 border border-blue-500 rounded-md"
onClick={() => regenerate()}
>
Retry
</button>
</div>
)}

<ChatInput status={status} onSubmit={text => sendMessage({ text })} />
</div>
);
}
4 changes: 3 additions & 1 deletion examples/next-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"private": true,
"scripts": {
"dev": "next dev",
"dev:ws": "node ./scripts/ws-chat-server.js",
"build": "next build",
"start": "next start",
"lint": "next lint"
Expand Down Expand Up @@ -44,7 +45,8 @@
"tailwind-merge": "^3.0.2",
"tailwindcss-animate": "^1.0.7",
"valibot": "1.1.0",
"zod": "3.25.76"
"zod": "3.25.76",
"ws": "^8.18.0"
},
"devDependencies": {
"@types/node": "20.17.24",
Expand Down
81 changes: 81 additions & 0 deletions examples/next-openai/scripts/ws-chat-server.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Minimal WebSocket server for demoing WebSocketChatTransport. */
const { WebSocketServer } = require('ws');

const PORT = process.env.WS_PORT ? Number(process.env.WS_PORT) : 8787;
const wss = new WebSocketServer({ port: PORT });

function send(ws, message) {
try {
ws.send(JSON.stringify(message));
} catch {}
}

wss.on('connection', ws => {
ws.on('message', data => {
let inbound;
try {
inbound = JSON.parse(data.toString());
} catch {
return;
}

const requestId = inbound.requestId || '';

switch (inbound.type) {
case 'send': {
// Use AI SDK to stream model output and forward UIMessageChunks over WS
(async () => {
try {
const { streamText, convertToModelMessages } = await import('ai');
const { openai } = await import('@ai-sdk/openai');

const uiMessages = Array.isArray(inbound.messages)
? inbound.messages
: [];
const result = streamText({
model: openai('gpt-4o-mini'),
messages: convertToModelMessages(uiMessages),
});

const stream = result.toUIMessageStream();
const reader = stream.getReader();

try {
while (true) {
const { value, done } = await reader.read();
if (done) break;
send(ws, { type: 'chunk', requestId, chunk: value });
}
send(ws, { type: 'end', requestId });
} finally {
reader.releaseLock?.();
}
} catch (err) {
send(ws, {
type: 'error',
requestId,
errorText: String(err?.message || err),
});
}
})();
break;
}

case 'resume': {
// Minimal behavior: no active stream
send(ws, { type: 'no-active', requestId });
break;
}

case 'abort': {
// No-op in this minimal server
break;
}

default:
break;
}
});
});

console.log(`[ws-chat-server] listening on wss://localhost:${PORT}`);
6 changes: 6 additions & 0 deletions packages/ai/src/ui/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ export {
export { lastAssistantMessageIsCompleteWithApprovalResponses } from './last-assistant-message-is-complete-with-approval-responses';
export { lastAssistantMessageIsCompleteWithToolCalls } from './last-assistant-message-is-complete-with-tool-calls';
export { TextStreamChatTransport } from './text-stream-chat-transport';
export {
WebSocketChatTransport,
type WebSocketChatTransportInitOptions,
type PrepareReconnectToStreamRequest as WebSocketPrepareReconnectToStreamRequest,
type PrepareSendMessagesRequest as WebSocketPrepareSendMessagesRequest,
} from './websocket-chat-transport';
export {
getToolName,
getToolOrDynamicToolName,
Expand Down
135 changes: 135 additions & 0 deletions packages/ai/src/ui/websocket-chat-transport.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import { describe, it, expect, vi } from 'vitest';
import { WebSocketChatTransport } from './websocket-chat-transport';
import { UIMessage } from './ui-messages';

class MockWebSocket {
static OPEN = 1;
static CONNECTING = 0;
readyState = MockWebSocket.OPEN;
sent: string[] = [];
private _onopen: (() => void) | null = null;
onmessage: ((evt: { data: string }) => void) | null = null;
onclose: (() => void) | null = null;
onerror: (() => void) | null = null;

constructor(public url: string) {}

// Auto-trigger onopen when assigned if already open
set onopen(handler: (() => void) | null) {
this._onopen = handler;
if (this.readyState === MockWebSocket.OPEN && handler) {
// Use setTimeout to make it async like real WebSocket
setTimeout(() => handler(), 0);
}
}

get onopen() {
return this._onopen;
}

send(data: string) {
this.sent.push(data);
}
triggerOpen() {
this._onopen?.();
}
triggerMessage(data: any) {
this.onmessage?.({ data: JSON.stringify(data) });
}
triggerClose() {
this.onclose?.();
}
triggerError() {
this.onerror?.();
}
}

describe('WebSocketChatTransport', () => {
it('sends headers and body in outbound send message', async () => {
const ws = new MockWebSocket('wss://example.test');
const transport = new WebSocketChatTransport<UIMessage>({
url: 'wss://example.test',
headers: { 'X-Test': 'yes' },
body: { someData: true },
makeWebSocket: () => ws as unknown as WebSocket,
});

const stream = await transport.sendMessages({
chatId: 'c1',
messageId: 'm1',
trigger: 'submit-message',
messages: [
{ id: 'm1', role: 'user', parts: [{ type: 'text', text: 'hi' }] },
],
abortSignal: new AbortController().signal,
});

const outbound = JSON.parse(ws.sent[0]);
expect(outbound.type).toBe('send');
expect(outbound.id).toBe('c1');
expect(outbound.messageId).toBe('m1');
expect(outbound.headers['X-Test']).toBe('yes');
expect(outbound.body.someData).toBe(true);
});

it('resume returns null when server indicates no-active', async () => {
const ws = new MockWebSocket('wss://example.test');
const transport = new WebSocketChatTransport<UIMessage>({
url: 'wss://example.test',
makeWebSocket: () => ws as unknown as WebSocket,
});

const resumePromise = transport.reconnectToStream({
chatId: 'c2',
});

// Wait for connection and message to be sent
await new Promise(resolve => setTimeout(resolve, 10));

// capture requestId from sent message
const sent = JSON.parse(ws.sent[0]);
expect(sent.type).toBe('resume');

// server responds with no-active
ws.triggerMessage({ type: 'no-active', requestId: sent.requestId });

const result = await resumePromise;
expect(result).toBeNull();
});

it('sendMessages stream closes on end', async () => {
const ws = new MockWebSocket('wss://example.test');
const transport = new WebSocketChatTransport<UIMessage>({
url: 'wss://example.test',
makeWebSocket: () => ws as unknown as WebSocket,
});

const stream = await transport.sendMessages({
chatId: 'c3',
messageId: 'm3',
trigger: 'submit-message',
messages: [
{ id: 'm3', role: 'user', parts: [{ type: 'text', text: 'hi' }] },
],
abortSignal: new AbortController().signal,
});

const sent = JSON.parse(ws.sent[0]);
const reader = stream.getReader();

// push a chunk and then end
ws.triggerMessage({
type: 'chunk',
requestId: sent.requestId,
chunk: { type: 'text-start', id: 'id1' },
});

const { value, done } = await reader.read();
expect(done).toBe(false);
expect(value).toEqual({ type: 'text-start', id: 'id1' });

ws.triggerMessage({ type: 'end', requestId: sent.requestId });
const r2 = await reader.read();
expect(r2.done).toBe(true);
});
});
Loading
Loading