Skip to content

Commit

Permalink
Add embedding support (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcominerva authored Oct 16, 2023
2 parents c4ae94e + 56cee4a commit 8c11ab0
Show file tree
Hide file tree
Showing 30 changed files with 386 additions and 64 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ jobs:
publish:
name: Publish on NuGet
runs-on: ubuntu-latest

if: ${{ github.repository_owner == vars.REPOSITORY_OWNER }}

steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ builder.Services.AddChatGpt(options =>
//options.UseAzure(resourceName: "", apiKey: "", authenticationType: AzureAuthenticationType.ApiKey);
options.DefaultModel = "my-model";
options.DefaultEmbeddingModel = "text-embedding-ada-002",
options.MessageLimit = 16; // Default: 10
options.MessageExpiration = TimeSpan.FromMinutes(5); // Default: 1 hour
});
Expand All @@ -54,9 +55,11 @@ builder.Services.AddChatGpt(options =>
- 2023-08-01-preview (default)
- _AuthenticationType_: it specifies if the key is an actual API Key or an [Azure Active Directory token](https://learn.microsoft.com/azure/cognitive-services/openai/how-to/managed-identity) (optional, default: "ApiKey").

### DefaultModel
### DefaultModel and DefaultEmbeddingModel

ChatGPT can be used with different models for chat completion, both on OpenAI and Azure OpenAI service. With the *DefaultModel* property, you can specify the default model that will be used, unless you pass an explicit value in the **AskAsync** method.
ChatGPT can be used with different models for chat completion, both on OpenAI and Azure OpenAI service. With the *DefaultModel* property, you can specify the default model that will be used, unless you pass an explicit value in the **AskAsync** or **AsyStreamAsync** methods.

Even if it is not a strictly necessary for chat conversation, the library supports also the Embedding API, on both [OpenAI](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) and [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings). As for chat completion, embeddings can be done with different models. With the *DefaultEmbeddingModel* property, you can specify the default model that will be used, unless you pass an explicit value in the **GetEmbeddingAsync** method.

##### OpenAI

Expand Down Expand Up @@ -128,6 +131,7 @@ The configuration can be automatically read from [IConfiguration](https://learn.
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values: ApiKey (default) or ActiveDirectory

"DefaultModel": "my-model",
"DefaultEmbeddingModel": "text-embedding-ada-002", // Optional, set it if you want to use embedding
"MessageLimit": 20,
"MessageExpiration": "00:30:00",
"ThrowExceptionOnError": true
Expand Down Expand Up @@ -433,6 +437,19 @@ Check out the [Function calling sample](https://github.com/marcominerva/ChatGptN

When using Azure OpenAI Service, we automatically get content filtering for free. For details about how it works, check out the [documentation](https://learn.microsoft.com/azure/ai-services/openai/concepts/content-filter). This information is returned for all scenarios when using API version `2023-06-01-preview` or later. **ChatGptNet** fully supports this object model by providing the corresponding properties in the [ChatGptResponse](https://github.com/marcominerva/ChatGptNet/blob/master/src/ChatGptNet/Models/ChatGptResponse.cs#L57) and [ChatGptChoice](https://github.com/marcominerva/ChatGptNet/blob/master/src/ChatGptNet/Models/ChatGptChoice.cs#L26) classes.

## Embeddings

[Embeddings](https://platform.openai.com/docs/guides/embeddings) allows to transform text into a vector space. This can be useful to compare the similarity of two sentences, for example. **ChatGptNet** fully supports this feature by providing the **GetEmbeddingAsync** method:

```csharp
var response = await chatGptClient.GenerateEmbeddingAsync(message);
var embeddings = response.GetEmbedding();
```

This code will give you a float array containing all the embeddings for the specified message. The length of the array depends on the model used. For example, if we use the _text-embedding-ada-002_ model, the array will contain 1536 elements.

If you need to calculate the [cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) between two embeddings, you can use the **EmbeddingUtility.CosineSimilarity** method.

## Documentation

The full technical documentation is available [here](https://github.com/marcominerva/ChatGptNet/tree/master/docs).
Expand Down
2 changes: 1 addition & 1 deletion samples/ChatGptApi/ChatGptApi.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="7.0.11" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="7.0.12" />
<PackageReference Include="MinimalHelpers.OpenApi" Version="1.0.4" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
</ItemGroup>
Expand Down
21 changes: 21 additions & 0 deletions samples/ChatGptApi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ await problemDetailsService.WriteAsync(new()
})
.WithOpenApi();

app.MapPost("/api/embeddings", async (EmbeddingRequest request, IChatGptClient chatGptClient) =>
{
var embeddingResponse = await chatGptClient.GenerateEmbeddingAsync(request.Message);
return TypedResults.Ok(embeddingResponse);
})
.WithOpenApi();

app.MapPost("/api/embeddings/CosineSimilarity", async (CosineSimilarityRequest request, IChatGptClient chatGptClient) =>
{
var firstEmbeddingResponse = await chatGptClient.GenerateEmbeddingAsync(request.FirstMessage);
var secondEmbeddingResponse = await chatGptClient.GenerateEmbeddingAsync(request.SecondMessage);

var similarity = firstEmbeddingResponse.CosineSimilarity(secondEmbeddingResponse);
return TypedResults.Ok(new { CosineSimilarity = similarity });
})
.WithOpenApi();

app.Run();

public record class Request(Guid ConversationId, string Message);

public record class EmbeddingRequest(string Message);

public record class CosineSimilarityRequest(string FirstMessage, string SecondMessage);
3 changes: 2 additions & 1 deletion samples/ChatGptApi/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
//"Organization": "", // Optional, used only by OpenAI
"ResourceName": "", // Required when using Azure OpenAI Service
"ApiVersion": "2023-08-01-preview", // Optional, used only by Azure OpenAI Service (default: 2023-08-01-preview)
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values : ApiKey (default) or ActiveDirectory
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values: ApiKey (default) or ActiveDirectory

"DefaultModel": "my-model",
"DefaultEmbeddingModel": "text-embedding-ada-002", // Optional, set it if you want to use embeddings
"MessageLimit": 20,
"MessageExpiration": "00:30:00",
"ThrowExceptionOnError": true
Expand Down
4 changes: 2 additions & 2 deletions samples/ChatGptBlazor.Wasm/ChatGptBlazor.Wasm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

<ItemGroup>
<PackageReference Include="Markdig" Version="0.33.0" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly" Version="7.0.11" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly.DevServer" Version="7.0.11" PrivateAssets="all" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly" Version="7.0.12" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly.DevServer" Version="7.0.12" PrivateAssets="all" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion samples/ChatGptBlazor.Wasm/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

options.DefaultModel = "my-model";
options.MessageLimit = 16; // Default: 10
options.MessageExpiration = TimeSpan.FromMinutes(5); // Default: 1 hour
options.MessageExpiration = TimeSpan.FromMinutes(5); // Default: 1 hour
});

await builder.Build().RunAsync();
1 change: 1 addition & 0 deletions samples/ChatGptConsole/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values: ApiKey (default) or ActiveDirectory

"DefaultModel": "my-model",
"DefaultEmbeddingModel": "text-embedding-ada-002", // Optional, it set if you want to use embeddings
"MessageLimit": 20,
"MessageExpiration": "00:30:00",
"ThrowExceptionOnError": true
Expand Down
1 change: 1 addition & 0 deletions samples/ChatGptFunctionCallingConsole/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values: ApiKey (default) or ActiveDirectory

"DefaultModel": "gpt-3.5-turbo",
"DefaultEmbeddingModel": "text-embedding-ada-002", // Optional, set it if you want to use embeddings
"MessageLimit": 20,
"MessageExpiration": "00:30:00",
"ThrowExceptionOnError": true
Expand Down
3 changes: 2 additions & 1 deletion samples/ChatGptStreamConsole/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
//"Organization": "", // Optional, used only by OpenAI
"ResourceName": "", // Required when using Azure OpenAI Service
"ApiVersion": "2023-08-01-preview", // Optional, used only by Azure OpenAI Service (default: 2023-08-01-preview)
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values : ApiKey (default) or ActiveDirectory
"AuthenticationType": "ApiKey", // Optional, used only by Azure OpenAI Service. Allowed values: ApiKey (default) or ActiveDirectory

"DefaultModel": "my-model",
"DefaultEmbeddingModel": "text-embedding-ada-002", // Optional, set it if you want to use embeddings
"MessageLimit": 20,
"MessageExpiration": "00:30:00",
"ThrowExceptionOnError": true
Expand Down
43 changes: 38 additions & 5 deletions src/ChatGptNet/ChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Text.Json.Serialization;
using ChatGptNet.Exceptions;
using ChatGptNet.Models;
using ChatGptNet.Models.Common;
using ChatGptNet.Models.Embeddings;

namespace ChatGptNet;

Expand Down Expand Up @@ -61,9 +63,9 @@ public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message,
conversationId = (conversationId == Guid.Empty) ? Guid.NewGuid() : conversationId;

var messages = await CreateMessageListAsync(conversationId, message, cancellationToken);
var request = CreateRequest(messages, functionParameters, false, parameters, model);
var request = CreateChatGptRequest(messages, functionParameters, false, parameters, model);

var requestUri = options.ServiceConfiguration.GetServiceEndpoint(model ?? options.DefaultModel);
var requestUri = options.ServiceConfiguration.GetChatCompletionEndpoint(model ?? options.DefaultModel);
using var httpResponse = await httpClient.PostAsJsonAsync(requestUri, request, jsonSerializerOptions, cancellationToken);

var response = await httpResponse.Content.ReadFromJsonAsync<ChatGptResponse>(jsonSerializerOptions, cancellationToken: cancellationToken);
Expand Down Expand Up @@ -93,9 +95,9 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
conversationId = (conversationId == Guid.Empty) ? Guid.NewGuid() : conversationId;

var messages = await CreateMessageListAsync(conversationId, message, cancellationToken);
var request = CreateRequest(messages, null, true, parameters, model);
var request = CreateChatGptRequest(messages, null, true, parameters, model);

var requestUri = options.ServiceConfiguration.GetServiceEndpoint(model ?? options.DefaultModel);
var requestUri = options.ServiceConfiguration.GetChatCompletionEndpoint(model ?? options.DefaultModel);
using var requestMessage = new HttpRequestMessage(HttpMethod.Post, requestUri)
{
Content = new StringContent(JsonSerializer.Serialize(request, jsonSerializerOptions), Encoding.UTF8, MediaTypeNames.Application.Json)
Expand Down Expand Up @@ -286,6 +288,26 @@ public async Task AddFunctionResponseAsync(Guid conversationId, string functionN
await UpdateCacheAsync(conversationId, messages, cancellationToken);
}

public async Task<EmbeddingResponse> GenerateEmbeddingAsync(IEnumerable<string> messages, string? model = null, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(messages);

var request = CreateEmbeddingRequest(messages, model);

var requestUri = options.ServiceConfiguration.GetEmbeddingEndpoint(model ?? options.DefaultEmbeddingModel);
using var httpResponse = await httpClient.PostAsJsonAsync(requestUri, request, jsonSerializerOptions, cancellationToken);

var response = await httpResponse.Content.ReadFromJsonAsync<EmbeddingResponse>(jsonSerializerOptions, cancellationToken: cancellationToken);
NormalizeResponse(httpResponse, response!, model ?? options.DefaultEmbeddingModel);

if (!response!.IsSuccessful && options.ThrowExceptionOnError)
{
throw new EmbeddingException(response.Error, httpResponse.StatusCode);
}

return response;
}

private async Task<IList<ChatGptMessage>> CreateMessageListAsync(Guid conversationId, string message, CancellationToken cancellationToken = default)
{
// Checks whether a list of messages for the given conversationId already exists.
Expand All @@ -301,7 +323,7 @@ private async Task<IList<ChatGptMessage>> CreateMessageListAsync(Guid conversati
return messages;
}

private ChatGptRequest CreateRequest(IEnumerable<ChatGptMessage> messages, ChatGptFunctionParameters? functionParameters, bool stream, ChatGptParameters? parameters, string? model)
private ChatGptRequest CreateChatGptRequest(IEnumerable<ChatGptMessage> messages, ChatGptFunctionParameters? functionParameters, bool stream, ChatGptParameters? parameters, string? model)
=> new()
{
Model = model ?? options.DefaultModel,
Expand All @@ -322,6 +344,13 @@ private ChatGptRequest CreateRequest(IEnumerable<ChatGptMessage> messages, ChatG
User = options.User,
};

private EmbeddingRequest CreateEmbeddingRequest(IEnumerable<string> messages, string? model = null)
=> new()
{
Model = model ?? options.DefaultEmbeddingModel,
Input = messages
};

private async Task AddAssistantResponseAsync(Guid conversationId, IList<ChatGptMessage> messages, ChatGptMessage? message, CancellationToken cancellationToken = default)
{
if (!string.IsNullOrWhiteSpace(message?.Content?.Trim()) || message?.FunctionCall is not null)
Expand Down Expand Up @@ -359,7 +388,11 @@ private async Task UpdateCacheAsync(Guid conversationId, IEnumerable<ChatGptMess
private static void NormalizeResponse(HttpResponseMessage httpResponse, ChatGptResponse response, Guid conversationId, string? model)
{
response.ConversationId = conversationId;
NormalizeResponse(httpResponse, response, model);
}

private static void NormalizeResponse(HttpResponseMessage httpResponse, Response response, string? model)
{
if (string.IsNullOrWhiteSpace(response.Model) && model is not null)
{
response.Model = model;
Expand Down
2 changes: 1 addition & 1 deletion src/ChatGptNet/ChatGptNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<PackageLicenseExpression>MIT</PackageLicenseExpression>
<PackageProjectUrl>https://github.com/marcominerva/ChatGptNet</PackageProjectUrl>
<PackageIcon>Chat.png</PackageIcon>
<PackageTags>csharp dotnet net openai openai-api azure-openai azure-openai-api chatgpt</PackageTags>
<PackageTags>csharp dotnet net openai openai-api azure-openai azure-openai-api chatgpt embeddings</PackageTags>
<RepositoryType>git</RepositoryType>
<RepositoryUrl>https://github.com/marcominerva/ChatGptNet.git</RepositoryUrl>
<RepositoryBranch>master</RepositoryBranch>
Expand Down
7 changes: 7 additions & 0 deletions src/ChatGptNet/ChatGptOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ public class ChatGptOptions
/// <seealso cref="OpenAIChatGptServiceConfiguration"/>
public string? DefaultModel { get; set; }

/// <summary>
/// Gets or sets the default model for embedding. (default: <see cref="OpenAIEmbeddingModels.TextEmbeddingAda002"/> when the provider is <see cref="OpenAIChatGptServiceConfiguration"> OpenAI</see>).
/// </summary>
/// <seealso cref="OpenAIEmbeddingModels"/>
/// <seealso cref="OpenAIChatGptServiceConfiguration"/>
public string? DefaultEmbeddingModel { get; set; }

/// <summary>
/// Gets or sets the default parameters for chat completion.
/// </summary>
Expand Down
8 changes: 8 additions & 0 deletions src/ChatGptNet/ChatGptOptionsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ public class ChatGptOptionsBuilder
/// <seealso cref="OpenAIChatGptServiceConfiguration"/>
public string? DefaultModel { get; set; }

/// <summary>
/// Gets or sets the default model for embedding. (default: <see cref="OpenAIEmbeddingModels.TextEmbeddingAda002"/> when the provider is <see cref="OpenAIChatGptServiceConfiguration"> OpenAI</see>).
/// </summary>
/// <seealso cref="OpenAIEmbeddingModels"/>
/// <seealso cref="OpenAIChatGptServiceConfiguration"/>
public string? DefaultEmbeddingModel { get; set; }

/// <summary>
/// Gets or sets the default parameters for chat completion.
/// </summary>
Expand All @@ -60,6 +67,7 @@ internal ChatGptOptions Build()
{
MessageLimit = MessageLimit,
DefaultModel = DefaultModel,
DefaultEmbeddingModel = DefaultEmbeddingModel,
DefaultParameters = DefaultParameters ?? new(),
MessageExpiration = MessageExpiration,
ThrowExceptionOnError = ThrowExceptionOnError,
Expand Down
14 changes: 11 additions & 3 deletions src/ChatGptNet/ChatGptServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,18 @@ private static IChatGptBuilder AddChatGptCore(IServiceCollection services)

private static void SetMissingDefaults(ChatGptOptionsBuilder options)
{
// If the provider is OpenAI and no default model has been specified, uses gpt-3.5-turbo by default.
if (options.ServiceConfiguration is OpenAIChatGptServiceConfiguration && string.IsNullOrWhiteSpace(options.DefaultModel))
if (options.ServiceConfiguration is OpenAIChatGptServiceConfiguration)
{
options.DefaultModel = OpenAIChatGptModels.Gpt35Turbo;
// If the provider is OpenAI and some default models are not specified, use the default ones.
if (string.IsNullOrWhiteSpace(options.DefaultModel))
{
options.DefaultModel = OpenAIChatGptModels.Gpt35Turbo;
}

if (string.IsNullOrWhiteSpace(options.DefaultEmbeddingModel))
{
options.DefaultEmbeddingModel = OpenAIEmbeddingModels.TextEmbeddingAda002;
}
}
}
}
2 changes: 1 addition & 1 deletion src/ChatGptNet/Exceptions/ChatGptException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace ChatGptNet.Exceptions;

/// <summary>
/// Represents errors that occur during API invocation.
/// Represents errors that occur during ChatGPT API invocation.
/// </summary>
public class ChatGptException : HttpRequestException
{
Expand Down
32 changes: 32 additions & 0 deletions src/ChatGptNet/Exceptions/EmbeddingException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System.Net;
using ChatGptNet.Models;

namespace ChatGptNet.Exceptions;

/// <summary>
/// Represents errors that occur during Embeddings API invocation.
/// </summary>
public class EmbeddingException : HttpRequestException
{
/// <summary>
/// Gets the detailed error information.
/// </summary>
/// <seealso cref="ChatGptError"/>
public ChatGptError Error { get; }

/// <summary>
/// Initializes a new instance of the <see cref="EmbeddingException"/> class with the specified <paramref name="error"/> details.
/// </summary>
/// <param name="error">The detailed error information</param>
/// <param name="statusCode">The HTTP status code</param>
/// <seealso cref="ChatGptError"/>
/// <seealso cref="HttpRequestException"/>
public EmbeddingException(ChatGptError? error, HttpStatusCode statusCode) : base(!string.IsNullOrWhiteSpace(error?.Message) ? error.Message : error?.Code ?? statusCode.ToString(), null, statusCode)
{
Error = error ?? new ChatGptError
{
Message = statusCode.ToString(),
Code = ((int)statusCode).ToString()
};
}
}
Loading

0 comments on commit 8c11ab0

Please sign in to comment.