Skip to content

Commit

Permalink
.Net: Add BERT ONNX embedding generation service (microsoft#5518)
Browse files Browse the repository at this point in the history
Adds a new Microsoft.SemanticKernel.Connectors.Onnx component. As of
this PR, it contains one service,
BertOnnxTextEmbeddingGenerationService, for using BERT-based ONNX models
to generate embeddings. But in time we can add more ONNX-based
implementations for using local models.

This is in part based on
https://onnxruntime.ai/docs/tutorials/csharp/bert-nlp-csharp-console-app.html
and https://github.com/dotnet-smartcomponents/smartcomponents. It
doesn't support everything that's supported via sentence-transformers,
but we should be able to extend it as needed.

cc: @luisquintanilla, @SteveSandersonMS, @JakeRadMSFT
  • Loading branch information
stephentoub authored Mar 20, 2024
1 parent 20748d9 commit 0050e2c
Show file tree
Hide file tree
Showing 12 changed files with 996 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
<PackageVersion Include="Microsoft.CodeAnalysis.Common" Version="4.3.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.3.0" />
<PackageVersion Include="Microsoft.Bcl.TimeProvider" Version="8.0.1" />
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.17.1" />
<PackageVersion Include="FastBertTokenizer" Version="0.4.67" />
<PackageVersion Include="System.Diagnostics.DiagnosticSource" Version="8.0.0" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="System.Memory.Data" Version="8.0.0" />
Expand Down
18 changes: 18 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Functions.Grpc", "src\Funct
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.HuggingFace", "src\Connectors\Connectors.HuggingFace\Connectors.HuggingFace.csproj", "{136823BE-8665-4D57-87E0-EF41535539E2}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Onnx", "src\Connectors\Connectors.Onnx\Connectors.Onnx.csproj", "{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "InternalUtilities", "InternalUtilities", "{4D3DAE63-41C6-4E1C-A35A-E77BDFC40675}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Weaviate", "src\Connectors\Connectors.Memory.Weaviate\Connectors.Memory.Weaviate.csproj", "{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}"
Expand Down Expand Up @@ -228,6 +230,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HomeAutomation", "samples\H
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HuggingFaceImageTextExample", "samples\HuggingFaceImageTextExample\HuggingFaceImageTextExample.csproj", "{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Onnx.UnitTests", "src\Connectors\Connectors.Onnx.UnitTests\Connectors.Onnx.UnitTests.csproj", "{D06465FA-0308-494C-920B-D502DA5690CB}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -372,6 +376,12 @@ Global
{136823BE-8665-4D57-87E0-EF41535539E2}.Publish|Any CPU.Build.0 = Publish|Any CPU
{136823BE-8665-4D57-87E0-EF41535539E2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{136823BE-8665-4D57-87E0-EF41535539E2}.Release|Any CPU.Build.0 = Release|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Publish|Any CPU.ActiveCfg = Publish|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Publish|Any CPU.Build.0 = Publish|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Release|Any CPU.Build.0 = Release|Any CPU
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}.Publish|Any CPU.ActiveCfg = Publish|Any CPU
Expand Down Expand Up @@ -533,6 +543,12 @@ Global
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}.Publish|Any CPU.Build.0 = Debug|Any CPU
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}.Release|Any CPU.Build.0 = Release|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Publish|Any CPU.Build.0 = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -565,6 +581,7 @@ Global
{4D226C2F-AE9F-4EFB-AF2D-45C8FE5CB34E} = {24503383-A8C4-4255-9998-28D70FE8E99A}
{E52F805C-794A-4CA9-B684-DFF358B18820} = {9ECD1AA0-75B3-4E25-B0B5-9F0945B64974}
{136823BE-8665-4D57-87E0-EF41535539E2} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
{4D3DAE63-41C6-4E1C-A35A-E77BDFC40675} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0}
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4} = {24503383-A8C4-4255-9998-28D70FE8E99A}
{50FAE231-6F24-4779-9D02-12ABBC9A49E2} = {24503383-A8C4-4255-9998-28D70FE8E99A}
Expand Down Expand Up @@ -610,6 +627,7 @@ Global
{1F96837A-61EC-4C8F-904A-07BEBD05FDEE} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
{13429BD6-4C4E-45EC-81AD-30BAC380AA60} = {FA3720F1-C99A-49B2-9577-A940257098BF}
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4} = {FA3720F1-C99A-49B2-9577-A940257098BF}
{D06465FA-0308-494C-920B-D502DA5690CB} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83}
Expand Down
1 change: 1 addition & 0 deletions dotnet/docs/EXPERIMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part
| SKEXP0070 | Ollama AI connector | | | | | |
| SKEXP0070 | Gemini AI connector | | | | | |
| SKEXP0070 | Mistral AI connector | | | | | |
| SKEXP0070 | ONNX AI connector | | | | | |
| | | | | | | |
| SKEXP0101 | Experiment with Assistants | | | | | |
| SKEXP0101 | Experiment with Flow Orchestration | | | | | |

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<AssemblyName>SemanticKernel.Connectors.Onnx.UnitTests</AssemblyName>
<RootNamespace>SemanticKernel.Connectors.Onnx.UnitTests</RootNamespace>
<TargetFramework>net6.0</TargetFramework>
<LangVersion>12</LangVersion>
<RollForward>LatestMajor</RollForward>
<IsTestProject>true</IsTestProject>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);SKEXP0001;SKEXP0070;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1849;CA1861;CA2007;CA2234;VSTHRD111</NoWarn>
</PropertyGroup>

