diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 347128f0ea0..370d7f2eedb 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -17,6 +17,7 @@ package org.springframework.ai.mistralai; import java.util.List; +import java.util.Map; import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; @@ -41,6 +42,9 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.CODESTRAL_EMBED; +import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.EMBED; + /** * Provides the Mistral AI Embedding Model. * @@ -53,6 +57,9 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of(EMBED.getValue(), 1024, + CODESTRAL_EMBED.getValue(), 1536); + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); private final MistralAiEmbeddingOptions defaultOptions; @@ -78,8 +85,7 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) { } public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode) { - this(mistralAiApi, metadataMode, - MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(), + this(mistralAiApi, metadataMode, MistralAiEmbeddingOptions.builder().withModel(EMBED.getValue()).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } @@ -179,6 +185,11 @@ public float[] embed(Document document) { return this.embed(document.getFormattedContent(this.metadataMode)); } + @Override + public int dimensions() { + return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); + } + /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index e2702a3a0af..a5db5dbc08c 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -321,7 +321,8 @@ public String getName() { public enum EmbeddingModel { // @formatter:off - EMBED("mistral-embed"); + EMBED("mistral-embed"), + CODESTRAL_EMBED("codestral-embed"); // @formatter:on private final String value; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java index b9c91cca8a9..eb4c208800c 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,50 +16,76 @@ package org.springframework.ai.mistralai; -import java.util.List; - import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = MistralAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") class MistralAiEmbeddingIT { + private static final int MISTRAL_EMBED_DIMENSIONS = 1024; + + @Autowired + private MistralAiApi mistralAiApi; + @Autowired private MistralAiEmbeddingModel mistralAiEmbeddingModel; @Test void defaultEmbedding() { - assertThat(this.mistralAiEmbeddingModel).isNotNull(); - var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); + var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); - assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(MISTRAL_EMBED_DIMENSIONS); assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); - assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); + assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS); } - @Test - void embeddingTest() { - assertThat(this.mistralAiEmbeddingModel).isNotNull(); - var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest( - List.of("Hello World", "World is big"), - MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build())); + @ParameterizedTest + @CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" }) + void defaultOptionsEmbedding(String model, int dimensions) { + var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build(); + var anotherMistralAiEmbeddingModel = new MistralAiEmbeddingModel(mistralAiApi, mistralAiEmbeddingOptions); + var embeddingResponse = anotherMistralAiEmbeddingModel.embedForResponse(List.of("Hello World", "World is big")); assertThat(embeddingResponse.getResults()).hasSize(2); - assertThat(embeddingResponse.getResults().get(0)).isNotNull(); - assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); - assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); + embeddingResponse.getResults().forEach(result -> { + assertThat(result).isNotNull(); + assertThat(result.getOutput()).hasSize(dimensions); + }); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9); - assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); + assertThat(anotherMistralAiEmbeddingModel.dimensions()).isEqualTo(dimensions); + } + + @ParameterizedTest + @CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" }) + void calledOptionsEmbedding(String model, int dimensions) { + var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build(); + var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big", "We are small"), + mistralAiEmbeddingOptions); + var embeddingResponse = mistralAiEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(3); + embeddingResponse.getResults().forEach(result -> { + assertThat(result).isNotNull(); + assertThat(result.getOutput()).hasSize(dimensions); + }); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(14); + assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(14); + assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java index 26e6911daa8..ff7d4db80d4 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.ai.mistralai; -import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiModerationApi; import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; @@ -27,30 +26,28 @@ @SpringBootConfiguration public class MistralAiTestConfiguration { - @Bean - public MistralAiApi mistralAiApi() { + private static String retrieveApiKey() { var apiKey = System.getenv("MISTRAL_AI_API_KEY"); if (!StringUtils.hasText(apiKey)) { throw new IllegalArgumentException( "Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key."); } - return new MistralAiApi(apiKey); + return apiKey; + } + + @Bean + public MistralAiApi mistralAiApi() { + return new MistralAiApi(retrieveApiKey()); } @Bean public MistralAiModerationApi mistralAiModerationApi() { - var apiKey = System.getenv("MISTRAL_AI_API_KEY"); - if (!StringUtils.hasText(apiKey)) { - throw new IllegalArgumentException( - "Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key."); - } - return new MistralAiModerationApi(apiKey); + return new MistralAiModerationApi(retrieveApiKey()); } @Bean - public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) { - return new MistralAiEmbeddingModel(api, - MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build()); + public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiApi mistralAiApi) { + return new MistralAiEmbeddingModel(mistralAiApi); } @Bean