Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add BERT ONNX embedding generation service #5518

Merged
merged 8 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
<PackageVersion Include="Microsoft.CodeAnalysis.Common" Version="4.3.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.3.0" />
<PackageVersion Include="Microsoft.Bcl.TimeProvider" Version="8.0.1" />
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.17.1" />
<PackageVersion Include="FastBertTokenizer" Version="0.4.67" />
<PackageVersion Include="System.Diagnostics.DiagnosticSource" Version="8.0.0" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="System.Memory.Data" Version="8.0.0" />
Expand Down
18 changes: 18 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Functions.Grpc", "src\Funct
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.HuggingFace", "src\Connectors\Connectors.HuggingFace\Connectors.HuggingFace.csproj", "{136823BE-8665-4D57-87E0-EF41535539E2}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Onnx", "src\Connectors\Connectors.Onnx\Connectors.Onnx.csproj", "{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "InternalUtilities", "InternalUtilities", "{4D3DAE63-41C6-4E1C-A35A-E77BDFC40675}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Weaviate", "src\Connectors\Connectors.Memory.Weaviate\Connectors.Memory.Weaviate.csproj", "{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}"
Expand Down Expand Up @@ -228,6 +230,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HomeAutomation", "samples\H
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HuggingFaceImageTextExample", "samples\HuggingFaceImageTextExample\HuggingFaceImageTextExample.csproj", "{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Onnx.UnitTests", "src\Connectors\Connectors.Onnx.UnitTests\Connectors.Onnx.UnitTests.csproj", "{D06465FA-0308-494C-920B-D502DA5690CB}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -372,6 +376,12 @@ Global
{136823BE-8665-4D57-87E0-EF41535539E2}.Publish|Any CPU.Build.0 = Publish|Any CPU
{136823BE-8665-4D57-87E0-EF41535539E2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{136823BE-8665-4D57-87E0-EF41535539E2}.Release|Any CPU.Build.0 = Release|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Publish|Any CPU.ActiveCfg = Publish|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Publish|Any CPU.Build.0 = Publish|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9}.Release|Any CPU.Build.0 = Release|Any CPU
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4}.Publish|Any CPU.ActiveCfg = Publish|Any CPU
Expand Down Expand Up @@ -533,6 +543,12 @@ Global
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}.Publish|Any CPU.Build.0 = Debug|Any CPU
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4}.Release|Any CPU.Build.0 = Release|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Publish|Any CPU.Build.0 = Debug|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D06465FA-0308-494C-920B-D502DA5690CB}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -565,6 +581,7 @@ Global
{4D226C2F-AE9F-4EFB-AF2D-45C8FE5CB34E} = {24503383-A8C4-4255-9998-28D70FE8E99A}
{E52F805C-794A-4CA9-B684-DFF358B18820} = {9ECD1AA0-75B3-4E25-B0B5-9F0945B64974}
{136823BE-8665-4D57-87E0-EF41535539E2} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
{FBEB24A0-E4E9-44D7-B56C-48D91D39A3F9} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
{4D3DAE63-41C6-4E1C-A35A-E77BDFC40675} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0}
{6AAB0620-33A1-4A98-A63B-6560B9BA47A4} = {24503383-A8C4-4255-9998-28D70FE8E99A}
{50FAE231-6F24-4779-9D02-12ABBC9A49E2} = {24503383-A8C4-4255-9998-28D70FE8E99A}
Expand Down Expand Up @@ -610,6 +627,7 @@ Global
{1F96837A-61EC-4C8F-904A-07BEBD05FDEE} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
{13429BD6-4C4E-45EC-81AD-30BAC380AA60} = {FA3720F1-C99A-49B2-9577-A940257098BF}
{8EE10EB0-A947-49CC-BCC1-18D93415B9E4} = {FA3720F1-C99A-49B2-9577-A940257098BF}
{D06465FA-0308-494C-920B-D502DA5690CB} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<AssemblyName>SemanticKernel.Connectors.Onnx.UnitTests</AssemblyName>
<RootNamespace>SemanticKernel.Connectors.Onnx.UnitTests</RootNamespace>
<TargetFramework>net6.0</TargetFramework>
<LangVersion>12</LangVersion>
<RollForward>LatestMajor</RollForward>
<IsTestProject>true</IsTestProject>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);SKEXP0001;SKEXP0020;CS1591;IDE1006;RCS1261;CA1308;CA1849;CA1861;CA2007;CA2234;VSTHRD111</NoWarn>
</PropertyGroup>

