Skip to content

Commit

Permalink
feat(openai): Support audio output (#7012)
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Oct 17, 2024
1 parent c9b8026 commit 2a08a03
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 18 deletions.
110 changes: 110 additions & 0 deletions docs/core_docs/docs/integrations/chat/openai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,116 @@
"console.log(\"USAGE:\", resWitCaching.response_metadata.usage);"
]
},
{
"cell_type": "markdown",
"id": "cc8b3c94",
"metadata": {},
"source": [
"## Audio output\n",
"\n",
"Some OpenAI models (such as `gpt-4o-audio-preview`) support generating audio output. This example shows how to use that feature:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b4d579b7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" id: 'audio_67117718c6008190a3afad3e3054b9b6',\n",
" data: 'UklGRqYwBgBXQVZFZm10IBAAAAABAAEAwF0AAIC7AAACABAATElTVBoAAABJTkZPSVNGVA4AAABMYXZmNTguMjkuMTAwAGRhdGFg',\n",
" expires_at: 1729201448,\n",
" transcript: 'Sure! Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse!'\n",
"}\n"
]
}
],
"source": [
"import { ChatOpenAI } from \"@langchain/openai\";\n",
"\n",
"const modelWithAudioOutput = new ChatOpenAI({\n",
" model: \"gpt-4o-audio-preview\",\n",
" // You may also pass these fields to `.bind` as a call argument.\n",
" modalities: [\"text\", \"audio\"], // Specifies that the model should output audio.\n",
" audio: {\n",
" voice: \"alloy\",\n",
" format: \"wav\",\n",
" },\n",
"});\n",
"\n",
"const audioOutputResult = await modelWithAudioOutput.invoke(\"Tell me a joke about cats.\");\n",
"const castMessageContent = audioOutputResult.content[0] as Record<string, any>;\n",
"\n",
"console.log({\n",
" ...castMessageContent,\n",
" data: castMessageContent.data.slice(0, 100) // Sliced for brevity\n",
"})"
]
},
{
"cell_type": "markdown",
"id": "bfea3608",
"metadata": {},
"source": [
"We see that the audio data is returned inside the `data` field. We are also provided an `expires_at` date field. This field represents the date the audio response will no longer be accessible on the server for use in multi-turn conversations.\n",
"\n",
"### Streaming Audio Output\n",
"\n",
"OpenAI also supports streaming audio output. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0fa68183",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" id: 'audio_671177fd836c819099b0110f5180a581audio_671177fd836c819099b0110f5180a581',\n",
" transcript: 'Why was the cat sitting on the computer? Because it wanted to keep an eye on the mouse!',\n",
" index: 0,\n",
" data: 'CQAMAAMADwAHAAsADAAKAA8ADQAPAAoABQANAAUAEAAIAA0ABwAHAAoAAAAFAAMABwAJAAQABwAAAAgAAgAFAAMAAwACAAAAAwAB',\n",
" expires_at: 1729201678\n",
"}\n"
]
}
],
"source": [
"import { AIMessageChunk } from \"@langchain/core/messages\";\n",
"import { concat } from \"@langchain/core/utils/stream\"\n",
"import { ChatOpenAI } from \"@langchain/openai\";\n",
"\n",
"const modelWithStreamingAudioOutput = new ChatOpenAI({\n",
" model: \"gpt-4o-audio-preview\",\n",
" modalities: [\"text\", \"audio\"],\n",
" audio: {\n",
" voice: \"alloy\",\n",
" format: \"pcm16\", // Format must be `pcm16` for streaming\n",
" },\n",
"});\n",
"\n",
"const audioOutputStream = await modelWithStreamingAudioOutput.stream(\"Tell me a joke about cats.\");\n",
"let finalAudioOutputMsg: AIMessageChunk | undefined;\n",
"for await (const chunk of audioOutputStream) {\n",
" finalAudioOutputMsg = finalAudioOutputMsg ? concat(finalAudioOutputMsg, chunk) : chunk;\n",
"}\n",
"const castStreamedMessageContent = finalAudioOutputMsg?.content[1] as Record<string, any>;\n",
"\n",
"console.log({\n",
" ...castStreamedMessageContent,\n",
" data: castStreamedMessageContent.data.slice(0, 100) // Sliced for brevity\n",
"})"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
"dependencies": {
"js-tiktoken": "^1.0.12",
"openai": "^4.67.2",
"openai": "^4.68.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
99 changes: 87 additions & 12 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
isAIMessage,
convertToChunk,
UsageMetadata,
MessageContent,
} from "@langchain/core/messages";
import {
type ChatGeneration,
Expand Down Expand Up @@ -174,8 +175,10 @@ function openAIResponseToChatMessage(
system_fingerprint: rawResponse.system_fingerprint,
};
}
const content = message.audio ? [message.audio] : message.content;

return new AIMessage({
content: message.content || "",
content: content || "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs,
Expand All @@ -196,7 +199,17 @@ function _convertDeltaToMessageChunk(
includeRawResponse?: boolean
) {
const role = delta.role ?? defaultRole;
const content = delta.content ?? "";
let content: MessageContent;
if (delta.audio) {
content = [
{
...delta.audio,
index: rawResponse.choices[0].index,
},
];
} else {
content = delta.content ?? "";
}
let additional_kwargs: Record<string, unknown>;
if (delta.function_call) {
additional_kwargs = {
Expand Down Expand Up @@ -372,6 +385,26 @@ export interface ChatOpenAICallOptions
* @version 0.2.6
*/
strict?: boolean;
/**
* Output types that you would like the model to generate for this request. Most
* models are capable of generating text, which is the default:
*
* `["text"]`
*
* The `gpt-4o-audio-preview` model can also be used to
* [generate audio](https://platform.openai.com/docs/guides/audio). To request that
* this model generate both text and audio responses, you can use:
*
* `["text", "audio"]`
*/
modalities?: Array<OpenAIClient.Chat.ChatCompletionModality>;

/**
* Parameters for audio output. Required when audio output is requested with
* `modalities: ["audio"]`.
* [Learn more](https://platform.openai.com/docs/guides/audio).
*/
audio?: OpenAIClient.Chat.ChatCompletionAudioParam;
}

export interface ChatOpenAIFields
Expand Down Expand Up @@ -842,6 +875,43 @@ export interface ChatOpenAIFields
* </details>
*
* <br />
*
* <details>
* <summary><strong>Audio Outputs</strong></summary>
*
* ```typescript
* import { ChatOpenAI } from "@langchain/openai";
*
* const modelWithAudioOutput = new ChatOpenAI({
* model: "gpt-4o-audio-preview",
* // You may also pass these fields to `.bind` as a call argument.
* modalities: ["text", "audio"], // Specifies that the model should output audio.
* audio: {
* voice: "alloy",
* format: "wav",
* },
* });
*
* const audioOutputResult = await modelWithAudioOutput.invoke("Tell me a joke about cats.");
* const castMessageContent = audioOutputResult.content[0] as Record<string, any>;
*
* console.log({
* ...castMessageContent,
* data: castMessageContent.data.slice(0, 100) // Sliced for brevity
* })
* ```
*
* ```txt
* {
* id: 'audio_67117718c6008190a3afad3e3054b9b6',
* data: 'UklGRqYwBgBXQVZFZm10IBAAAAABAAEAwF0AAIC7AAACABAATElTVBoAAABJTkZPSVNGVA4AAABMYXZmNTguMjkuMTAwAGRhdGFg',
* expires_at: 1729201448,
* transcript: 'Sure! Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse!'
* }
* ```
* </details>
*
* <br />
*/
export class ChatOpenAI<
CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions
Expand Down Expand Up @@ -958,6 +1028,10 @@ export class ChatOpenAI<
*/
supportsStrictToolCalling?: boolean;

audio?: OpenAIClient.Chat.ChatCompletionAudioParam;

modalities?: Array<OpenAIClient.Chat.ChatCompletionModality>;

constructor(
fields?: ChatOpenAIFields,
/** @deprecated */
Expand Down Expand Up @@ -1026,6 +1100,8 @@ export class ChatOpenAI<
this.stopSequences = this?.stop;
this.user = fields?.user;
this.__includeRawResponse = fields?.__includeRawResponse;
this.audio = fields?.audio;
this.modalities = fields?.modalities;

if (this.azureOpenAIApiKey || this.azureADTokenProvider) {
if (
Expand Down Expand Up @@ -1190,6 +1266,12 @@ export class ChatOpenAI<
seed: options?.seed,
...streamOptionsConfig,
parallel_tool_calls: options?.parallel_tool_calls,
...(this.audio || options?.audio
? { audio: this.audio || options?.audio }
: {}),
...(this.modalities || options?.modalities
? { modalities: this.modalities || options?.modalities }
: {}),
...this.modelKwargs,
};
return params;
Expand Down Expand Up @@ -1241,7 +1323,7 @@ export class ChatOpenAI<
const streamIterable = await this.completionWithRetry(params, options);
let usage: OpenAIClient.Completions.CompletionUsage | undefined;
for await (const data of streamIterable) {
const choice = data?.choices[0];
const choice = data?.choices?.[0];
if (data.usage) {
usage = data.usage;
}
Expand All @@ -1264,12 +1346,6 @@ export class ChatOpenAI<
prompt: options.promptIndex ?? 0,
completion: choice.index ?? 0,
};
if (typeof chunk.content !== "string") {
console.log(
"[WARNING]: Received non-string content from OpenAI. This is currently not supported."
);
continue;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const generationInfo: Record<string, any> = { ...newTokenIndices };
if (choice.finish_reason != null) {
Expand All @@ -1283,7 +1359,7 @@ export class ChatOpenAI<
}
const generationChunk = new ChatGenerationChunk({
message: chunk,
text: chunk.content,
text: typeof chunk.content === "string" ? chunk.content : "",
generationInfo,
});
yield generationChunk;
Expand Down Expand Up @@ -1490,9 +1566,8 @@ export class ChatOpenAI<

const generations: ChatGeneration[] = [];
for (const part of data?.choices ?? []) {
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
text: part.message?.content ?? "",
message: openAIResponseToChatMessage(
part.message ?? { role: "assistant" },
data,
Expand Down
76 changes: 76 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
import { CallbackManager } from "@langchain/core/callbacks/manager";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { InMemoryCache } from "@langchain/core/caches";
import { concat } from "@langchain/core/utils/stream";
import { ChatOpenAI } from "../chat_models.js";

// Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable
Expand Down Expand Up @@ -986,3 +987,78 @@ test("Test ChatOpenAI stream method", async () => {
}
expect(chunks.length).toEqual(1);
});

describe("Audio output", () => {
test("Audio output", async () => {
const model = new ChatOpenAI({
model: "gpt-4o-audio-preview",
temperature: 0,
modalities: ["text", "audio"],
audio: {
voice: "alloy",
format: "wav",
},
});

const response = await model.invoke("Make me an audio clip of you yelling");
expect(Array.isArray(response.content)).toBeTruthy();
expect(Object.keys(response.content[0]).sort()).toEqual([
"data",
"expires_at",
"id",
"transcript",
]);
});

test("Audio output can stream", async () => {
const model = new ChatOpenAI({
model: "gpt-4o-audio-preview",
temperature: 0,
modalities: ["text", "audio"],
audio: {
voice: "alloy",
format: "pcm16",
},
});

const stream = await model.stream("Make me an audio clip of you yelling");
let finalMsg: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalMsg = finalMsg ? concat(finalMsg, chunk) : chunk;
}
if (!finalMsg) {
throw new Error("No final message found");
}
console.dir(finalMsg, { depth: null });
expect(Array.isArray(finalMsg.content)).toBeTruthy();
expect(Object.keys(finalMsg.content[1]).sort()).toEqual([
"data",
"expires_at",
"id",
"index",
"transcript",
]);
});

test("Can bind audio output args", async () => {
const model = new ChatOpenAI({
model: "gpt-4o-audio-preview",
temperature: 0,
}).bind({
modalities: ["text", "audio"],
audio: {
voice: "alloy",
format: "wav",
},
});

const response = await model.invoke("Make me an audio clip of you yelling");
expect(Array.isArray(response.content)).toBeTruthy();
expect(Object.keys(response.content[0]).sort()).toEqual([
"data",
"expires_at",
"id",
"transcript",
]);
});
});
Loading

0 comments on commit 2a08a03

Please sign in to comment.