Skip to content
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -81,4 +84,58 @@ public void AddMistralTextEmbeddingGenerationToKernelBuilder()
Assert.NotNull(service);
Assert.IsType<MistralAITextEmbeddingGenerationService>(service);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void AddMistralChatCompletionInjectsExtraParametersHeader(bool useServiceCollection)
{
// Arrange
var collection = new ServiceCollection();
var kernelBuilder = collection.AddKernel();

if (useServiceCollection)
{
// Use the service collection to add the Mistral chat completion
kernelBuilder.Services.AddMistralChatCompletion(
modelId: "model",
apiKey: "key",
endpoint: new Uri("https://example.com"));
}
else
{
// Use the kernel builder directly
kernelBuilder.AddMistralChatCompletion(
modelId: "model",
apiKey: "key",
endpoint: new Uri("https://example.com"));
}

// Act
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
var service = kernel.GetRequiredService<IChatCompletionService>();

// Assert
Assert.NotNull(service);
Assert.IsType<MistralAIChatCompletionService>(service);

// Use reflection to get the private 'Client' field
var clientField = typeof(MistralAIChatCompletionService)
.GetField("<Client>k__BackingField", BindingFlags.NonPublic | BindingFlags.Instance);
Assert.NotNull(clientField);

var mistralClient = clientField.GetValue(service);
Assert.NotNull(mistralClient);

// Use reflection to get the private '_httpClient' field from MistralClient
var httpClientField = mistralClient.GetType()
.GetField("_httpClient", BindingFlags.NonPublic | BindingFlags.Instance);
Assert.NotNull(httpClientField);

var httpClient = (HttpClient)httpClientField.GetValue(mistralClient)!;
Assert.True(httpClient.DefaultRequestHeaders.Contains("extra-parameters"));

var headerValues = httpClient.DefaultRequestHeaders.GetValues("extra-parameters");
Assert.Contains("pass-through", headerValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

using System;
using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.MistralAI;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Http;

namespace Microsoft.SemanticKernel;

Expand Down Expand Up @@ -38,8 +32,7 @@ public static IKernelBuilder AddMistralChatCompletion(
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService<ILoggerFactory>()));
builder.Services.AddMistralChatCompletion(modelId, apiKey, endpoint, serviceId, httpClient);

return builder;
}
Expand All @@ -64,8 +57,7 @@ public static IKernelBuilder AddMistralTextEmbeddingGeneration(
{
Verify.NotNull(builder);

builder.Services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService<ILoggerFactory>()));
builder.Services.AddMistralTextEmbeddingGeneration(modelId, apiKey, endpoint, serviceId, httpClient);

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.MistralAI;
using Microsoft.SemanticKernel.Embeddings;
Expand All @@ -22,18 +24,39 @@ public static class MistralAIServiceCollectionExtensions
/// <param name="apiKey">The API key required for accessing the Mistral service.</param>
/// <param name="endpoint">Optional uri endpoint including the port where MistralAI server is hosted. Default is https://api.mistral.ai.</param>
/// <param name="serviceId">A local identifier for the given AI service.</param>
/// <param name="httpClient">The HttpClient to use with this service.</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddMistralChatCompletion(
this IServiceCollection services,
string modelId,
string apiKey,
Uri? endpoint = null,
string? serviceId = null)
string? serviceId = null,
HttpClient? httpClient = null)
{
Verify.NotNull(services);

return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(serviceProvider)));
services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
var resolvedHttpClient = HttpClientProvider.GetHttpClient(httpClient, serviceProvider);

if (httpClient == null && serviceProvider?.GetService<HttpClient>() == null)
{
if (!resolvedHttpClient.DefaultRequestHeaders.Contains("extra-parameters"))
{
resolvedHttpClient.DefaultRequestHeaders.Add("extra-parameters", "pass-through");
}
}

return new MistralAIChatCompletionService(
modelId,
apiKey,
endpoint,
resolvedHttpClient,
serviceProvider?.GetService<ILoggerFactory>());
});

return services;
}

/// <summary>
Expand All @@ -44,17 +67,19 @@ public static IServiceCollection AddMistralChatCompletion(
/// <param name="apiKey">The API key required for accessing the Mistral service.</param>
/// <param name="endpoint">Optional uri endpoint including the port where MistralAI server is hosted. Default is https://api.mistral.ai.</param>
/// <param name="serviceId">A local identifier for the given AI service.</param>
/// <param name="httpClient">The HttpClient to use with this service.</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddMistralTextEmbeddingGeneration(
this IServiceCollection services,
string modelId,
string apiKey,
Uri? endpoint = null,
string? serviceId = null)
string? serviceId = null,
HttpClient? httpClient = null)
{
Verify.NotNull(services);

return services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(serviceProvider)));
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)));
}
}
Loading