forked from microsoft/semantic-kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
.Net: Add BERT ONNX embedding generation service (microsoft#5518)
Adds a new Microsoft.SemanticKernel.Connectors.Onnx component. As of this PR, it contains one service, BertOnnxTextEmbeddingGenerationService, for using BERT-based ONNX models to generate embeddings. But in time we can add more ONNX-based implementations for using local models. This is in part based on https://onnxruntime.ai/docs/tutorials/csharp/bert-nlp-csharp-console-app.html and https://github.com/dotnet-smartcomponents/smartcomponents. It doesn't support everything that's supported via sentence-transformers, but we should be able to extend it as needed. cc: @luisquintanilla, @SteveSandersonMS, @JakeRadMSFT
- Loading branch information
1 parent
20748d9
commit 0050e2c
Showing
12 changed files
with
996 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
382 changes: 382 additions & 0 deletions
382
dotnet/src/Connectors/Connectors.Onnx.UnitTests/BertOnnxTextEmbeddingGenerationTests.cs
Large diffs are not rendered by default.
Oops, something went wrong.
47 changes: 47 additions & 0 deletions
47
dotnet/src/Connectors/Connectors.Onnx.UnitTests/Connectors.Onnx.UnitTests.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<AssemblyName>SemanticKernel.Connectors.Onnx.UnitTests</AssemblyName> | ||
<RootNamespace>SemanticKernel.Connectors.Onnx.UnitTests</RootNamespace> | ||
<TargetFramework>net6.0</TargetFramework> | ||
<LangVersion>12</LangVersion> | ||
<RollForward>LatestMajor</RollForward> | ||
<IsTestProject>true</IsTestProject> | ||
<Nullable>enable</Nullable> | ||
<IsPackable>false</IsPackable> | ||
<NoWarn>$(NoWarn);SKEXP0001;SKEXP0070;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1849;CA1861;CA2007;CA2234;VSTHRD111</NoWarn> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<!-- Use newest available compiler to permit LangVersion 12. --> | ||
<!-- This can be removed once we no longer target the .NET 6 SDK in CI. --> | ||
<PackageReference Include="Microsoft.Net.Compilers.Toolset" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.NET.Test.Sdk" /> | ||
<PackageReference Include="Moq" /> | ||
<PackageReference Include="xunit" /> | ||
<PackageReference Include="xunit.runner.visualstudio"> | ||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||
<PrivateAssets>all</PrivateAssets> | ||
</PackageReference> | ||
<PackageReference Include="coverlet.collector"> | ||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||
<PrivateAssets>all</PrivateAssets> | ||
</PackageReference> | ||
<PackageReference Include="System.Numerics.Tensors" /> | ||
<PackageReference Include="System.Text.Json" /> | ||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/test/AssertExtensions.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" /> | ||
<ProjectReference Include="..\Connectors.Onnx\Connectors.Onnx.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Diagnostics.CodeAnalysis; | ||
|
||
// This assembly is currently experimental. | ||
[assembly: Experimental("SKEXP0070")] |
101 changes: 101 additions & 0 deletions
101
dotnet/src/Connectors/Connectors.Onnx/BertOnnxOptions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System; | ||
using System.Text; | ||
|
||
namespace Microsoft.SemanticKernel.Connectors.Onnx; | ||
|
||
/// <summary>Provides an options bag used to configure <see cref="BertOnnxTextEmbeddingGenerationService"/>.</summary> | ||
public sealed class BertOnnxOptions | ||
{ | ||
private int _maximumTokens = 512; | ||
private string _clsToken = "[CLS]"; | ||
private string _unknownToken = "[UNK]"; | ||
private string _sepToken = "[SEP]"; | ||
private string _padToken = "[PAD]"; | ||
private EmbeddingPoolingMode _poolingMode = EmbeddingPoolingMode.Mean; | ||
|
||
/// <summary>Gets or sets whether the vocabulary employed by the model is case-sensitive.</summary> | ||
public bool CaseSensitive { get; init; } = false; | ||
|
||
/// <summary>Gets or sets the maximum number of tokens to encode. Defaults to 512.</summary> | ||
public int MaximumTokens | ||
{ | ||
get => this._maximumTokens; | ||
init | ||
{ | ||
if (value < 1) | ||
{ | ||
throw new ArgumentOutOfRangeException(nameof(this.MaximumTokens)); | ||
} | ||
|
||
this._maximumTokens = value; | ||
} | ||
} | ||
|
||
/// <summary>Gets or sets the cls token. Defaults to "[CLS]".</summary> | ||
public string ClsToken | ||
{ | ||
get => this._clsToken; | ||
init | ||
{ | ||
Verify.NotNullOrWhiteSpace(value); | ||
this._clsToken = value; | ||
} | ||
} | ||
|
||
/// <summary>Gets or sets the unknown token. Defaults to "[UNK]".</summary> | ||
public string UnknownToken | ||
{ | ||
get => this._unknownToken; | ||
init | ||
{ | ||
Verify.NotNullOrWhiteSpace(value); | ||
this._unknownToken = value; | ||
} | ||
} | ||
|
||
/// <summary>Gets or sets the sep token. Defaults to "[SEP]".</summary> | ||
public string SepToken | ||
{ | ||
get => this._sepToken; | ||
init | ||
{ | ||
Verify.NotNullOrWhiteSpace(value); | ||
this._sepToken = value; | ||
} | ||
} | ||
|
||
/// <summary>Gets or sets the pad token. Defaults to "[PAD]".</summary> | ||
public string PadToken | ||
{ | ||
get => this._padToken; | ||
init | ||
{ | ||
Verify.NotNullOrWhiteSpace(value); | ||
this._padToken = value; | ||
} | ||
} | ||
|
||
/// <summary>Gets or sets the type of Unicode normalization to perform on input text. Defaults to <see cref="NormalizationForm.FormD"/>.</summary> | ||
public NormalizationForm UnicodeNormalization { get; init; } = NormalizationForm.FormD; | ||
|
||
/// <summary>Gets or sets the pooling mode to use when generating the fixed-length embedding result. Defaults to "mean".</summary> | ||
public EmbeddingPoolingMode PoolingMode | ||
{ | ||
get => this._poolingMode; | ||
init | ||
{ | ||
if (value is not (EmbeddingPoolingMode.Max or EmbeddingPoolingMode.Mean or EmbeddingPoolingMode.MeanSquareRootTokensLength)) | ||
{ | ||
throw new ArgumentOutOfRangeException(nameof(this.PoolingMode)); | ||
} | ||
|
||
this._poolingMode = value; | ||
} | ||
} | ||
|
||
/// <summary>Gets or sets whether the resulting embedding vectors should be explicitly normalized. Defaults to false.</summary> | ||
/// <remarks>Normalized embeddings may be compared more efficiently, such as by using a dot product rather than cosine similarity.</remarks> | ||
public bool NormalizeEmbeddings { get; set; } = false; | ||
} |
Oops, something went wrong.