Skip to content

Commit

Permalink
feat(ollama): add retry template integration to OllamaChatModel
Browse files Browse the repository at this point in the history
  • Loading branch information
apappascs committed Dec 10, 2024
1 parent a474b12 commit 55d38fd
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
Expand All @@ -77,6 +79,7 @@
* @author luocongqiu
* @author Thomas Vitale
* @author Jihoon Kim
* @author Alexandros Pappas
* @since 1.0.0
*/
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
Expand Down Expand Up @@ -107,20 +110,32 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode

private final OllamaModelManager modelManager;

private final RetryTemplate retryTemplate;

private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
this(ollamaApi, defaultOptions, functionCallbackResolver, toolFunctionCallbacks, observationRegistry,
modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
RetryTemplate retryTemplate) {
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.chatApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
this.retryTemplate = retryTemplate;
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

Expand Down Expand Up @@ -198,7 +213,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
this.observationRegistry)
.observe(() -> {

OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));

List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
: ollamaResponse.message()
Expand Down Expand Up @@ -470,6 +485,8 @@ public static final class Builder {

private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

private RetryTemplate retryTemplate;

private Builder() {
}

Expand Down Expand Up @@ -513,9 +530,15 @@ public Builder withModelManagementOptions(ModelManagementOptions modelManagement
return this;
}

public Builder withRetryTemplate(RetryTemplate retryTemplate) {
this.retryTemplate = retryTemplate;
return this;
}

public OllamaChatModel build() {
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver,
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions);
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions,
this.retryTemplate);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 0.8.0
*/
// @formatter:off
Expand All @@ -66,8 +67,6 @@ public class OllamaApi {

private static final String DEFAULT_BASE_URL = "http://localhost:11434";

private final ResponseErrorHandler responseErrorHandler;

private final RestClient restClient;

private final WebClient webClient;
Expand Down Expand Up @@ -95,14 +94,16 @@ public OllamaApi(String baseUrl) {
*/
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {

this.responseErrorHandler = new OllamaResponseErrorHandler();
ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler();

Consumer<HttpHeaders> defaultHeaders = headers -> {
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
};

this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultStatusHandler(responseErrorHandler)
.defaultHeaders(defaultHeaders).build();

this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
}
Expand All @@ -123,7 +124,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
.uri("/api/chat")
.body(chatRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ChatResponse.class);
}

Expand Down Expand Up @@ -190,7 +190,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
.uri("/api/embed")
.body(embeddingsRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(EmbeddingsResponse.class);
}

Expand All @@ -201,7 +200,6 @@ public ListModelResponse listModels() {
return this.restClient.get()
.uri("/api/tags")
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ListModelResponse.class);
}

Expand All @@ -214,7 +212,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
.uri("/api/show")
.body(showModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ShowModelResponse.class);
}

Expand All @@ -227,7 +224,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
.uri("/api/copy")
.body(copyModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

Expand All @@ -240,7 +236,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
.uri("/api/delete")
.body(deleteModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.api.tool.MockWeatherService;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -123,6 +124,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return OllamaChatModel.builder()
.withOllamaApi(ollamaApi)
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -249,6 +250,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
.build())
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -84,6 +85,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return OllamaChatModel.builder()
.withOllamaApi(ollamaApi)
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand All @@ -47,6 +48,7 @@
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
*
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
public class OllamaChatModelObservationIT extends BaseOllamaIT {
Expand Down Expand Up @@ -172,6 +174,7 @@ public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegis
return OllamaChatModel.builder()
.withOllamaApi(ollamaApi)
.withObservationRegistry(observationRegistry)
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -39,6 +40,7 @@
/**
* @author Jihoon Kim
* @author Christian Tzolov
* @author Alexandros Pappas
* @since 1.0.0
*/
@ExtendWith(MockitoExtension.class)
Expand All @@ -53,6 +55,7 @@ public void buildOllamaChatModel() {
() -> OllamaChatModel.builder()
.withOllamaApi(this.ollamaApi)
.withDefaultOptions(OllamaOptions.create().withModel(OllamaModel.LLAMA2))
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.withModelManagementOptions(null)
.build());
assertEquals("modelManagementOptions must not be null", exception.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
public class OllamaChatRequestTests {

OllamaChatModel chatModel = OllamaChatModel.builder()
.withOllamaApi(new OllamaApi())
.withDefaultOptions(
OllamaOptions.create().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6).withNumGPU(1))
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();

@Test
Expand Down Expand Up @@ -113,6 +116,7 @@ public void createRequestWithDefaultOptionsModelOverride() {
OllamaChatModel chatModel = OllamaChatModel.builder()
.withOllamaApi(new OllamaApi())
.withDefaultOptions(OllamaOptions.create().withModel("DEFAULT_OPTIONS_MODEL"))
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();

var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;

Expand Down Expand Up @@ -82,7 +83,7 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
: PullModelStrategy.NEVER;

Expand All @@ -95,6 +96,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties
.withModelManagementOptions(
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
initProperties.getTimeout(), initProperties.getMaxRetries()))
.withRetryTemplate(retryTemplate)
.build();

observationConvention.ifAvailable(chatModel::setObservationConvention);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
Expand All @@ -41,7 +42,8 @@ public void propertiesTest() {
"spring.ai.ollama.chat.options.topP=0.56",
"spring.ai.ollama.chat.options.topK=123")
// @formatter:on
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
.run(context -> {
var chatProperties = context.getBean(OllamaChatProperties.class);
var connectionProperties = context.getBean(OllamaConnectionProperties.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
Expand All @@ -26,6 +27,7 @@

/**
* @author Christian Tzolov
* @author Alexandros Pappas
* @since 0.8.0
*/
public class OllamaEmbeddingAutoConfigurationTests {
Expand All @@ -41,7 +43,8 @@ public void propertiesTest() {
"spring.ai.ollama.embedding.options.topK=13"
// @formatter:on
)
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
.run(context -> {
var embeddingProperties = context.getBean(OllamaEmbeddingProperties.class);
var connectionProperties = context.getBean(OllamaConnectionProperties.class);
Expand Down

0 comments on commit 55d38fd

Please sign in to comment.