Skip to content

Commit

Permalink
.Net: OpenAI embeddings refactor (microsoft#4600)
Browse files Browse the repository at this point in the history
### Motivation and Context
Continuation of microsoft#3295

The OpenAI embeddings endpoint and `text-embedding-ada-002` supports
either a single item or an array of items. Rather than send text inputs
one by one, the request should send them all at once for improved
performance.

Note that this will require callers to manage their own batching
strategy (added an example to demonstrate). I view this as a worthwhile
tradeoff for improved performance out of the box.

Fixes microsoft#3294.

### Description
- Text data is sent as an array to the embeddings endpoint instead of an
individual request per text.
- Aligns embedding request functionality with how the HuggingFace
connector works.
- Remove unnecessary `ToArray()` call on the response embeddings.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄 (implementations passing more than
16 inputs will get an error from the OpenAI endpoint and need to batch
appropriately)

---------

Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
  • Loading branch information
2 people authored and BR committed Oct 6, 2024
1 parent c6c80d6 commit b8b925f
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 10 deletions.
182 changes: 182 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Example78_TextEmbedding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Linq;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Text;
using RepoUtils;
using SharpToken;
using Xunit;
using Xunit.Abstractions;

namespace Examples;

public class Example78_TextEmbedding : BaseTest
{
[Fact]
public async Task RunAsync()
{
this.WriteLine("======== Example76_TextEmbedding ========");
await RunExampleAsync();
}

private async Task RunExampleAsync()
{
const string EmbeddingModelName = "text-embedding-ada-002";
var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(
deploymentName: EmbeddingModelName,
endpoint: TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
apiKey: TestConfiguration.AzureOpenAIEmbeddings.ApiKey);

// To demonstrate batching we'll create abnormally small partitions.
var lines = TextChunker.SplitPlainTextLines(ChatTranscript, maxTokensPerLine: 10);
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, maxTokensPerParagraph: 25);

this.WriteLine($"Split transcript into {paragraphs.Count} paragraphs");

// Azure OpenAI currently supports input arrays up to 16 for text-embedding-ada-002 (Version 2).
// Both require the max input token limit per API request to remain under 8191 for this model.
var chunks = paragraphs
.ChunkByAggregate(
seed: 0,
aggregator: (tokenCount, paragraph) => tokenCount + GetTokenCount(EmbeddingModelName, paragraph),
predicate: (tokenCount, index) => tokenCount < 8191 && index < 16)
.ToList();

this.WriteLine($"Consolidated paragraphs into {chunks.Count}");

// Generate embeddings for each chunk.
for (var i = 0; i < chunks.Count; i++)
{
var chunk = chunks[i];
var embeddings = await embeddingGenerator.GenerateEmbeddingsAsync(chunk);

this.WriteLine($"Generated {embeddings.Count} embeddings from chunk {i + 1}");
}
}

// See Example55_TextChunker for more examples of how to count tokens.
private int GetTokenCount(string modelName, string text)
{
var encoding = GptEncoding.GetEncodingForModel(modelName);
var tokens = encoding.Encode(text);

return tokens.Count;
}

public Example78_TextEmbedding(ITestOutputHelper output) : base(output)
{
}

#region Transcript

private const string ChatTranscript =
@"
John: Hello, how are you?
Jane: I'm fine, thanks. How are you?
John: I'm doing well, writing some example code.
Jane: That's great! I'm writing some example code too.
John: What are you writing?
Jane: I'm writing a chatbot.
John: That's cool. I'm writing a chatbot too.
Jane: What language are you writing it in?
John: I'm writing it in C#.
Jane: I'm writing it in Python.
John: That's cool. I need to learn Python.
Jane: I need to learn C#.
John: Can I try out your chatbot?
Jane: Sure, here's the link.
John: Thanks!
Jane: You're welcome.
Jane: Look at this poem my chatbot wrote:
Jane: Roses are red
Jane: Violets are blue
Jane: I'm writing a chatbot
Jane: What about you?
John: That's cool. Let me see if mine will write a poem, too.
John: Here's a poem my chatbot wrote:
John: The singularity of the universe is a mystery.
John: The universe is a mystery.
John: The universe is a mystery.
John: The universe is a mystery.
John: Looks like I need to improve mine, oh well.
Jane: You might want to try using a different model.
Jane: I'm using the GPT-3 model.
John: I'm using the GPT-2 model. That makes sense.
John: Here is a new poem after updating the model.
John: The universe is a mystery.
John: The universe is a mystery.
John: The universe is a mystery.
John: Yikes, it's really stuck isn't it. Would you help me debug my code?
Jane: Sure, what's the problem?
John: I'm not sure. I think it's a bug in the code.
Jane: I'll take a look.
Jane: I think I found the problem.
Jane: It looks like you're not passing the right parameters to the model.
John: Thanks for the help!
Jane: I'm now writing a bot to summarize conversations. I want to make sure it works when the conversation is long.
John: So you need to keep talking with me to generate a long conversation?
Jane: Yes, that's right.
John: Ok, I'll keep talking. What should we talk about?
Jane: I don't know, what do you want to talk about?
John: I don't know, it's nice how CoPilot is doing most of the talking for us. But it definitely gets stuck sometimes.
Jane: I agree, it's nice that CoPilot is doing most of the talking for us.
Jane: But it definitely gets stuck sometimes.
John: Do you know how long it needs to be?
Jane: I think the max length is 1024 tokens. Which is approximately 1024*4= 4096 characters.
John: That's a lot of characters.
Jane: Yes, it is.
John: I'm not sure how much longer I can keep talking.
Jane: I think we're almost there. Let me check.
Jane: I have some bad news, we're only half way there.
John: Oh no, I'm not sure I can keep going. I'm getting tired.
Jane: I'm getting tired too.
John: Maybe there is a large piece of text we can use to generate a long conversation.
Jane: That's a good idea. Let me see if I can find one. Maybe Lorem Ipsum?
John: Yeah, that's a good idea.
Jane: I found a Lorem Ipsum generator.
Jane: Here's a 4096 character Lorem Ipsum text:
Jane: Lorem ipsum dolor sit amet, con
Jane: Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed euismod, nunc sit amet aliquam
Jane: Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed euismod, nunc sit amet aliquam
Jane: Darn, it's just repeating stuf now.
John: I think we're done.
Jane: We're not though! We need like 1500 more characters.
John: Oh Cananda, our home and native land.
Jane: True patriot love in all thy sons command.
John: With glowing hearts we see thee rise.
Jane: The True North strong and free.
John: From far and wide, O Canada, we stand on guard for thee.
Jane: God keep our land glorious and free.
John: O Canada, we stand on guard for thee.
Jane: O Canada, we stand on guard for thee.
Jane: That was fun, thank you. Let me check now.
Jane: I think we need about 600 more characters.
John: Oh say can you see?
Jane: By the dawn's early light.
John: What so proudly we hailed.
Jane: At the twilight's last gleaming.
John: Whose broad stripes and bright stars.
Jane: Through the perilous fight.
John: O'er the ramparts we watched.
Jane: Were so gallantly streaming.
John: And the rockets' red glare.
Jane: The bombs bursting in air.
John: Gave proof through the night.
Jane: That our flag was still there.
John: Oh say does that star-spangled banner yet wave.
Jane: O'er the land of the free.
John: And the home of the brave.
Jane: Are you a Seattle Kraken Fan?
John: Yes, I am. I love going to the games.
Jane: I'm a Seattle Kraken Fan too. Who is your favorite player?
John: I like watching all the players, but I think my favorite is Matty Beniers.
Jane: Yeah, he's a great player. I like watching him too. I also like watching Jaden Schwartz.
John: Adam Larsson is another good one. The big cat!
Jane: WE MADE IT! It's long enough. Thank you!
John: You're welcome. I'm glad we could help. Goodbye!
Jane: Goodbye!
";

#endregion
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;

namespace RepoUtils;

public static class EnumerableExtensions
{
public static IEnumerable<List<TSource>> ChunkByAggregate<TSource, TAccumulate>(
this IEnumerable<TSource> source,
TAccumulate seed,
Func<TAccumulate, TSource, TAccumulate> aggregator,
Func<TAccumulate, int, bool> predicate)
{
using var enumerator = source.GetEnumerator();
var aggregate = seed;
var index = 0;
var chunk = new List<TSource>();

while (enumerator.MoveNext())
{
var current = enumerator.Current;

aggregate = aggregator(aggregate, current);

if (predicate(aggregate, index++))
{
chunk.Add(current);
}
else
{
if (chunk.Count > 0)
{
yield return chunk;
}

chunk = new List<TSource>() { current };
aggregate = aggregator(seed, current);
index = 1;
}
}

if (chunk.Count > 0)
{
yield return chunk;
}
}
}
18 changes: 11 additions & 7 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,21 @@ internal async Task<IList<ReadOnlyMemory<float>>> GetEmbeddingsAsync(
CancellationToken cancellationToken)
{
var result = new List<ReadOnlyMemory<float>>(data.Count);
foreach (string text in data)

if (data.Count > 0)
{
var options = new EmbeddingsOptions(this.DeploymentOrModelName, new[] { text });
var response = await RunRequestAsync(() => this.Client.GetEmbeddingsAsync(new(this.DeploymentOrModelName, data), cancellationToken)).ConfigureAwait(false);
var embeddings = response.Value.Data;

Response<Azure.AI.OpenAI.Embeddings> response = await RunRequestAsync(() => this.Client.GetEmbeddingsAsync(options, cancellationToken)).ConfigureAwait(false);
if (response.Value.Data.Count == 0)
if (embeddings.Count != data.Count)
{
throw new KernelException("Text embedding not found");
throw new KernelException($"Expected {data.Count} text embedding(s), but received {embeddings.Count}");
}

result.Add(response.Value.Data[0].Embedding.ToArray());
for (var i = 0; i < embeddings.Count; i++)
{
result.Add(embeddings[i].Embedding);
}
}

return result;
Expand Down Expand Up @@ -461,7 +465,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
yield return new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata);
}

