Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add OpenAI transcribe and TTS models #2540

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi) {
this(audioApi,
OpenAiAudioSpeechOptions.builder()
.model(OpenAiAudioApi.TtsModel.TTS_1.getValue())
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.getValue())
.responseFormat(AudioResponseFormat.MP3)
.voice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
.speed(SPEED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPr
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
this(audioApi,
OpenAiAudioTranscriptionOptions.builder()
.model(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue())
.model(OpenAiAudioApi.TranscriptionModels.WHISPER_1.getValue())
.responseFormat(OpenAiAudioApi.TranscriptResponseFormat.JSON)
.temperature(0.7f)
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ApiKey;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.NoopApiKey;
import org.springframework.ai.model.SimpleApiKey;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
Expand Down Expand Up @@ -212,22 +213,25 @@ public String getFilename() {
* different model variates, tts-1 is optimized for real time text to speech use cases
* and tts-1-hd is optimized for quality. These models can be used with the Speech
* endpoint in the Audio API. Reference:
* <a href="https://platform.openai.com/docs/models/tts">TTS</a>
* <a href="https://platform.openai.com/docs/models#tts">TTS</a>
*/
public enum TtsModel {

// @formatter:off
/**
* The latest text to speech model, optimized for speed.
* Text-to-speech model optimized for speed
*/
@JsonProperty("tts-1")
TTS_1("tts-1"),
/**
* The latest text to speech model, optimized for quality.
* Text-to-speech model optimized for quality.
*/
@JsonProperty("tts-1-hd")
TTS_1_HD("tts-1-hd");
// @formatter:on
TTS_1_HD("tts-1-hd"),
/**
* Text-to-speech model powered by GPT-4o mini
*/
@JsonProperty("gpt-4o-mini-tts")
GPT_4O_MINI_TTS("gpt-4o-mini-tts");

public final String value;

Expand All @@ -249,6 +253,7 @@ public String getValue() {
* v2-large model is currently available through our API with the whisper-1 model
* name.
*/
@Deprecated
public enum WhisperModel {

// @formatter:off
Expand All @@ -268,6 +273,45 @@ public String getValue() {

}

/**
* The available models for the transcriptions API. Reference:
* <a href="https://platform.openai.com/docs/models#transcription">
*/
public enum TranscriptionModels implements ChatModelDescription {

/**
* Speech-to-text model powered by GPT-4o
*/
@JsonProperty("gpt-4o-transcribe")
GPT_4O_TRANSCRIBE("gpt-4o-transcribe"),
/**
* Speech-to-text model powered by GPT-4o mini
*/
@JsonProperty("gpt-4o-mini-transcribe")
GPT_4O_MINI_TRANSCRIBE("gpt-4o-mini-transcribe"),
/**
* General-purpose speech recognition model
*/
@JsonProperty("whisper-1")
WHISPER_1("whisper-1");

public final String value;

TranscriptionModels(String value) {
this.value = value;
}

public String getValue() {
return this.value;
}

@Override
public String getName() {
return this.value;
}

}

/**
* The format of the transcript and translation outputs, in one of these options:
* json, text, srt, verbose_json, or vtt. Defaults to json.
Expand Down Expand Up @@ -411,7 +455,7 @@ public String getValue() {
*/
public static class Builder {

private String model = TtsModel.TTS_1.getValue();
private String model = TtsModel.GPT_4O_MINI_TTS.getValue();

private String input;

Expand Down Expand Up @@ -521,7 +565,7 @@ public static class Builder {

private byte[] file;

private String model = WhisperModel.WHISPER_1.getValue();
private String model = TranscriptionModels.WHISPER_1.getValue();

private String language;

Expand Down Expand Up @@ -614,7 +658,7 @@ public static class Builder {

private byte[] file;

private String model = WhisperModel.WHISPER_1.getValue();
private String model = TranscriptionModels.WHISPER_1.getValue();

private String prompt;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest;
import org.springframework.ai.openai.api.OpenAiAudioApi.TranslationRequest;
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionModels;
import org.springframework.ai.openai.api.OpenAiAudioApi.TtsModel;
import org.springframework.ai.openai.api.OpenAiAudioApi.WhisperModel;
import org.springframework.util.FileCopyUtils;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -51,7 +51,7 @@ void speechTranscriptionAndTranslation() throws IOException {

byte[] speech = this.audioApi
.createSpeech(SpeechRequest.builder()
.model(TtsModel.TTS_1_HD.getValue())
.model(TtsModel.GPT_4O_MINI_TTS.getValue())
.input("Hello, my name is Chris and I love Spring A.I.")
.voice(Voice.ONYX)
.build())
Expand All @@ -63,15 +63,15 @@ void speechTranscriptionAndTranslation() throws IOException {

StructuredResponse translation = this.audioApi
.createTranslation(
TranslationRequest.builder().model(WhisperModel.WHISPER_1.getValue()).file(speech).build(),
TranslationRequest.builder().model(TranscriptionModels.WHISPER_1.getValue()).file(speech).build(),
StructuredResponse.class)
.getBody();

assertThat(translation.text().replaceAll(",", "")).isEqualTo("Hello my name is Chris and I love Spring AI.");

StructuredResponse transcriptionEnglish = this.audioApi
.createTranscription(
TranscriptionRequest.builder().model(WhisperModel.WHISPER_1.getValue()).file(speech).build(),
TranscriptionRequest.builder().model(TranscriptionModels.WHISPER_1.getValue()).file(speech).build(),
StructuredResponse.class)
.getBody();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void checkNoOpKey() {
assertThatThrownBy(() -> {
this.audioApi
.createSpeech(OpenAiAudioApi.SpeechRequest.builder()
.model(OpenAiAudioApi.TtsModel.TTS_1_HD.getValue())
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.getValue())
.input("Hello, my name is Chris and I love Spring A.I.")
.voice(OpenAiAudioApi.SpeechRequest.Voice.ONYX)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void shouldGenerateNonEmptyMp3AudioFromSpeechPrompt() {
.voice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
.speed(SPEED)
.responseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3)
.model(OpenAiAudioApi.TtsModel.TTS_1.value)
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.value)
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
Expand All @@ -78,7 +78,7 @@ void speechRateLimitTest() {
.voice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
.speed(SPEED)
.responseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3)
.model(OpenAiAudioApi.TtsModel.TTS_1.value)
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.value)
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
Expand All @@ -98,7 +98,7 @@ void shouldStreamNonEmptyResponsesForValidSpeechPrompts() {
.voice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
.speed(SPEED)
.responseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3)
.model(OpenAiAudioApi.TtsModel.TTS_1.value)
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.value)
.build();

SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
Expand All @@ -120,7 +120,7 @@ void speechVoicesTest(String voice) {
.voice(OpenAiAudioApi.SpeechRequest.Voice.valueOf(voice.toUpperCase()))
.speed(SPEED)
.responseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3)
.model(OpenAiAudioApi.TtsModel.TTS_1.value)
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.value)
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void aiResponseContainsImageResponseMetadata() {
.voice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
.speed(SPEED)
.responseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3)
.model(OpenAiAudioApi.TtsModel.TTS_1.value)
.model(OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.value)
.build();

SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class OpenAiAudioSpeechProperties extends OpenAiParentProperties {

public static final String CONFIG_PREFIX = "spring.ai.openai.audio.speech";

public static final String DEFAULT_SPEECH_MODEL = OpenAiAudioApi.TtsModel.TTS_1.getValue();
public static final String DEFAULT_SPEECH_MODEL = OpenAiAudioApi.TtsModel.GPT_4O_MINI_TTS.getValue();

private static final Float SPEED = 1.0f;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties {

public static final String CONFIG_PREFIX = "spring.ai.openai.audio.transcription";

public static final String DEFAULT_TRANSCRIPTION_MODEL = OpenAiAudioApi.WhisperModel.WHISPER_1.getValue();
public static final String DEFAULT_TRANSCRIPTION_MODEL = OpenAiAudioApi.TranscriptionModels.WHISPER_1.getValue();

private static final Double DEFAULT_TEMPERATURE = 0.7;

Expand Down