Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ public static TService GetRequiredService<TService>(
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">The generator did not produce exactly one embedding.</exception>
/// <remarks>
/// This operation is equivalent to using <see cref="GenerateEmbeddingAsync"/> and returning the
/// This operation is equivalent to using <see cref="GenerateAsync"/> and returning the
/// resulting <see cref="Embedding{T}"/>'s <see cref="Embedding{T}.Vector"/> property.
/// </remarks>
public static async Task<ReadOnlyMemory<TEmbeddingElement>> GenerateEmbeddingVectorAsync<TInput, TEmbeddingElement>(
public static async Task<ReadOnlyMemory<TEmbeddingElement>> GenerateVectorAsync<TInput, TEmbeddingElement>(
this IEmbeddingGenerator<TInput, Embedding<TEmbeddingElement>> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
var embedding = await GenerateEmbeddingAsync(generator, value, options, cancellationToken).ConfigureAwait(false);
var embedding = await GenerateAsync(generator, value, options, cancellationToken).ConfigureAwait(false);
return embedding.Vector;
}

Expand All @@ -130,7 +130,7 @@ public static async Task<ReadOnlyMemory<TEmbeddingElement>> GenerateEmbeddingVec
/// collection composed of the single <paramref name="value"/> and then returning the first embedding element from the
/// resulting <see cref="GeneratedEmbeddings{TEmbedding}"/> collection.
/// </remarks>
public static async Task<TEmbedding> GenerateEmbeddingAsync<TInput, TEmbedding>(
public static async Task<TEmbedding> GenerateAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ using Microsoft.Extensions.AI;
IEmbeddingGenerator<string, Embedding<float>> generator =
new SampleEmbeddingGenerator(new Uri("http://coolsite.ai"), "my-custom-model");

ReadOnlyMemory<float> vector = generator.GenerateEmbeddingVectorAsync("What is AI?");
ReadOnlyMemory<float> vector = generator.GenerateVectorAsync("What is AI?");
```

#### Pipelines of Functionality
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class SemanticSearch(
{
public async Task<IReadOnlyList<SemanticSearchRecord>> SearchAsync(string text, string? filenameFilter, int maxResults)
{
var queryEmbedding = await embeddingGenerator.GenerateEmbeddingVectorAsync(text);
var queryEmbedding = await embeddingGenerator.GenerateVectorAsync(text);
#if (UseQdrant)
var vectorCollection = vectorStore.GetCollection<Guid, SemanticSearchRecord>("data-ChatWithCustomData-CSharp.Web-ingestion");
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ public void GetService_ValidService_Returned()
[Fact]
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
{
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingVectorAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateVectorAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAndZipAsync(["hello"]));
}

Expand All @@ -94,8 +94,8 @@ public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
};

Assert.Same(result, await service.GenerateEmbeddingAsync("hello"));
Assert.Equal(result.Vector, await service.GenerateEmbeddingVectorAsync("hello"));
Assert.Same(result, await service.GenerateAsync("hello"));
Assert.Equal(result.Vector, await service.GenerateVectorAsync("hello"));
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ public virtual async Task Caching_SameOutputsForSameInput()
.Build();

string input = "Red, White, and Blue";
var embedding1 = await generator.GenerateEmbeddingAsync(input);
var embedding2 = await generator.GenerateEmbeddingAsync(input);
var embedding3 = await generator.GenerateEmbeddingAsync(input + "... and Green");
var embedding4 = await generator.GenerateEmbeddingAsync(input);
var embedding1 = await generator.GenerateAsync(input);
var embedding2 = await generator.GenerateAsync(input);
var embedding3 = await generator.GenerateAsync(input + "... and Green");
var embedding4 = await generator.GenerateAsync(input);

var callCounter = generator.GetService<CallCountingEmbeddingGenerator>();
Assert.NotNull(callCounter);
Expand All @@ -116,7 +116,7 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.UseOpenTelemetry(sourceName: sourceName)
.Build();

_ = await embeddingGenerator.GenerateEmbeddingAsync("Hello, world!");
_ = await embeddingGenerator.GenerateAsync("Hello, world!");

Assert.Single(activities);
var activity = activities.Single();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ public async Task CachesSuccessResultsAsync()
};

// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateEmbeddingAsync("abc");
var result1 = await outer.GenerateAsync("abc");
AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GenerateEmbeddingAsync("abc");
var result2 = await outer.GenerateAsync("abc");

// Assert
Assert.Equal(1, innerCallCount);
Expand Down Expand Up @@ -134,8 +134,8 @@ public async Task AllowsConcurrentCallsAsync()
};

// Act 1: Concurrent calls before resolution are passed into the inner client
var result1 = outer.GenerateEmbeddingAsync("abc");
var result2 = outer.GenerateEmbeddingAsync("abc");
var result1 = outer.GenerateAsync("abc");
var result2 = outer.GenerateAsync("abc");

// Assert 1
Assert.Equal(2, innerCallCount);
Expand All @@ -146,7 +146,7 @@ public async Task AllowsConcurrentCallsAsync()
AssertEmbeddingsEqual(_expectedEmbedding, await result2);

// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateEmbeddingAsync("abc");
var result3 = await outer.GenerateAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}
Expand All @@ -169,12 +169,12 @@ public async Task DoesNotCacheExceptionResultsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));
var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
Assert.Equal("some failure", ex1.Message);
Assert.Equal(1, innerCallCount);

// Act
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));

// Assert
Assert.NotSame(ex1, ex2);
Expand Down Expand Up @@ -207,15 +207,15 @@ public async Task DoesNotCacheCanceledResultsAsync()
};

// First call gets cancelled
var result1 = outer.GenerateEmbeddingAsync("abc");
var result1 = outer.GenerateAsync("abc");
Assert.False(result1.IsCompleted);
Assert.Equal(1, innerCallCount);
resolutionTcs.SetCanceled();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => result1);
Assert.True(result1.IsCanceled);

// Act/Assert: Second call can succeed
var result2 = await outer.GenerateEmbeddingAsync("abc");
var result2 = await outer.GenerateAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
Expand All @@ -241,11 +241,11 @@ public async Task CacheKeyVariesByEmbeddingOptionsAsync()
};

// Act: Call with two different EmbeddingGenerationOptions that have the same values
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
Expand All @@ -256,11 +256,11 @@ public async Task CacheKeyVariesByEmbeddingOptionsAsync()
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2);

// Act: Call with two different EmbeddingGenerationOptions that have different values
var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
var result3 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
var result4 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -292,11 +292,11 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -331,7 +331,7 @@ public async Task CanResolveIDistributedCacheFromDI()

// Act: Make a request that should populate the cache
Assert.Empty(_storage.Keys);
var result = await outer.GenerateEmbeddingAsync("abc");
var result = await outer.GenerateAsync("abc");

// Assert
Assert.NotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public async Task GetResponseAsync_LogsResponseInvocationAndCompletion(LogLevel
.UseLogging()
.Build(services);

await generator.GenerateEmbeddingAsync("Blue whale");
await generator.GenerateAsync("Blue whale");

var logs = collector.GetSnapshot();
if (level is LogLevel.Trace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public async Task ExpectedInformationLogged_Async(string? perRequestModelId)
},
};

await generator.GenerateEmbeddingVectorAsync("hello", options);
await generator.GenerateVectorAsync("hello", options);

var activity = Assert.Single(activities);
var expectedModelName = perRequestModelId ?? "defaultmodel";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class SemanticSearch(
{
public async Task<IReadOnlyList<SemanticSearchRecord>> SearchAsync(string text, string? filenameFilter, int maxResults)
{
var queryEmbedding = await embeddingGenerator.GenerateEmbeddingVectorAsync(text);
var queryEmbedding = await embeddingGenerator.GenerateVectorAsync(text);
var vectorCollection = vectorStore.GetCollection<string, SemanticSearchRecord>("data-aichatweb-ingested");

var nearest = await vectorCollection.VectorizedSearchAsync(queryEmbedding, new VectorSearchOptions<SemanticSearchRecord>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class SemanticSearch(
{
public async Task<IReadOnlyList<SemanticSearchRecord>> SearchAsync(string text, string? filenameFilter, int maxResults)
{
var queryEmbedding = await embeddingGenerator.GenerateEmbeddingVectorAsync(text);
var queryEmbedding = await embeddingGenerator.GenerateVectorAsync(text);
var vectorCollection = vectorStore.GetCollection<string, SemanticSearchRecord>("data-aichatweb-ingested");

var nearest = await vectorCollection.VectorizedSearchAsync(queryEmbedding, new VectorSearchOptions<SemanticSearchRecord>
Expand Down
Loading