<ItemGroup>
<!-- Use newest available compiler to permit LangVersion 12. -->
<!-- This can be removed once we no longer target the .NET 6 SDK in CI. -->
<PackageReference Include="Microsoft.Net.Compilers.Toolset" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Moq" />
<PackageReference Include="xunit" />
<PackageReference Include="xunit.runner.visualstudio">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="System.Numerics.Tensors" />
<PackageReference Include="System.Text.Json" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
</ItemGroup>

<ItemGroup>
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/test/AssertExtensions.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
<ProjectReference Include="..\Connectors.Onnx\Connectors.Onnx.csproj" />
</ItemGroup>

</Project>
6 changes: 6 additions & 0 deletions dotnet/src/Connectors/Connectors.Onnx/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;

// This assembly is currently experimental.
[assembly: Experimental("SKEXP0070")]
101 changes: 101 additions & 0 deletions dotnet/src/Connectors/Connectors.Onnx/BertOnnxOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text;

namespace Microsoft.SemanticKernel.Connectors.Onnx;

/// <summary>Provides an options bag used to configure <see cref="BertOnnxTextEmbeddingGenerationService"/>.</summary>
public sealed class BertOnnxOptions
{
private int _maximumTokens = 512;
private string _clsToken = "[CLS]";
private string _unknownToken = "[UNK]";
private string _sepToken = "[SEP]";
private string _padToken = "[PAD]";
private EmbeddingPoolingMode _poolingMode = EmbeddingPoolingMode.Mean;

/// <summary>Gets or sets whether the vocabulary employed by the model is case-sensitive.</summary>
public bool CaseSensitive { get; init; } = false;

/// <summary>Gets or sets the maximum number of tokens to encode. Defaults to 512.</summary>
public int MaximumTokens
{
get => this._maximumTokens;
init
{
if (value < 1)
{
throw new ArgumentOutOfRangeException(nameof(this.MaximumTokens));
}

this._maximumTokens = value;
}
}

/// <summary>Gets or sets the cls token. Defaults to "[CLS]".</summary>
public string ClsToken
{
get => this._clsToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._clsToken = value;
}
}

/// <summary>Gets or sets the unknown token. Defaults to "[UNK]".</summary>
public string UnknownToken
{
get => this._unknownToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._unknownToken = value;
}
}

/// <summary>Gets or sets the sep token. Defaults to "[SEP]".</summary>
public string SepToken
{
get => this._sepToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._sepToken = value;
}
}

/// <summary>Gets or sets the pad token. Defaults to "[PAD]".</summary>
public string PadToken
{
get => this._padToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._padToken = value;
}
}

/// <summary>Gets or sets the type of Unicode normalization to perform on input text. Defaults to <see cref="NormalizationForm.FormD"/>.</summary>
public NormalizationForm UnicodeNormalization { get; init; } = NormalizationForm.FormD;

/// <summary>Gets or sets the pooling mode to use when generating the fixed-length embedding result. Defaults to "mean".</summary>
public EmbeddingPoolingMode PoolingMode
{
get => this._poolingMode;
init
{
if (value is not (EmbeddingPoolingMode.Max or EmbeddingPoolingMode.Mean or EmbeddingPoolingMode.MeanSquareRootTokensLength))
{
throw new ArgumentOutOfRangeException(nameof(this.PoolingMode));
}

this._poolingMode = value;
}
}

/// <summary>Gets or sets whether the resulting embedding vectors should be explicitly normalized. Defaults to false.</summary>
/// <remarks>Normalized embeddings may be compared more efficiently, such as by using a dot product rather than cosine similarity.</remarks>
public bool NormalizeEmbeddings { get; set; } = false;
}
Loading

0 comments on commit 0050e2c

Please sign in to comment.