diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml index 1e6ad3cecb6..86c684a4366 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml @@ -30,6 +30,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + org.springframework.boot diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java index 30d4d69609f..58f63d08d0d 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory; import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; +import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -34,7 +35,7 @@ * @author Jihoon Kim * @since 1.0.0 */ -@AutoConfiguration(after = CassandraAutoConfiguration.class) +@AutoConfiguration(after = CassandraAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class) @ConditionalOnClass({ CassandraChatMemory.class, CqlSession.class }) @EnableConfigurationProperties(CassandraChatMemoryProperties.class) public class CassandraChatMemoryAutoConfiguration { diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml index 19e4f0e7bb7..defad3b07b3 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml @@ -29,6 +29,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + org.springframework.boot diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java index bc811c3ded6..d3dc6d36fee 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java @@ -21,24 +21,33 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryConfig; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository; +import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.condition.ConditionMessage; +import org.springframework.boot.autoconfigure.condition.ConditionOutcome; +import org.springframework.boot.autoconfigure.condition.SpringBootCondition; import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; -import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ConditionContext; +import org.springframework.context.annotation.Conditional; +import org.springframework.core.type.AnnotatedTypeMetadata; import org.springframework.jdbc.core.JdbcTemplate; /** * @author Jonathan Leijendekker + * @author Thomas Vitale * @since 1.0.0 */ -@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) -@ConditionalOnClass({ JdbcChatMemory.class, DataSource.class, JdbcTemplate.class }) +@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class) +@ConditionalOnClass({ JdbcChatMemoryRepository.class, DataSource.class, JdbcTemplate.class }) @EnableConfigurationProperties(JdbcChatMemoryProperties.class) public class JdbcChatMemoryAutoConfiguration { @@ -46,20 +55,59 @@ public class JdbcChatMemoryAutoConfiguration { @Bean @ConditionalOnMissingBean - public JdbcChatMemory chatMemory(JdbcTemplate jdbcTemplate) { - var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + JdbcChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) { + return JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build(); + } + /** + * @deprecated in favor of building a {@link MessageWindowChatMemory} (or other + * {@link ChatMemory} implementations) with a {@link JdbcChatMemoryRepository} + * instance. + */ + @Bean + @ConditionalOnMissingBean + @Deprecated + JdbcChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); return JdbcChatMemory.create(config); } @Bean @ConditionalOnMissingBean - @ConditionalOnProperty(value = "spring.ai.chat.memory.jdbc.initialize-schema", havingValue = "true", - matchIfMissing = true) - public DataSourceScriptDatabaseInitializer jdbcChatMemoryScriptDatabaseInitializer(DataSource dataSource) { - logger.debug("Initializing JdbcChatMemory schema"); - + @Conditional(OnSchemaInitializationEnabledCondition.class) + JdbcChatMemoryDataSourceScriptDatabaseInitializer jdbcChatMemoryScriptDatabaseInitializer(DataSource dataSource) { + logger.debug("Initializing schema for JdbcChatMemoryRepository"); return new JdbcChatMemoryDataSourceScriptDatabaseInitializer(dataSource); } + /** + * Condition to check if the schema initialization is enabled, supporting both + * deprecated and new property. + * + * @deprecated to be removed in 1.0.0-RC1. + */ + @Deprecated + static class OnSchemaInitializationEnabledCondition extends SpringBootCondition { + + @Override + public ConditionOutcome getMatchOutcome(ConditionContext context, AnnotatedTypeMetadata metadata) { + Boolean initializeSchemaEnabled = context.getEnvironment() + .getProperty("spring.ai.chat.memory.jdbc.initialize-schema", Boolean.class); + + if (initializeSchemaEnabled != null) { + return new ConditionOutcome(initializeSchemaEnabled, + ConditionMessage.forCondition("Enable JDBC Chat Memory Schema Initialization") + .because("spring.ai.chat.memory.jdbc.initialize-schema is " + initializeSchemaEnabled)); + } + + initializeSchemaEnabled = context.getEnvironment() + .getProperty(JdbcChatMemoryProperties.CONFIG_PREFIX + ".initialize-schema", Boolean.class, true); + + return new ConditionOutcome(initializeSchemaEnabled, ConditionMessage + .forCondition("Enable JDBC Chat Memory Schema Initialization") + .because(JdbcChatMemoryProperties.CONFIG_PREFIX + ".initialize-schema is " + initializeSchemaEnabled)); + } + + } + } diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java index 716b6fb57c0..546b272d367 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java @@ -25,6 +25,11 @@ import org.springframework.boot.sql.init.DatabaseInitializationMode; import org.springframework.boot.sql.init.DatabaseInitializationSettings; +/** + * Performs database initialization for the JDBC Chat Memory Repository. + * + * @since 1.0.0 + */ class JdbcChatMemoryDataSourceScriptDatabaseInitializer extends DataSourceScriptDatabaseInitializer { private static final String SCHEMA_LOCATION = "classpath:org/springframework/ai/chat/memory/jdbc/schema-@@platform@@.sql"; diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java index 1c33ffbb0a5..6ff1a2a18ba 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java @@ -20,13 +20,17 @@ /** * @author Jonathan Leijendekker + * @author Thomas Vitale * @since 1.0.0 */ @ConfigurationProperties(JdbcChatMemoryProperties.CONFIG_PREFIX) public class JdbcChatMemoryProperties { - public static final String CONFIG_PREFIX = "spring.ai.chat.memory.jdbc"; + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.jdbc"; + /** + * Whether to initialize the schema on startup. + */ private boolean initializeSchema = true; public boolean isInitializeSchema() { diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/resources/META-INF/additional-spring-configuration-metadata.json b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/resources/META-INF/additional-spring-configuration-metadata.json new file mode 100644 index 00000000000..eceef48bb83 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/resources/META-INF/additional-spring-configuration-metadata.json @@ -0,0 +1,14 @@ +{ + "groups": [], + "properties": [ + { + "name": "spring.ai.chat.memory.jdbc.initialize-schema", + "type": "java.lang.Boolean", + "description": "Whether to initialize the schema on startup.", + "deprecation": { + "replacement": "spring.ai.chat.memory.repository.jdbc.initialize-schema" + } + } + ], + "hints": [] +} \ No newline at end of file diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java deleted file mode 100644 index df9a49d85b9..00000000000 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; - -import java.util.List; -import java.util.UUID; - -import org.junit.jupiter.api.Test; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; - -import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.boot.autoconfigure.AutoConfigurations; -import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; -import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; -import org.springframework.boot.test.context.runner.ApplicationContextRunner; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Jonathan Leijendekker - */ -@Testcontainers -class JdbcChatMemoryAutoConfigurationIT { - - static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); - - @Container - @SuppressWarnings("resource") - static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME) - .withDatabaseName("chat_memory_auto_configuration_test") - .withUsername("postgres") - .withPassword("postgres"); - - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, - JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) - .withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), - String.format("spring.datasource.username=%s", postgresContainer.getUsername()), - String.format("spring.datasource.password=%s", postgresContainer.getPassword())); - - @Test - void jdbcChatMemoryScriptDatabaseInitializer_shouldBeLoaded() { - this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true") - .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue()); - } - - @Test - void jdbcChatMemoryScriptDatabaseInitializer_shouldNotBeLoaded() { - this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=false") - .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isFalse()); - } - - @Test - void addGetAndClear_shouldAllExecute() { - this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true").run(context -> { - var chatMemory = context.getBean(JdbcChatMemory.class); - var conversationId = UUID.randomUUID().toString(); - var userMessage = new UserMessage("Message from the user"); - - chatMemory.add(conversationId, userMessage); - - assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); - assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); - - chatMemory.clear(conversationId); - - assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); - - var multipleMessages = List.of(new UserMessage("Message from the user 1"), - new AssistantMessage("Message from the assistant 1")); - - chatMemory.add(conversationId, multipleMessages); - - assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(multipleMessages.size()); - assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(multipleMessages); - }); - } - -} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests.java similarity index 97% rename from auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java rename to auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests.java index f563c67cdf1..cd53f2bd8dd 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests.java @@ -35,7 +35,7 @@ * @author Jonathan Leijendekker */ @Testcontainers -class JdbcChatMemoryDataSourceScriptDatabaseInitializerTests { +class JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests { static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java new file mode 100644 index 00000000000..30caa5a23ee --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java @@ -0,0 +1,159 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + * @author Thomas Vitale + */ +class JdbcChatMemoryPostgresqlAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues("spring.datasource.url=jdbc:tc:postgresql:17:///"); + + @Test + void jdbcChatMemoryScriptDatabaseInitializer_shouldBeLoaded() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue()); + this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=true") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue()); + } + + @Test + void jdbcChatMemoryScriptDatabaseInitializer_shouldNotBeLoaded() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=false") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isFalse()); + this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=false") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isFalse()); + } + + @Test + void initializeSchemaEnabledWithDeprecatedProperty() { + this.contextRunner + .withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true", + "spring.ai.chat.memory.repository.jdbc.initialize-schema=false") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue()); + } + + @Test + void initializeSchemaEnabledWithNewProperty() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=true") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue()); + } + + @Test + void addGetAndClear_shouldAllExecute() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true").run(context -> { + var chatMemory = context.getBean(JdbcChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemory.add(conversationId, userMessage); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemory.add(conversationId, multipleMessages); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(multipleMessages.size()); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(multipleMessages); + }); + } + + @Test + void useAutoConfiguredJdbcChatMemoryRepository() { + this.contextRunner.run(context -> { + var chatMemoryRepository = context.getBean(JdbcChatMemoryRepository.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemoryRepository.saveAll(conversationId, List.of(userMessage)); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(1); + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(List.of(userMessage)); + + chatMemoryRepository.deleteByConversationId(conversationId); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemoryRepository.saveAll(conversationId, multipleMessages); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(multipleMessages.size()); + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEqualTo(multipleMessages); + }); + } + + @Test + void useAutoConfiguredChatMemoryWithJdbc() { + this.contextRunner.withConfiguration(AutoConfigurations.of(ChatMemoryAutoConfiguration.class)).run(context -> { + assertThat(context).hasSingleBean(ChatMemory.class); + assertThat(context).hasSingleBean(JdbcChatMemoryRepository.class); + + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemory.add(conversationId, userMessage); + + assertThat(chatMemory.get(conversationId)).hasSize(1); + assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage)); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemory.add(conversationId, multipleMessages); + + assertThat(chatMemory.get(conversationId)).hasSize(multipleMessages.size()); + assertThat(chatMemory.get(conversationId)).isEqualTo(multipleMessages); + }); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/pom.xml index ac407bd82bd..09614d0594a 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/pom.xml +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/pom.xml @@ -29,6 +29,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + org.springframework.boot diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java index dd94209eee2..554b2a3d83e 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory; import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; +import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -33,7 +34,7 @@ * @author Enrico Rampazzo * @since 1.0.0 */ -@AutoConfiguration(after = Neo4jAutoConfiguration.class) +@AutoConfiguration(after = Neo4jAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class) @ConditionalOnClass({ Neo4jChatMemory.class, Driver.class }) @EnableConfigurationProperties(Neo4jChatMemoryProperties.class) public class Neo4jChatMemoryAutoConfiguration { diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/pom.xml new file mode 100644 index 00000000000..76267fe9622 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/pom.xml @@ -0,0 +1,58 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.0.0-SNAPSHOT + ../../../../../pom.xml + + spring-ai-autoconfigure-model-chat-memory + jar + Spring AI Chat Memory Auto Configuration + Spring AI Chat Memory Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.ai + spring-ai-model + ${project.parent.version} + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + junit-jupiter + test + + + + diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/main/java/org/springframework/ai/model/chat/memory/autoconfigure/ChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/main/java/org/springframework/ai/model/chat/memory/autoconfigure/ChatMemoryAutoConfiguration.java new file mode 100644 index 00000000000..79aef937123 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/main/java/org/springframework/ai/model/chat/memory/autoconfigure/ChatMemoryAutoConfiguration.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.autoconfigure; + +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for {@link ChatMemory}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@AutoConfiguration +@ConditionalOnClass({ ChatMemory.class, ChatMemoryRepository.class }) +public class ChatMemoryAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + ChatMemoryRepository chatMemoryRepository() { + return new InMemoryChatMemoryRepository(); + } + + @Bean + @ConditionalOnMissingBean + ChatMemory chatMemory(ChatMemoryRepository chatMemoryRepository) { + return MessageWindowChatMemory.builder().chatMemoryRepository(chatMemoryRepository).build(); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000000..8864bea46e7 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/test/java/org/springframework/ai/model/chat/memory/autoconfigure/ChatMemoryAutoConfigurationTests.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/test/java/org/springframework/ai/model/chat/memory/autoconfigure/ChatMemoryAutoConfigurationTests.java new file mode 100644 index 00000000000..e40ba588808 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory/src/test/java/org/springframework/ai/model/chat/memory/autoconfigure/ChatMemoryAutoConfigurationTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.autoconfigure; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ChatMemoryAutoConfiguration}. + * + * @author Thomas Vitale + */ +class ChatMemoryAutoConfigurationTests { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(ChatMemoryAutoConfiguration.class)); + + @Test + void defaultConfiguration() { + contextRunner.run(context -> { + assertThat(context).hasSingleBean(ChatMemoryRepository.class); + assertThat(context).hasSingleBean(ChatMemory.class); + }); + } + + @Test + void whenChatMemoryRepositoryExists() { + contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> { + assertThat(context).hasSingleBean(ChatMemoryRepository.class); + assertThat(context).hasBean("customChatMemoryRepository"); + assertThat(context).doesNotHaveBean("chatMemoryRepository"); + }); + } + + @Test + void whenChatMemoryExists() { + contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> { + assertThat(context).hasSingleBean(ChatMemoryRepository.class); + assertThat(context).hasBean("customChatMemoryRepository"); + assertThat(context).doesNotHaveBean("chatMemoryRepository"); + }); + } + + @Configuration(proxyBeanMethods = false) + static class CustomChatMemoryRepositoryConfiguration { + + private final ChatMemoryRepository customChatMemoryRepository = new InMemoryChatMemoryRepository(); + + @Bean + ChatMemoryRepository customChatMemoryRepository() { + return customChatMemoryRepository; + } + + } + + @Configuration(proxyBeanMethods = false) + static class CustomChatMemoryConfiguration { + + private final ChatMemory customChatMemory = MessageWindowChatMemory.builder().build(); + + @Bean + ChatMemory customChatMemory() { + return customChatMemory; + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/README.md b/memory/spring-ai-model-chat-memory-jdbc/README.md index 8e100ad20a3..9fb420a9787 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/README.md +++ b/memory/spring-ai-model-chat-memory-jdbc/README.md @@ -1 +1 @@ -[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory) +[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatmemory.html) diff --git a/memory/spring-ai-model-chat-memory-jdbc/pom.xml b/memory/spring-ai-model-chat-memory-jdbc/pom.xml index c8a734d38b8..a2e86e527a7 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/pom.xml +++ b/memory/spring-ai-model-chat-memory-jdbc/pom.xml @@ -41,7 +41,7 @@ org.springframework.ai - spring-ai-client-chat + spring-ai-model ${project.version} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java index 6c9825bac1b..51827ed0f61 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java @@ -22,6 +22,7 @@ import java.util.List; import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -38,14 +39,17 @@ * * @author Jonathan Leijendekker * @since 1.0.0 + * @deprecated in favor of building a {@link MessageWindowChatMemory} (or other + * {@link ChatMemory} implementations) with a {@link JdbcChatMemoryRepository} instance. */ +@Deprecated public class JdbcChatMemory implements ChatMemory { private static final String QUERY_ADD = """ INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)"""; private static final String QUERY_GET = """ - SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?"""; + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" LIMIT ?"""; private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java index 5a503aef051..0211f05b853 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java @@ -24,7 +24,9 @@ * * @author Jonathan Leijendekker * @since 1.0.0 + * @deprecated in favor of using {@link JdbcChatMemoryRepository#builder()}. */ +@Deprecated public final class JdbcChatMemoryConfig { private final JdbcTemplate jdbcTemplate; diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java new file mode 100644 index 00000000000..09ba5f1653b --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java @@ -0,0 +1,156 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.*; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +/** + * An implementation of {@link ChatMemoryRepository} for JDBC. + * + * @author Jonathan Leijendekker + * @author Thomas Vitale + * @since 1.0.0 + */ +public class JdbcChatMemoryRepository implements ChatMemoryRepository { + + private static final String QUERY_GET_IDS = """ + SELECT DISTINCT conversation_id FROM ai_chat_memory + """; + + private static final String QUERY_ADD = """ + INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?) + """; + + private static final String QUERY_GET = """ + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" + """; + + private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; + + private final JdbcTemplate jdbcTemplate; + + private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate) { + Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null"); + this.jdbcTemplate = jdbcTemplate; + } + + @Override + public List findConversationIds() { + List conversationIds = this.jdbcTemplate.query(QUERY_GET_IDS, rs -> { + var ids = new ArrayList(); + while (rs.next()) { + ids.add(rs.getString(1)); + } + return ids; + }); + return conversationIds != null ? conversationIds : List.of(); + } + + @Override + public List findByConversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId); + } + + @Override + public void saveAll(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + this.deleteByConversationId(conversationId); + this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages)); + } + + @Override + public void deleteByConversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.jdbcTemplate.update(QUERY_CLEAR, conversationId); + } + + private record AddBatchPreparedStatement(String conversationId, + List messages) implements BatchPreparedStatementSetter { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var message = this.messages.get(i); + + ps.setString(1, this.conversationId); + ps.setString(2, message.getText()); + ps.setString(3, message.getMessageType().name()); + } + + @Override + public int getBatchSize() { + return this.messages.size(); + } + } + + private static class MessageRowMapper implements RowMapper { + + @Override + @Nullable + public Message mapRow(ResultSet rs, int i) throws SQLException { + var content = rs.getString(1); + var type = MessageType.valueOf(rs.getString(2)); + + return switch (type) { + case USER -> new UserMessage(content); + case ASSISTANT -> new AssistantMessage(content); + case SYSTEM -> new SystemMessage(content); + // The content is always stored empty for ToolResponseMessages. + // If we want to capture the actual content, we need to extend + // AddBatchPreparedStatement to support it. + case TOOL -> new ToolResponseMessage(List.of()); + }; + } + + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private JdbcTemplate jdbcTemplate; + + private Builder() { + } + + public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + return this; + } + + public JdbcChatMemoryRepository build() { + return new JdbcChatMemoryRepository(this.jdbcTemplate); + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHints.java similarity index 94% rename from memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java rename to memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHints.java index 6740602e3f8..3b518733ffb 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHints.java @@ -27,7 +27,7 @@ * * @author Jonathan Leijendekker */ -class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar { +class JdbcChatMemoryRepositoryRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/package-info.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/package-info.java new file mode 100644 index 00000000000..a26f200dec1 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.memory.jdbc; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories index 4b6f4a8f5ce..7169645c1e9 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories @@ -1,2 +1,2 @@ org.springframework.aot.hint.RuntimeHintsRegistrar=\ -org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRuntimeHints +org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRepositoryRuntimeHints diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql index 88c0ea11ba0..174c3b545fd 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql @@ -7,4 +7,4 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory ( ); CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx -ON ai_chat_memory(conversation_id, `timestamp` DESC); +ON ai_chat_memory(conversation_id, `timestamp`); diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql index 11e60194b60..31a9f301e03 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql @@ -6,4 +6,4 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory ( ); CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx -ON ai_chat_memory(conversation_id, "timestamp" DESC); +ON ai_chat_memory(conversation_id, "timestamp"); diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java new file mode 100644 index 00000000000..5e5abc6ac41 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java @@ -0,0 +1,168 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.context.TestPropertySource; +import org.springframework.test.context.jdbc.Sql; + +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link JdbcChatMemoryRepository}. + * + * @author Jonathan Leijendekker + * @author Thomas Vitale + */ +@SpringBootTest(classes = JdbcChatMemoryRepositoryPostgresqlIT.TestConfiguration.class) +@TestPropertySource(properties = "spring.datasource.url=jdbc:tc:postgresql:17:///") +@Sql(scripts = "classpath:org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql") +class JdbcChatMemoryRepositoryPostgresqlIT { + + @Autowired + private ChatMemoryRepository chatMemoryRepository; + + @Autowired + private JdbcTemplate jdbcTemplate; + + @Test + void correctChatMemoryRepositoryInstance() { + assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) + void saveMessagesSingleMessage(String content, MessageType messageType) { + var conversationId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + case SYSTEM -> new SystemMessage(content + " - " + conversationId); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemoryRepository.saveAll(conversationId, List.of(message)); + + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var result = jdbcTemplate.queryForMap(query, conversationId); + + assertThat(result.size()).isEqualTo(4); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + + @Test + void saveMessagesMultipleMessages() { + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.saveAll(conversationId, messages); + + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var results = jdbcTemplate.queryForList(query, conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.get("conversation_id")).isNotNull(); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + + var count = chatMemoryRepository.findByConversationId(conversationId).size(); + assertThat(count).isEqualTo(messages.size()); + + chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello"))); + + count = chatMemoryRepository.findByConversationId(conversationId).size(); + assertThat(count).isEqualTo(1); + } + + @Test + void findMessagesByConversationId() { + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.saveAll(conversationId, messages); + + var results = chatMemoryRepository.findByConversationId(conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); + } + + @Test + void deleteMessagesByConversationId() { + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.saveAll(conversationId, messages); + + chatMemoryRepository.deleteByConversationId(conversationId); + + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?", + Integer.class, conversationId); + + assertThat(count).isZero(); + } + + @SpringBootConfiguration + @ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class }) + static class TestConfiguration { + + @Bean + ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) { + return JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build(); + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHintsTest.java similarity index 83% rename from memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java rename to memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHintsTest.java index 90c65272d72..d507c712eda 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHintsTest.java @@ -38,18 +38,18 @@ /** * @author Jonathan Leijendekker */ -class JdbcChatMemoryRuntimeHintsTest { +class JdbcChatMemoryRepositoryRuntimeHintsTest { private final RuntimeHints hints = new RuntimeHints(); - private final JdbcChatMemoryRuntimeHints jdbcChatMemoryRuntimeHints = new JdbcChatMemoryRuntimeHints(); + private final JdbcChatMemoryRepositoryRuntimeHints jdbcChatMemoryRepositoryRuntimeHints = new JdbcChatMemoryRepositoryRuntimeHints(); @Test void aotFactoriesContainsRegistrar() { var match = SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories") .load(RuntimeHintsRegistrar.class) .stream() - .anyMatch(registrar -> registrar instanceof JdbcChatMemoryRuntimeHints); + .anyMatch(registrar -> registrar instanceof JdbcChatMemoryRepositoryRuntimeHints); assertThat(match).isTrue(); } @@ -57,7 +57,7 @@ void aotFactoriesContainsRegistrar() { @ParameterizedTest @MethodSource("getSchemaFileNames") void jdbcSchemasHasHints(String schemaFileName) { - this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + this.jdbcChatMemoryRepositoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); var predicate = RuntimeHintsPredicates.resource() .forResource("org/springframework/ai/chat/memory/jdbc/" + schemaFileName); @@ -67,7 +67,7 @@ void jdbcSchemasHasHints(String schemaFileName) { @Test void dataSourceHasHints() { - this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + this.jdbcChatMemoryRepositoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); assertThat(RuntimeHintsPredicates.reflection().onType(DataSource.class)).accepts(this.hints); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index 658dce1ec3d..f4c9ba162d3 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -32,6 +33,9 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -48,6 +52,12 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -330,6 +340,79 @@ void streamFunctionCallUsageTest() { assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(650); } + @Test + void chatMemory() { + ChatMemory memory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + UserMessage userMessage1 = new UserMessage("My name is James Bond"); + memory.add(conversationId, userMessage1); + ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response1).isNotNull(); + memory.add(conversationId, response1.getResult().getOutput()); + + UserMessage userMessage2 = new UserMessage("What is my name?"); + memory.add(conversationId, userMessage2); + ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response2).isNotNull(); + memory.add(conversationId, response2.getResult().getOutput()); + + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResult().getOutput().getText()).contains("James Bond"); + } + + @Test + void chatMemoryWithTools() { + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); + ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(new MathTools())) + .internalToolExecutionEnabled(false) + .build(); + Prompt prompt = new Prompt( + List.of(new SystemMessage("You are a helpful assistant."), new UserMessage("What is 6 * 8?")), + chatOptions); + chatMemory.add(conversationId, prompt.getInstructions()); + + Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + ChatResponse chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + + while (chatResponse.hasToolCalls()) { + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(promptWithMemory, + chatResponse); + chatMemory.add(conversationId, toolExecutionResult.conversationHistory() + .get(toolExecutionResult.conversationHistory().size() - 1)); + promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + } + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).contains("48"); + + UserMessage newUserMessage = new UserMessage("What did I ask you earlier?"); + chatMemory.add(conversationId, newUserMessage); + + ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId))); + + assertThat(newResponse).isNotNull(); + assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8"); + } + + static class MathTools { + + @Tool(description = "Multiply the two numbers") + double multiply(double a, double b) { + return a * b; + } + + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 8709a5b8b3a..276f429a75c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonProperty; @@ -25,8 +26,11 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; @@ -38,12 +42,18 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -252,6 +262,79 @@ void jsonSchemaFormatStructuredOutput() { assertThat(countryInfo.capital()).isEqualToIgnoringCase("Copenhagen"); } + @Test + void chatMemory() { + ChatMemory memory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + UserMessage userMessage1 = new UserMessage("My name is James Bond"); + memory.add(conversationId, userMessage1); + ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response1).isNotNull(); + memory.add(conversationId, response1.getResult().getOutput()); + + UserMessage userMessage2 = new UserMessage("What is my name?"); + memory.add(conversationId, userMessage2); + ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response2).isNotNull(); + memory.add(conversationId, response2.getResult().getOutput()); + + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResult().getOutput().getText()).contains("James Bond"); + } + + @Test + void chatMemoryWithTools() { + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); + ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(new MathTools())) + .internalToolExecutionEnabled(false) + .build(); + Prompt prompt = new Prompt( + List.of(new SystemMessage("You are a helpful assistant."), new UserMessage("What is 6 * 8?")), + chatOptions); + chatMemory.add(conversationId, prompt.getInstructions()); + + Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + ChatResponse chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + + while (chatResponse.hasToolCalls()) { + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(promptWithMemory, + chatResponse); + chatMemory.add(conversationId, toolExecutionResult.conversationHistory() + .get(toolExecutionResult.conversationHistory().size() - 1)); + promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + } + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).contains("48"); + + UserMessage newUserMessage = new UserMessage("What did I ask you earlier?"); + chatMemory.add(conversationId, newUserMessage); + + ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId))); + + assertThat(newResponse).isNotNull(); + assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8"); + } + + static class MathTools { + + @Tool(description = "Multiply the two numbers") + double multiply(double a, double b) { + return a * b; + } + + } + record CountryInfo(@JsonProperty(required = true) String name, @JsonProperty(required = true) String capital, @JsonProperty(required = true) List languages) { } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 5dbd922602c..e0c1a4da8bb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -37,14 +38,23 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; @@ -628,6 +638,79 @@ void validateStoreAndMetadata() { assertThat(response).isNotNull(); } + @Test + void chatMemory() { + ChatMemory memory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + UserMessage userMessage1 = new UserMessage("My name is James Bond"); + memory.add(conversationId, userMessage1); + ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response1).isNotNull(); + memory.add(conversationId, response1.getResult().getOutput()); + + UserMessage userMessage2 = new UserMessage("What is my name?"); + memory.add(conversationId, userMessage2); + ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response2).isNotNull(); + memory.add(conversationId, response2.getResult().getOutput()); + + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResult().getOutput().getText()).contains("James Bond"); + } + + @Test + void chatMemoryWithTools() { + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); + ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(new MathTools())) + .internalToolExecutionEnabled(false) + .build(); + Prompt prompt = new Prompt( + List.of(new SystemMessage("You are a helpful assistant."), new UserMessage("What is 6 * 8?")), + chatOptions); + chatMemory.add(conversationId, prompt.getInstructions()); + + Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + ChatResponse chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + + while (chatResponse.hasToolCalls()) { + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(promptWithMemory, + chatResponse); + chatMemory.add(conversationId, toolExecutionResult.conversationHistory() + .get(toolExecutionResult.conversationHistory().size() - 1)); + promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + } + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).contains("48"); + + UserMessage newUserMessage = new UserMessage("What did I ask you earlier?"); + chatMemory.add(conversationId, newUserMessage); + + ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId))); + + assertThat(newResponse).isNotNull(); + assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8"); + } + + static class MathTools { + + @Tool(description = "Multiply the two numbers") + double multiply(double a, double b) { + return a * b; + } + + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/pom.xml b/pom.xml index b80480f88d7..2f9dd335a1b 100644 --- a/pom.xml +++ b/pom.xml @@ -51,6 +51,7 @@ auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client + auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index fc7e6f5c7c4..1f344dd0143 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -148,6 +148,12 @@ + + org.springframework.ai + spring-ai-model-chat-memory + ${project.version} + + org.springframework.ai spring-ai-model-chat-memory-cassandra diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 72a69f657a3..3d778937647 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.function.Function; +import org.springframework.ai.chat.memory.ChatMemory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -53,7 +54,9 @@ public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, /** * The default conversation id to use when no conversation id is provided. + * @deprecated in favor of {@link ChatMemory#DEFAULT_CONVERSATION_ID}. */ + @Deprecated public static final String DEFAULT_CHAT_MEMORY_CONVERSATION_ID = "default"; /** @@ -91,7 +94,7 @@ public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, * @param chatMemory the chat memory store */ protected AbstractChatMemoryAdvisor(T chatMemory) { - this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true); + this(chatMemory, ChatMemory.DEFAULT_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true); } /** @@ -204,7 +207,7 @@ public static abstract class AbstractBuilder { /** * The conversation id. */ - protected String conversationId = DEFAULT_CHAT_MEMORY_CONVERSATION_ID; + protected String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID; /** * The chat memory retrieve size. diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java index 73318fc8c4b..3056bb5de18 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java @@ -34,7 +34,10 @@ * @see ChatMemory * @author Christian Tzolov * @since 1.0.0 M1 + * @deprecated in favor of {@link MessageWindowChatMemory}, which internally uses + * {@link InMemoryChatMemoryRepository}. */ +@Deprecated public class InMemoryChatMemory implements ChatMemory { Map> conversationHistory = new ConcurrentHashMap<>(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 86e963deb0d..34d211039b1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -94,6 +94,7 @@ * xref:api/retrieval-augmented-generation.adoc[Retrieval Augmented Generation (RAG)] ** xref:api/etl-pipeline.adoc[] * xref:api/structured-output-converter.adoc[Structured Output] +* xref:api/chat-memory.adoc[Chat Memory] * xref:api/tools.adoc[Tool Calling] ** xref:api/tools-migration.adoc[Migrating to ToolCallback API] * xref:api/mcp/mcp-overview.adoc[Model Context Protocol (MCP)] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc new file mode 100644 index 00000000000..df5201495b4 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc @@ -0,0 +1,188 @@ +[[ChatMemory]] += Chat Memory + +Large language models (LLMs) are stateless, meaning they do not retain information about previous interactions. This can be a limitation when you want to maintain context or state across multiple interactions. To address this, Spring AI provides a `ChatMemory` abstraction that allows you to store and retrieve information across multiple interactions with the LLM. + +== Quick Start + +Spring AI auto-configures a `ChatMemory` bean that you can use directly in your application. By default, it uses an in-memory repository to store messages (`InMemoryChatMemoryRepository`) and a `MessageWindowChatMemory` implementation to manage the conversation history. If a different repository is already configured (e.g., Cassandra, JDBC, or Neo4j), Spring AI will use that instead. + +[source,java] +---- +@Autowired +ChatMemory chatMemory; +---- + +The following sections will describe further the different memory types and repositories available in Spring AI. + +== Memory Types + +The `ChatMemory` abstraction allows you to implement various types of memory to suit different use cases. The choice of memory type can significantly impact the performance and behavior of your application. This section describes the built-in memory types provided by Spring AI and their characteristics. + +=== Message Window Chat Memory + +`MessageWindowChatMemory` maintains a window of messages up to a specified maximum size. When the number of messages exceeds the maximum, older messages are removed while preserving system messages. The default window size is 20 messages. + +[source,java] +---- +MessageWindowChatMemory memory = MessageWindowChatMemory.builder() + .maxMessages(10) + .build(); +---- + +This is the default message type used by Spring AI to auto-configure a `ChatMemory` bean. + +== Memory Storage + +Spring AI offers the `ChatMemoryRepository` abstraction for storing chat memory. This section describes the built-in repositories provided by Spring AI and how to use them, but you can also implement your own repository if needed. + +=== In-Memory Repository + +`InMemoryChatMemoryRepository` stores messages in memory using a `ConcurrentHashMap`. + +By default, if no other repository is already configured, Spring AI auto-configures a `ChatMemoryRepository` bean of type `InMemoryChatMemoryRepository` that you can use directly in your application. + +[source,java] +---- +@Autowired +ChatMemoryRepository chatMemoryRepository; +---- + +If you'd rather create the `InMemoryChatMemoryRepository` manually, you can do so as follows: + +[source,java] +---- +ChatMemoryRepository repository = new InMemoryChatMemoryRepository(); +---- + +=== JDBC Repository + +`JdbcChatMemoryRepository` is a built-in implementation that uses JDBC to store messages in a relational database. It is suitable for applications that require persistent storage of chat memory. + +First, add the following dependency to your project: + +[tabs] +====== +Maven:: ++ +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-chat-memory-jdbc + +---- + +Gradle:: ++ +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-jdbc' +} +---- +====== + +Spring AI provides auto-configuration for the `JdbcChatMemoryRepository`, that you can use directly in your application. + +[source,java] +---- +@Autowired +JdbcChatMemoryRepository chatMemoryRepository; + +ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(10) + .build(); +---- + +If you'd rather create the `JdbcChatMemoryRepository` manually, you can do so by providing a `JdbcTemplate` instance: + +[source,java] +---- +ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder() + .jdbcTemplate(jdbcTemplate) + .build(); + +ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(10) + .build(); +---- + +==== Configuration Properties + +[cols="2,5,1",stripes=even] +|=== +|Property | Description | Default Value +| `spring.ai.chat.memory.repository.jdbc.initialize-schema` | Whether to initialize the schema on startup. | `true` +|=== + +==== Schema Initialization + +The auto-configuration will automatically create the `ai_chat_memory` table using the JDBC driver. Currently, only PostgreSQL and MariaDB are supported. + +You can disable the schema initialization by setting the property `spring.ai.chat.memory.repository.jdbc.initialize-schema` to `false`. + +If your project uses a tool like Flyway or Liquibase to manage your database schemas, you can disable the schema initialization and refer to link:https://github.com/spring-projects/spring-ai/tree/main/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc[these SQL scripts] for configuring those tools to create the `ai_chat_memory` table. + +== Memory in Chat Client + +When using the ChatClient API, you can provide a `ChatMemory` implementation to maintain conversation context across multiple interactions. + +Spring AI provides a few built-in Advisors that you can use to configure the memory behavior of the `ChatClient`, based on your needs. + +WARNING: Currently, the intermediate messages exchanged with a large-language model when performing tool calls are not stored in the memory. This is a limitation of the current implementation and will be addressed in future releases. If you need to store these messages, refer to the instructions for the xref:api/tools.adoc#_user_controlled_tool_execution[User Controlled Tool Execution]. + +* `MessageChatMemoryAdvisor`. This advisor manages the conversation memory using the provided `ChatMemory` implementation. On each interaction, it retrieves the conversation history from the memory and includes it in the prompt as a collection of messages. +* `PromptChatMemoryAdvisor`. This advisor manages the conversation memory using the provided `ChatMemory` implementation. On each interaction, it retrieves the conversation history from the memory and appends it to the system prompt as plain text. +* `VectorStoreChatMemoryAdvisor`. This advisor manages the conversation memory using the provided `VectorStore` implementation. On each interaction, it retrieves the conversation history from the vector store and appends it to the system message as plain text. + +For example, if you want to use `MessageWindowChatMemory` with the `MessageChatMemoryAdvisor`, you can configure it as follows: + +[source,java] +---- +ChatMemory chatMemory = MessageChatMemoryAdvisor.builder().build(); + +ChatClient chatClient = ChatClient.builder(chatModel) + .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build()) + .build(); +---- + +When performing a call to the `ChatClient`, the memory will be automatically managed by the `MessageChatMemoryAdvisor`. The conversation history will be retrieved from the memory based on the specified conversation ID: + +[source,java] +---- +String conversationId = "007"; + +chatClient.prompt() + .user("Do I have license to code?") + .advisors(a -> a.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId)) + .call() + .content(); +---- + +== Memory in Chat Model + +If you're working directly with a `ChatModel` instead of a `ChatClient`, you can manage the memory explicitly: + +[source,java] +---- +// Create a memory instance +ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); +String conversationId = "007"; + +// First interaction +UserMessage userMessage1 = new UserMessage("My name is James Bond"); +chatMemory.add(conversationId, userMessage1); +ChatResponse response1 = chatModel.call(new Prompt(chatMemory.get(conversationId))); +chatMemory.add(conversationId, response1.getResult().getOutput()); + +// Second interaction +UserMessage userMessage2 = new UserMessage("What is my name?"); +chatMemory.add(conversationId, userMessage2); +ChatResponse response2 = chatModel.call(new Prompt(chatMemory.get(conversationId))); +chatMemory.add(conversationId, response2.getResult().getOutput()); + +// The response will contain "James Bond" +---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 7b445d067b3..bbe03c1a577 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -366,6 +366,8 @@ In this configuration, the `MessageChatMemoryAdvisor` will be executed first, ad xref:ROOT:api/retrieval-augmented-generation.adoc#_questionansweradvisor[Learn about Question Answer Advisor] +TIP: Refer to the xref:api/chat-memory.adoc[Chat Memory] documentation for more information on how to use the `ChatMemory` interface to manage conversation history in combination with the advisors. + The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt * `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt @@ -485,7 +487,9 @@ This allows you to tailor the logged information to your specific needs. TIP: Be cautious about logging sensitive information in production environments. -== Chat Memory +== Chat Memory (Deprecated) + +IMPORTANT: Refer to the new xref:api/chat-memory.adoc[Chat Memory] documentation for the current features and capabilities. The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc index 7013ece276c..461228ffa9e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc @@ -1088,11 +1088,8 @@ image::tools/framework-manager.jpg[Framework-controlled tool execution lifecycle WARNING: Currently, the internal messages exchanged with the model regarding the tool execution are not exposed to the user. If you need to access these messages, you should use the user-controlled tool execution approach. -=== User-Controlled Tool Execution +The logic determining whether a tool call is eligible for execution is handled by the `ToolExecutionEligibilityPredicate` interface. By default, the tool execution eligibility is determined by checking if the `internalToolExecutionEnabled` attribute of `ToolCallingChatOptions` is set to `true` (the default value), and if the `ChatResponse` contains any tool calls. -There are cases where you'd rather control the tool execution lifecycle yourself. You can do so by setting the `internalToolExecutionEnabled` attribute of `ToolCallingChatOptions` to `false`. -Alternatevly you can implement your `ToolExecutionEligibilityPredicate` predicate to control the tool execution eligibility. -The default predicate implementation looks like this: [source,java] ---- public class DefaultToolExecutionEligibilityPredicate implements ToolExecutionEligibilityPredicate { @@ -1102,9 +1099,16 @@ public class DefaultToolExecutionEligibilityPredicate implements ToolExecutionEl return ToolCallingChatOptions.isInternalToolExecutionEnabled(promptOptions) && chatResponse != null && chatResponse.hasToolCalls(); } + } ---- +You can provide your custom implementation of `ToolExecutionEligibilityPredicate` when creating the `ChatModel` bean. + +=== User-Controlled Tool Execution + +There are cases where you'd rather control the tool execution lifecycle yourself. You can do so by setting the `internalToolExecutionEnabled` attribute of `ToolCallingChatOptions` to `false`. + When you invoke a `ChatModel` with this option, the tool execution will be delegated to the caller, giving you full control over the tool execution lifecycle. It's your responsibility checking for tool calls in the `ChatResponse` and executing them using the `ToolCallingManager`. The following example demonstrates a minimal implementation of the user-controlled tool execution approach: @@ -1135,6 +1139,43 @@ System.out.println(chatResponse.getResult().getOutput().getText()); NOTE: When choosing the user-controlled tool execution approach, we recommend using a `ToolCallingManager` to manage the tool calling operations. This way, you can benefit from the built-in support provided by Spring AI for tool execution. However, nothing prevents you from implementing your own tool execution logic. +The next examples shows a minimal implementation of the user-controlled tool execution approach combined with the usage of the `ChatMemory` API: + +[source,java] +---- +ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); +ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); +String conversationId = UUID.randomUUID().toString(); + +ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(new MathTools())) + .internalToolExecutionEnabled(false) + .build(); +Prompt prompt = new Prompt( + List.of(new SystemMessage("You are a helpful assistant."), new UserMessage("What is 6 * 8?")), + chatOptions); +chatMemory.add(conversationId, prompt.getInstructions()); + +Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); +ChatResponse chatResponse = chatModel.call(promptWithMemory); +chatMemory.add(conversationId, chatResponse.getResult().getOutput()); + +while (chatResponse.hasToolCalls()) { + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(promptWithMemory, + chatResponse); + chatMemory.add(conversationId, toolExecutionResult.conversationHistory() + .get(toolExecutionResult.conversationHistory().size() - 1)); + promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions); + chatResponse = chatModel.call(promptWithMemory); + chatMemory.add(conversationId, chatResponse.getResult().getOutput()); +} + +UserMessage newUserMessage = new UserMessage("What did I ask you earlier?"); +chatMemory.add(conversationId, newUserMessage); + +ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId))); +---- + === Exception Handling When a tool call fails, the exception is propagated as a `ToolExecutionException` which can be caught to handle the error. A `ToolExecutionExceptionProcessor` can be used to handle a `ToolExecutionException` with two outcomes: either producing an error message to be sent back to the AI model or throwing an exception to be handled by the caller. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc index bd067cc5bb9..e9cecc0d42a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc @@ -46,6 +46,17 @@ This approach can save time and reduce the chance of errors when upgrading multi [[upgrading-to-1-0-0-m8]] == Upgrading to 1.0.0-M8 +=== Chat Memory + +* A `ChatMemory` bean is auto-configured for you whenever using one of the Spring AI Model starters. By default, it uses the `MessageWindowChatMemory` implementation and stores the conversation history in memory. +* The `ChatMemory` API has been enhanced to support a more flexible and extensible way of managing conversation history. The storage mechanism has been decoupled from the `ChatMemory` interface and is now handled by a new `ChatMemoryRepository` interface. The `ChatMemory` API now can be used to implement different memory strategies without being tied to a specific storage mechanism. By default, Spring AI provides a `MessageWindowChatMemory` implementation that maintains a window of messages up to a specified maximum size. +* The `get(String conversationId, int lastN)` method in `ChatMemory` has been deprecated in favour of using `MessageWindowChatMemory` when it's needed to keep messages in memory up to a certain limit. The `get(String conversationId)` method is now the preferred way to retrieve messages from the memory whereas the specific implementation of `ChatMemory` can decide the strategy for filtering, processing, and returning messages. +* The `JdbcChatMemory` has been deprecated in favour of using `JdbcChatMemoryRepository` together with a `ChatMemory` implementation such `MessageWindowChatMemory`. If you were relying on an auto-configured `JdbcChatMemory` bean, you can replace that by auto-wiring a `ChatMemory` bean that is auto-configured to use the `JdbcChatMemoryRepository` internally for storing messages whenever the related dependency is in the classpath. +* The `spring.ai.chat.memory.jdbc.initialize-schema` property has been deprecated in favor of `spring.ai.chat.memory.repository.jdbc.initialize-schema`. +* Refer to the new xref:api/chat-memory.adoc[Chat Memory] documentation for more details on the new API and how to use it. + +=== Prompt Templating + * The `PromptTemplate` API has been redesigned to support a more flexible and extensible way of templating prompts, relying on a new `TemplateRenderer` API. As part of this change, the `getInputVariables()` and `validate()` methods have been deprecated and will throw an `UnsupportedOperationException` if called. Any logic specific to a template engine should be available through the `TemplateRenderer` API. [[upgrading-to-1-0-0-m7]] diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java new file mode 100644 index 00000000000..5d99b43392a --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * The contract for storing and managing the memory of chat conversations. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ChatMemory { + + String DEFAULT_CONVERSATION_ID = "default"; + + /** + * Save the specified message in the chat memory for the specified conversation. + */ + default void add(String conversationId, Message message) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(message, "message cannot be null"); + this.add(conversationId, List.of(message)); + } + + /** + * Save the specified messages in the chat memory for the specified conversation. + */ + void add(String conversationId, List messages); + + /** + * Get the messages in the chat memory for the specified conversation. + */ + default List get(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return get(conversationId, Integer.MAX_VALUE); + } + + /** + * @deprecated in favor of using {@link MessageWindowChatMemory}. + */ + @Deprecated + List get(String conversationId, int lastN); + + /** + * Clear the chat memory for the specified conversation. + */ + void clear(String conversationId); + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java similarity index 52% rename from spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java rename to spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java index 7003457df16..ff4b823a174 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,30 +16,28 @@ package org.springframework.ai.chat.memory; -import java.util.List; - import org.springframework.ai.chat.messages.Message; +import java.util.List; + /** - * The ChatMemory interface represents a storage for chat conversation history. It - * provides methods to add messages to a conversation, retrieve messages from a - * conversation, and clear the conversation history. + * A repository for storing and retrieving chat messages. * - * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ -public interface ChatMemory { - - // TODO: consider a non-blocking interface for streaming usages +public interface ChatMemoryRepository { - default void add(String conversationId, Message message) { - this.add(conversationId, List.of(message)); - } + List findConversationIds(); - void add(String conversationId, List messages); + List findByConversationId(String conversationId); - List get(String conversationId, int lastN); + /** + * Replaces all the existing messages for the given conversation ID with the provided + * messages. + */ + void saveAll(String conversationId, List messages); - void clear(String conversationId); + void deleteByConversationId(String conversationId); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java new file mode 100644 index 00000000000..290ccfb4174 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * An in-memory implementation of {@link ChatMemoryRepository}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class InMemoryChatMemoryRepository implements ChatMemoryRepository { + + Map> chatMemoryStore = new ConcurrentHashMap<>(); + + @Override + public List findConversationIds() { + return new ArrayList<>(this.chatMemoryStore.keySet()); + } + + @Override + public List findByConversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + List messages = this.chatMemoryStore.get(conversationId); + return messages != null ? new ArrayList<>(messages) : List.of(); + } + + @Override + public void saveAll(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + this.chatMemoryStore.put(conversationId, messages); + } + + @Override + public void deleteByConversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.chatMemoryStore.remove(conversationId); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java new file mode 100644 index 00000000000..8b289d1b276 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A chat memory implementation that maintains a message window of a specified size, + * ensuring that the total number of messages does not exceed the specified limit. When + * the number of messages exceeds the maximum size, older messages are evicted. + *

+ * Messages of type {@link SystemMessage} are treated specially: if a new + * {@link SystemMessage} is added, all previous {@link SystemMessage} instances are + * removed from the memory. Also, if the total number of messages exceeds the limit, the + * {@link SystemMessage} messages are preserved while evicting other types of messages. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class MessageWindowChatMemory implements ChatMemory { + + private static final int DEFAULT_MAX_MESSAGES = 200; + + private static final ChatMemoryRepository DEFAULT_CHAT_MEMORY_REPOSITORY = new InMemoryChatMemoryRepository(); + + private final ChatMemoryRepository chatMemoryRepository; + + private final int maxMessages; + + private MessageWindowChatMemory(ChatMemoryRepository chatMemoryRepository, int maxMessages) { + Assert.notNull(chatMemoryRepository, "chatMemoryRepository cannot be null"); + Assert.isTrue(maxMessages > 0, "maxMessages must be greater than 0"); + this.chatMemoryRepository = chatMemoryRepository; + this.maxMessages = maxMessages; + } + + @Override + public void add(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + + List memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId); + List processedMessages = process(memoryMessages, messages); + this.chatMemoryRepository.saveAll(conversationId, processedMessages); + } + + @Override + public List get(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return this.chatMemoryRepository.findByConversationId(conversationId); + } + + @Override + @Deprecated // in favor of get(conversationId) + public List get(String conversationId, int lastN) { + return get(conversationId); + } + + @Override + public void clear(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.chatMemoryRepository.deleteByConversationId(conversationId); + } + + private List process(List memoryMessages, List newMessages) { + List processedMessages = new ArrayList<>(); + + Set memoryMessagesSet = new HashSet<>(memoryMessages); + boolean hasNewSystemMessage = newMessages.stream() + .filter(SystemMessage.class::isInstance) + .anyMatch(message -> !memoryMessagesSet.contains(message)); + + memoryMessages.stream() + .filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage)) + .forEach(processedMessages::add); + + processedMessages.addAll(newMessages); + + if (processedMessages.size() <= this.maxMessages) { + return processedMessages; + } + + int messagesToRemove = processedMessages.size() - this.maxMessages; + + List trimmedMessages = new ArrayList<>(); + int removed = 0; + for (Message message : processedMessages) { + if (message instanceof SystemMessage || removed >= messagesToRemove) { + trimmedMessages.add(message); + } + else { + removed++; + } + } + + return trimmedMessages; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ChatMemoryRepository chatMemoryRepository = DEFAULT_CHAT_MEMORY_REPOSITORY; + + private int maxMessages = DEFAULT_MAX_MESSAGES; + + private Builder() { + } + + public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) { + this.chatMemoryRepository = chatMemoryRepository; + return this; + } + + public Builder maxMessages(int maxMessages) { + this.maxMessages = maxMessages; + return this; + } + + public MessageWindowChatMemory build() { + return new MessageWindowChatMemory(chatMemoryRepository, maxMessages); + } + + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/package-info.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/package-info.java new file mode 100644 index 00000000000..2dd55ff2556 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.memory; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java new file mode 100644 index 00000000000..6fc4315c19e --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java @@ -0,0 +1,154 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link InMemoryChatMemoryRepository}. + * + * @author Thomas Vitale + */ +public class InMemoryChatMemoryRepositoryTests { + + private final InMemoryChatMemoryRepository chatMemoryRepository = new InMemoryChatMemoryRepository(); + + @Test + void findConversationIds() { + String conversationId1 = UUID.randomUUID().toString(); + String conversationId2 = UUID.randomUUID().toString(); + List messages1 = List.of(new UserMessage("Hello")); + List messages2 = List.of(new AssistantMessage("Hi there")); + + chatMemoryRepository.saveAll(conversationId1, messages1); + chatMemoryRepository.saveAll(conversationId2, messages2); + + assertThat(chatMemoryRepository.findConversationIds()).containsExactlyInAnyOrder(conversationId1, + conversationId2); + + chatMemoryRepository.deleteByConversationId(conversationId1); + assertThat(chatMemoryRepository.findConversationIds()).containsExactlyInAnyOrder(conversationId2); + } + + @Test + void saveMessagesAndFindMultipleMessagesInConversation() { + String conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("I, Robot"), new UserMessage("Hello")); + + chatMemoryRepository.saveAll(conversationId, messages); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).containsAll(messages); + + chatMemoryRepository.deleteByConversationId(conversationId); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty(); + } + + @Test + void saveMessagesAndFindSingleMessageInConversation() { + String conversationId = UUID.randomUUID().toString(); + Message message = new UserMessage("Hello"); + List messages = List.of(message); + + chatMemoryRepository.saveAll(conversationId, messages); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).contains(message); + + chatMemoryRepository.deleteByConversationId(conversationId); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty(); + } + + @Test + void findNonExistingConversation() { + String conversationId = UUID.randomUUID().toString(); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty(); + } + + @Test + void subsequentSaveOverwritesPreviousVersion() { + String conversationId = UUID.randomUUID().toString(); + List firstMessages = List.of(new UserMessage("Hello")); + List secondMessages = List.of(new AssistantMessage("Hi there")); + + chatMemoryRepository.saveAll(conversationId, firstMessages); + chatMemoryRepository.saveAll(conversationId, secondMessages); + + assertThat(chatMemoryRepository.findByConversationId(conversationId)).containsExactlyElementsOf(secondMessages); + } + + @Test + void nullConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemoryRepository.saveAll(null, List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.findByConversationId(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.deleteByConversationId(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void emptyConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemoryRepository.saveAll("", List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.findByConversationId("")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.deleteByConversationId("")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void nullMessagesNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> chatMemoryRepository.saveAll(conversationId, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot be null"); + } + + @Test + void messagesWithNullElementsNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + List messagesWithNull = new ArrayList<>(); + messagesWithNull.add(null); + + assertThatThrownBy(() -> chatMemoryRepository.saveAll(conversationId, messagesWithNull)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot contain null elements"); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/MessageWindowChatMemoryTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/MessageWindowChatMemoryTests.java new file mode 100644 index 00000000000..6684e48d3d5 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/MessageWindowChatMemoryTests.java @@ -0,0 +1,294 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MessageWindowChatMemory}. + * + * @author Thomas Vitale + */ +public class MessageWindowChatMemoryTests { + + private final MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + + @Test + void zeroMaxMessagesNotAllowed() { + assertThatThrownBy(() -> MessageWindowChatMemory.builder().maxMessages(0).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("maxMessages must be greater than 0"); + } + + @Test + void negativeMaxMessagesNotAllowed() { + assertThatThrownBy(() -> MessageWindowChatMemory.builder().maxMessages(-1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("maxMessages must be greater than 0"); + } + + @Test + void handleMultipleMessagesInConversation() { + String conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("I, Robot"), new UserMessage("Hello")); + + chatMemory.add(conversationId, messages); + + assertThat(chatMemory.get(conversationId)).containsAll(messages); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId)).isEmpty(); + } + + @Test + void handleSingleMessageInConversation() { + String conversationId = UUID.randomUUID().toString(); + Message message = new UserMessage("Hello"); + + chatMemory.add(conversationId, message); + + assertThat(chatMemory.get(conversationId)).contains(message); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId)).isEmpty(); + } + + @Test + void nullConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemory.add(null, List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.add(null, new UserMessage("Hello"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.get(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.clear(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void emptyConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemory.add("", List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.add(null, new UserMessage("Hello"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.get("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.clear("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void nullMessagesNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> chatMemory.add(conversationId, (List) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot be null"); + } + + @Test + void nullMessageNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> chatMemory.add(conversationId, (Message) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("message cannot be null"); + } + + @Test + void messagesWithNullElementsNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + List messagesWithNull = new ArrayList<>(); + messagesWithNull.add(null); + + assertThatThrownBy(() -> chatMemory.add(conversationId, messagesWithNull)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot contain null elements"); + } + + @Test + void customMaxMessages() { + String conversationId = UUID.randomUUID().toString(); + int customMaxMessages = 2; + + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder() + .maxMessages(customMaxMessages) + .build(); + + List messages = List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"), + new UserMessage("Message 2"), new AssistantMessage("Response 2"), new UserMessage("Message 3")); + + customChatMemory.add(conversationId, messages); + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(2); + } + + @Test + void noEvictionWhenMessagesWithinLimit() { + int limit = 3; + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder().maxMessages(limit).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new UserMessage("Hello"), new AssistantMessage("Hi there"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>(List.of(new UserMessage("How are you?"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(limit); + assertThat(result).containsExactly(new UserMessage("Hello"), new AssistantMessage("Hi there"), + new UserMessage("How are you?")); + } + + @Test + void evictionWhenMessagesExceedLimit() { + int limit = 2; + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder().maxMessages(limit).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(limit); + assertThat(result).containsExactly(new UserMessage("Message 2"), new AssistantMessage("Response 2")); + } + + @Test + void systemMessageIsPreservedDuringEviction() { + int limit = 3; + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder().maxMessages(limit).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>(List.of(new SystemMessage("System instruction"), + new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(limit); + assertThat(result).containsExactly(new SystemMessage("System instruction"), new UserMessage("Message 2"), + new AssistantMessage("Response 2")); + } + + @Test + void multipleSystemMessagesArePreservedDuringEviction() { + int limit = 3; + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder().maxMessages(limit).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new SystemMessage("System instruction 1"), new SystemMessage("System instruction 2"), + new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(limit); + assertThat(result).containsExactly(new SystemMessage("System instruction 1"), + new SystemMessage("System instruction 2"), new AssistantMessage("Response 2")); + } + + @Test + void emptyMessageList() { + String conversationId = UUID.randomUUID().toString(); + + List result = this.chatMemory.get(conversationId); + + assertThat(result).isEmpty(); + } + + @Test + void oldSystemMessagesAreRemovedWhenNewOneAdded() { + int limit = 2; + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder().maxMessages(limit).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new SystemMessage("System instruction 1"), new SystemMessage("System instruction 2"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>(List.of(new SystemMessage("System instruction 3"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(1); + assertThat(result).containsExactly(new SystemMessage("System instruction 3")); + } + + @Test + void mixedMessagesWithLimitEqualToSystemMessageCount() { + int limit = 2; + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder().maxMessages(limit).build(); + + String conversationId = UUID.randomUUID().toString(); + List memoryMessages = new ArrayList<>( + List.of(new SystemMessage("System instruction 1"), new SystemMessage("System instruction 2"))); + customChatMemory.add(conversationId, memoryMessages); + + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + customChatMemory.add(conversationId, newMessages); + + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(2); + assertThat(result).containsExactly(new SystemMessage("System instruction 1"), + new SystemMessage("System instruction 2")); + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic/pom.xml index 6a80aa6cd39..0129169caf6 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai/pom.xml index e6a8fcae730..2b2a2bf128a 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-bedrock-converse/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-bedrock-converse/pom.xml index 6d9c5b93413..ac8512df0ac 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-bedrock-converse/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-bedrock-converse/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-cassandra/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-cassandra/pom.xml index 48f67bb8bc7..698974711a5 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-cassandra/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-cassandra/pom.xml @@ -42,6 +42,12 @@ spring-boot-starter + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + org.springframework.ai spring-ai-autoconfigure-model-chat-memory-cassandra diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml index 112be3cc248..5f510a08750 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml @@ -42,6 +42,12 @@ spring-boot-starter + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + org.springframework.ai spring-ai-autoconfigure-model-chat-memory-jdbc diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-neo4j/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-neo4j/pom.xml index bd6208ff4c7..63256cbbc36 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-neo4j/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-neo4j/pom.xml @@ -42,6 +42,12 @@ spring-boot-starter + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + + org.springframework.ai spring-ai-autoconfigure-model-chat-memory-neo4j diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-huggingface/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-huggingface/pom.xml index 7dd3ca95f5f..f984a23ab29 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-huggingface/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-huggingface/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-minimax/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-minimax/pom.xml index 177f5ddaca7..d8e554a8b2b 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-minimax/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-minimax/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-mistral-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-mistral-ai/pom.xml index d8d4c521fb7..3d30900fec0 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-mistral-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-mistral-ai/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-oci-genai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-oci-genai/pom.xml index dd910300f72..93ed73f86f9 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-oci-genai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-oci-genai/pom.xml @@ -58,5 +58,11 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-ollama/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-ollama/pom.xml index 627a7474b9b..8e981140f29 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-ollama/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-ollama/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-openai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-openai/pom.xml index 5195cf03510..f2d003c250e 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-openai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-openai/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini/pom.xml index 9a1b6cdb4cd..ea917658973 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-watsonx-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-watsonx-ai/pom.xml index d3014bfa361..af67e9fcd4a 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-watsonx-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-watsonx-ai/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai/pom.xml index cef98cdedbe..d5abba2b358 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai/pom.xml @@ -59,6 +59,12 @@ spring-ai-autoconfigure-model-chat-client ${project.parent.version} + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory + ${project.parent.version} +