<ItemGroup>
<!-- Use newest available compiler to permit LangVersion 12. -->
<!-- This can be removed once we no longer target the .NET 6 SDK in CI. -->
<PackageReference Include="Microsoft.Net.Compilers.Toolset" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Moq" />
<PackageReference Include="xunit" />
<PackageReference Include="xunit.runner.visualstudio">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="System.Numerics.Tensors" />
<PackageReference Include="System.Text.Json" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
</ItemGroup>

<ItemGroup>
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/test/AssertExtensions.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
<ProjectReference Include="..\Connectors.Onnx\Connectors.Onnx.csproj" />
</ItemGroup>

</Project>
6 changes: 6 additions & 0 deletions dotnet/src/Connectors/Connectors.Onnx/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;

// This assembly is currently experimental.
[assembly: Experimental("SKEXP0020")]
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
101 changes: 101 additions & 0 deletions dotnet/src/Connectors/Connectors.Onnx/BertOnnxOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text;

namespace Microsoft.SemanticKernel.Connectors.Onnx;

/// <summary>Provides an options bag used to configure <see cref="BertOnnxTextEmbeddingGenerationService"/>.</summary>
public sealed class BertOnnxOptions
{
private int _maximumTokens = 512;
private string _clsToken = "[CLS]";
private string _unknownToken = "[UNK]";
private string _sepToken = "[SEP]";
private string _padToken = "[PAD]";
private EmbeddingPoolingMode _poolingMode = EmbeddingPoolingMode.Mean;

/// <summary>Gets or sets whether the vocabulary employed by the model is case-sensitive.</summary>
public bool CaseSensitive { get; init; } = false;

/// <summary>Gets or sets the maximum number of tokens to encode. Defaults to 512.</summary>
public int MaximumTokens
{
get => this._maximumTokens;
init
{
if (value < 1)
{
throw new ArgumentOutOfRangeException(nameof(this.MaximumTokens));
}

this._maximumTokens = value;
}
}

/// <summary>Gets or sets the cls token. Defaults to "[CLS]".</summary>
public string ClsToken
{
get => this._clsToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._clsToken = value;
}
}

/// <summary>Gets or sets the unknown token. Defaults to "[UNK]".</summary>
public string UnknownToken
{
get => this._unknownToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._unknownToken = value;
}
}

/// <summary>Gets or sets the sep token. Defaults to "[SEP]".</summary>
public string SepToken
{
get => this._sepToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._sepToken = value;
}
}

/// <summary>Gets or sets the pad token. Defaults to "[PAD]".</summary>
public string PadToken
{
get => this._padToken;
init
{
Verify.NotNullOrWhiteSpace(value);
this._padToken = value;
}
}

/// <summary>Gets or sets the type of Unicode normalization to perform on input text. Defaults to <see cref="NormalizationForm.FormD"/>.</summary>
public NormalizationForm UnicodeNormalization { get; init; } = NormalizationForm.FormD;

/// <summary>Gets or sets the pooling mode to use when generating the fixed-length embedding result. Defaults to "mean".</summary>
public EmbeddingPoolingMode PoolingMode
{
get => this._poolingMode;
init
{
if (value is not (EmbeddingPoolingMode.Max or EmbeddingPoolingMode.Mean or EmbeddingPoolingMode.MeanSquareRootTokensLength))
{
throw new ArgumentOutOfRangeException(nameof(this.PoolingMode));
}

this._poolingMode = value;
}
}

/// <summary>Gets or sets whether the resulting embedding vectors should be explicitly normalized. Defaults to false.</summary>
/// <remarks>Normalized embeddings may be compared more efficiently, such as by using a dot product rather than cosine similarity.</remarks>
public bool NormalizeEmbeddings { get; set; } = false;
}
Loading
Loading