// If we don't have a function to invoke, we're done.
// If we don't have a function to invoke, we're done.
// Note that we don't check the FinishReason and instead check whether there are any tool calls, as the service
// may return a FinishReason of "stop" even if there are tool calls to be made, in particular if a required tool
// is specified.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public async Task GenerateEmbeddingsWithEmptyResponseThrowsExceptionAsync()

// Act & Assert
var exception = await Assert.ThrowsAsync<KernelException>(() => service.GenerateEmbeddingsAsync(["test"]));
Assert.Equal("Text embedding not found", exception.Message);
Assert.Equal("Expected 1 text embedding(s), but received 0", exception.Message);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public async Task GenerateEmbeddingsWithEmptyResponseThrowsExceptionAsync()

// Act & Assert
var exception = await Assert.ThrowsAsync<KernelException>(() => service.GenerateEmbeddingsAsync(["test"]));
Assert.Equal("Text embedding not found", exception.Message);
Assert.Equal("Expected 1 text embedding(s), but received 0", exception.Message);
}

[Fact]
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/IntegrationTests/testsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
"AzureOpenAIEmbeddings": {
"ServiceId": "azure-text-embedding-ada-002",
"DeploymentName": "text-embedding-ada-002",
"DeploymentName": "ada-002",
"Endpoint": "",
"ApiKey": ""
},
Expand Down

0 comments on commit b8b925f

Please sign in to comment.