diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIExtensionTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIExtensionTests.cs index 0d6cab861ba3..009e728c6006 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIExtensionTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIExtensionTests.cs @@ -1,5 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Net.Http; +using System.Reflection; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -81,4 +84,58 @@ public void AddMistralTextEmbeddingGenerationToKernelBuilder() Assert.NotNull(service); Assert.IsType(service); } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void AddMistralChatCompletionInjectsExtraParametersHeader(bool useServiceCollection) + { + // Arrange + var collection = new ServiceCollection(); + var kernelBuilder = collection.AddKernel(); + + if (useServiceCollection) + { + // Use the service collection to add the Mistral chat completion + kernelBuilder.Services.AddMistralChatCompletion( + modelId: "model", + apiKey: "key", + endpoint: new Uri("https://example.com")); + } + else + { + // Use the kernel builder directly + kernelBuilder.AddMistralChatCompletion( + modelId: "model", + apiKey: "key", + endpoint: new Uri("https://example.com")); + } + + // Act + var kernel = collection.BuildServiceProvider().GetRequiredService(); + var service = kernel.GetRequiredService(); + + // Assert + Assert.NotNull(service); + Assert.IsType(service); + + // Use reflection to get the private 'Client' field + var clientField = typeof(MistralAIChatCompletionService) + .GetField("k__BackingField", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(clientField); + + var mistralClient = clientField.GetValue(service); + Assert.NotNull(mistralClient); + + // Use reflection to get the private '_httpClient' field from MistralClient + var httpClientField = mistralClient.GetType() + .GetField("_httpClient", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(httpClientField); + + var httpClient = (HttpClient)httpClientField.GetValue(mistralClient)!; + Assert.True(httpClient.DefaultRequestHeaders.Contains("extra-parameters")); + + var headerValues = httpClient.DefaultRequestHeaders.GetValues("extra-parameters"); + Assert.Contains("pass-through", headerValues); + } } diff --git a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs index 90e7e762d3c3..158c3d06eff7 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs @@ -2,12 +2,6 @@ using System; using System.Net.Http; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.MistralAI; -using Microsoft.SemanticKernel.Embeddings; -using Microsoft.SemanticKernel.Http; namespace Microsoft.SemanticKernel; @@ -38,8 +32,7 @@ public static IKernelBuilder AddMistralChatCompletion( Verify.NotNullOrWhiteSpace(modelId); Verify.NotNullOrWhiteSpace(apiKey); - builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => - new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService())); + builder.Services.AddMistralChatCompletion(modelId, apiKey, endpoint, serviceId, httpClient); return builder; } @@ -64,8 +57,7 @@ public static IKernelBuilder AddMistralTextEmbeddingGeneration( { Verify.NotNull(builder); - builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => - new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService())); + builder.Services.AddMistralTextEmbeddingGeneration(modelId, apiKey, endpoint, serviceId, httpClient); return builder; } diff --git a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIServiceCollectionExtensions.cs index a88aa49e7220..c838110e1dc1 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIServiceCollectionExtensions.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Net.Http; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.MistralAI; using Microsoft.SemanticKernel.Embeddings; @@ -22,18 +24,39 @@ public static class MistralAIServiceCollectionExtensions /// The API key required for accessing the Mistral service. /// Optional uri endpoint including the port where MistralAI server is hosted. Default is https://api.mistral.ai. /// A local identifier for the given AI service. + /// The HttpClient to use with this service. /// The same instance as . public static IServiceCollection AddMistralChatCompletion( this IServiceCollection services, string modelId, string apiKey, Uri? endpoint = null, - string? serviceId = null) + string? serviceId = null, + HttpClient? httpClient = null) { Verify.NotNull(services); - return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => - new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(serviceProvider))); + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + var resolvedHttpClient = HttpClientProvider.GetHttpClient(httpClient, serviceProvider); + + if (httpClient == null && serviceProvider?.GetService() == null) + { + if (!resolvedHttpClient.DefaultRequestHeaders.Contains("extra-parameters")) + { + resolvedHttpClient.DefaultRequestHeaders.Add("extra-parameters", "pass-through"); + } + } + + return new MistralAIChatCompletionService( + modelId, + apiKey, + endpoint, + resolvedHttpClient, + serviceProvider?.GetService()); + }); + + return services; } /// @@ -44,17 +67,19 @@ public static IServiceCollection AddMistralChatCompletion( /// The API key required for accessing the Mistral service. /// Optional uri endpoint including the port where MistralAI server is hosted. Default is https://api.mistral.ai. /// A local identifier for the given AI service. + /// The HttpClient to use with this service. /// The same instance as . public static IServiceCollection AddMistralTextEmbeddingGeneration( this IServiceCollection services, string modelId, string apiKey, Uri? endpoint = null, - string? serviceId = null) + string? serviceId = null, + HttpClient? httpClient = null) { Verify.NotNull(services); return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => - new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(serviceProvider))); + new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))); } }