Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Double instead of Float in the portable ChatOptions #1325

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM

public static final Integer DEFAULT_MAX_TOKENS = 500;

public static final Float DEFAULT_TEMPERATURE = 0.8f;
public static final Double DEFAULT_TEMPERATURE = 0.8;

/**
* The lower-level API for the Anthropic service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
private @JsonProperty("max_tokens") Integer maxTokens;
private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata;
private @JsonProperty("stop_sequences") List<String> stopSequences;
private @JsonProperty("temperature") Float temperature;
private @JsonProperty("top_p") Float topP;
private @JsonProperty("temperature") Double temperature;
private @JsonProperty("top_p") Double topP;
private @JsonProperty("top_k") Integer topK;

/**
Expand Down Expand Up @@ -112,12 +112,12 @@ public Builder withStopSequences(List<String> stopSequences) {
return this;
}

public Builder withTemperature(Float temperature) {
public Builder withTemperature(Double temperature) {
this.options.temperature = temperature;
return this;
}

public Builder withTopP(Float topP) {
public Builder withTopP(Double topP) {
this.options.topP = topP;
return this;
}
Expand Down Expand Up @@ -186,20 +186,20 @@ public void setStopSequences(List<String> stopSequences) {
}

@Override
public Float getTemperature() {
public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Float temperature) {
public void setTemperature(Double temperature) {
this.temperature = temperature;
}

@Override
public Float getTopP() {
public Double getTopP() {
return this.topP;
}

public void setTopP(Float topP) {
public void setTopP(Double topP) {
this.topP = topP;
}

Expand Down Expand Up @@ -236,13 +236,13 @@ public void setFunctions(Set<String> functions) {

@Override
@JsonIgnore
public Float getFrequencyPenalty() {
public Double getFrequencyPenalty() {
return null;
}

@Override
@JsonIgnore
public Float getPresencePenalty() {
public Double getPresencePenalty() {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,19 @@ public record ChatCompletionRequest( // @formatter:off
@JsonProperty("metadata") Metadata metadata,
@JsonProperty("stop_sequences") List<String> stopSequences,
@JsonProperty("stream") Boolean stream,
@JsonProperty("temperature") Float temperature,
@JsonProperty("top_p") Float topP,
@JsonProperty("temperature") Double temperature,
@JsonProperty("top_p") Double topP,
@JsonProperty("top_k") Integer topK,
@JsonProperty("tools") List<Tool> tools) {
// @formatter:on

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, String system, Integer maxTokens,
Float temperature, Boolean stream) {
Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null);
}

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, String system, Integer maxTokens,
List<String> stopSequences, Float temperature, Boolean stream) {
List<String> stopSequences, Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null);
}

Expand Down Expand Up @@ -292,9 +292,9 @@ public static class ChatCompletionRequestBuilder {

private Boolean stream = false;

private Float temperature;
private Double temperature;

private Float topP;
private Double topP;

private Integer topK;

Expand Down Expand Up @@ -357,12 +357,12 @@ public ChatCompletionRequestBuilder withStream(Boolean stream) {
return this;
}

public ChatCompletionRequestBuilder withTemperature(Float temperature) {
public ChatCompletionRequestBuilder withTemperature(Double temperature) {
this.temperature = temperature;
return this;
}

public ChatCompletionRequestBuilder withTopP(Float topP) {
public ChatCompletionRequestBuilder withTopP(Double topP) {
this.topP = topP;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void roleTest(String modelName) {

@Test
void streamingWithTokenUsage() {
var promptOptions = AnthropicChatOptions.builder().withTemperature(0f).build();
var promptOptions = AnthropicChatOptions.builder().withTemperature(0.0).build();

var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ void observationForChatOperation() {
.withModel(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue())
.withMaxTokens(2048)
.withStopSequences(List.of("this-is-the-end"))
.withTemperature(0.7f)
.withTemperature(0.7)
.withTopK(1)
.withTopP(1f)
.withTopP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand All @@ -93,9 +93,9 @@ void observationForStreamingChatOperation() {
.withModel(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue())
.withMaxTokens(2048)
.withStopSequences(List.of("this-is-the-end"))
.withTemperature(0.7f)
.withTemperature(0.7)
.withTopK(1)
.withTopP(1f)
.withTopP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ public class ChatCompletionRequestTests {
public void createRequestWithChatOptions() {

var client = new AnthropicChatModel(new AnthropicApi("TEST"),
AnthropicChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
AnthropicChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build());

var request = client.createRequest(new Prompt("Test message content"), false);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();

assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.temperature()).isEqualTo(66.6f);
assertThat(request.temperature()).isEqualTo(66.6);

request = client.createRequest(new Prompt("Test message content",
AnthropicChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true);
AnthropicChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9).build()), true);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();

assertThat(request.model()).isEqualTo("PROMPT_MODEL");
assertThat(request.temperature()).isEqualTo(99.9f);
assertThat(request.temperature()).isEqualTo(99.9);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void chatCompletionEntity() {
Role.USER);
ResponseEntity<ChatCompletionResponse> response = anthropicApi
.chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(),
List.of(chatCompletionMessage), null, 100, 0.8f, false));
List.of(chatCompletionMessage), null, 100, 0.8, false));

System.out.println(response);
assertThat(response).isNotNull();
Expand All @@ -58,9 +58,8 @@ void chatCompletionStream() {
AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")),
Role.USER);

Flux<ChatCompletionResponse> response = anthropicApi
.chatCompletionStream(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(),
List.of(chatCompletionMessage), null, 100, 0.8f, true));
Flux<ChatCompletionResponse> response = anthropicApi.chatCompletionStream(new ChatCompletionRequest(
AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, true));

assertThat(response).isNotNull();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ void toolCalls() {
Role.USER);

ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(
AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), systemPrompt, 500,
0.8f, false);
AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), systemPrompt, 500, 0.8,
false);

ResponseEntity<ChatCompletionResponse> chatCompletion = doCall(chatCompletionRequest);

Expand Down Expand Up @@ -147,7 +147,7 @@ private ResponseEntity<ChatCompletionResponse> doCall(ChatCompletionRequest chat
AnthropicMessage chatCompletionMessage2 = new AnthropicMessage(List.of(new ContentBlock(content)), Role.USER);

return doCall(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(),
List.of(chatCompletionMessage2), null, 500, 0.8f, false));
List.of(chatCompletionMessage2), null, 500, 0.8, false));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private ResponseEntity<ChatCompletionResponse> doCall(List<AnthropicMessage> mes
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS)
.withMessages(messageConversation)
.withMaxTokens(1500)
.withTemperature(0.8f)
.withTemperature(0.8)
.withTools(tools)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
* @author Christian Tzolov
* @author Grogdunn
* @author Benoit Moussaud
* @author Thomas Vitale
* @author luocongqiu
* @author timostark
* @see ChatModel
Expand All @@ -98,7 +99,7 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha

private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-4o";

private static final Float DEFAULT_TEMPERATURE = 0.7f;
private static final Double DEFAULT_TEMPERATURE = 0.7;

/**
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
Expand Down Expand Up @@ -422,22 +423,22 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,

mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature());
if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) {
mergedAzureOptions.setTemperature(toSpringAiOptions.getTemperature().doubleValue());
mergedAzureOptions.setTemperature(toSpringAiOptions.getTemperature());
}

mergedAzureOptions.setTopP(fromAzureOptions.getTopP());
if (mergedAzureOptions.getTopP() == null && toSpringAiOptions.getTopP() != null) {
mergedAzureOptions.setTopP(toSpringAiOptions.getTopP().doubleValue());
mergedAzureOptions.setTopP(toSpringAiOptions.getTopP());
}

mergedAzureOptions.setFrequencyPenalty(fromAzureOptions.getFrequencyPenalty());
if (mergedAzureOptions.getFrequencyPenalty() == null && toSpringAiOptions.getFrequencyPenalty() != null) {
mergedAzureOptions.setFrequencyPenalty(toSpringAiOptions.getFrequencyPenalty().doubleValue());
mergedAzureOptions.setFrequencyPenalty(toSpringAiOptions.getFrequencyPenalty());
}

mergedAzureOptions.setPresencePenalty(fromAzureOptions.getPresencePenalty());
if (mergedAzureOptions.getPresencePenalty() == null && toSpringAiOptions.getPresencePenalty() != null) {
mergedAzureOptions.setPresencePenalty(toSpringAiOptions.getPresencePenalty().doubleValue());
mergedAzureOptions.setPresencePenalty(toSpringAiOptions.getPresencePenalty());
}

mergedAzureOptions.setResponseFormat(fromAzureOptions.getResponseFormat());
Expand Down Expand Up @@ -486,19 +487,19 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
}

if (fromSpringAiOptions.getTemperature() != null) {
mergedAzureOptions.setTemperature(fromSpringAiOptions.getTemperature().doubleValue());
mergedAzureOptions.setTemperature(fromSpringAiOptions.getTemperature());
}

if (fromSpringAiOptions.getTopP() != null) {
mergedAzureOptions.setTopP(fromSpringAiOptions.getTopP().doubleValue());
mergedAzureOptions.setTopP(fromSpringAiOptions.getTopP());
}

if (fromSpringAiOptions.getFrequencyPenalty() != null) {
mergedAzureOptions.setFrequencyPenalty(fromSpringAiOptions.getFrequencyPenalty().doubleValue());
mergedAzureOptions.setFrequencyPenalty(fromSpringAiOptions.getFrequencyPenalty());
}

if (fromSpringAiOptions.getPresencePenalty() != null) {
mergedAzureOptions.setPresencePenalty(fromSpringAiOptions.getPresencePenalty().doubleValue());
mergedAzureOptions.setPresencePenalty(fromSpringAiOptions.getPresencePenalty());
}

if (fromSpringAiOptions.getN() != null) {
Expand Down
Loading