Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ Task<KernelSearchResults<TextSearchResult>> GetTextSearchResultsAsync(
CancellationToken cancellationToken = default);

/// <summary>
/// Perform a search for content related to the specified query and return <see cref="object"/> values representing the search results.
/// Perform a search for content related to the specified query and return strongly-typed <typeparamref name="TRecord"/> values representing the search results.
/// </summary>
/// <param name="query">What to search for.</param>
/// <param name="searchOptions">Options used when executing a text search.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
Task<KernelSearchResults<object>> GetSearchResultsAsync(
Task<KernelSearchResults<TRecord>> GetSearchResultsAsync(
string query,
TextSearchOptions<TRecord>? searchOptions = null,
CancellationToken cancellationToken = default);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ Task<KernelSearchResults<TextSearchResult>> ITextSearch<TRecord>.GetTextSearchRe
}

/// <inheritdoc/>
Task<KernelSearchResults<object>> ITextSearch<TRecord>.GetSearchResultsAsync(string query, TextSearchOptions<TRecord>? searchOptions, CancellationToken cancellationToken)
Task<KernelSearchResults<TRecord>> ITextSearch<TRecord>.GetSearchResultsAsync(string query, TextSearchOptions<TRecord>? searchOptions, CancellationToken cancellationToken)
{
var searchResponse = this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken);

return Task.FromResult(new KernelSearchResults<object>(this.GetResultsAsRecordAsync(searchResponse, cancellationToken)));
return Task.FromResult(new KernelSearchResults<TRecord>(this.GetResultsAsTRecordAsync(searchResponse, cancellationToken)));
}

#region private
Expand Down Expand Up @@ -367,6 +367,28 @@ private async IAsyncEnumerable<object> GetResultsAsRecordAsync(IAsyncEnumerable<
}
}

/// <summary>
/// Return the search results as strongly-typed <typeparamref name="TRecord"/> instances.
/// </summary>
/// <param name="searchResponse">Response containing the records matching the query.</param>
/// <param name="cancellationToken">Cancellation token</param>
private async IAsyncEnumerable<TRecord> GetResultsAsTRecordAsync(IAsyncEnumerable<VectorSearchResult<TRecord>>? searchResponse, [EnumeratorCancellation] CancellationToken cancellationToken)
{
if (searchResponse is null)
{
yield break;
}

await foreach (var result in searchResponse.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (result.Record is not null)
{
yield return result.Record;
await Task.Yield();
}
}
}

/// <summary>
/// Return the search results as instances of <see cref="TextSearchResult"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ public async Task CanGetSearchResultAsync()
{
// Arrange.
var sut = await CreateVectorStoreTextSearchAsync();
ITextSearch<DataModel> typeSafeInterface = sut;

// Act.
KernelSearchResults<object> searchResults = await sut.GetSearchResultsAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 });
KernelSearchResults<DataModel> searchResults = await typeSafeInterface.GetSearchResultsAsync("What is the Semantic Kernel?", new TextSearchOptions<DataModel> { Top = 2, Skip = 0 });
var results = await searchResults.Results.ToListAsync();

Assert.Equal(2, results.Count);
Assert.All(results, result => Assert.IsType<DataModel>(result));
}

[Fact]
Expand Down Expand Up @@ -117,12 +119,14 @@ public async Task CanGetSearchResultsWithEmbeddingGeneratorAsync()
{
// Arrange.
var sut = await CreateVectorStoreTextSearchWithEmbeddingGeneratorAsync();
ITextSearch<DataModelWithRawEmbedding> typeSafeInterface = sut;

// Act.
KernelSearchResults<object> searchResults = await sut.GetSearchResultsAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 });
KernelSearchResults<DataModelWithRawEmbedding> searchResults = await typeSafeInterface.GetSearchResultsAsync("What is the Semantic Kernel?", new TextSearchOptions<DataModelWithRawEmbedding> { Top = 2, Skip = 0 });
var results = await searchResults.Results.ToListAsync();

Assert.Equal(2, results.Count);
Assert.All(results, result => Assert.IsType<DataModelWithRawEmbedding>(result));
}

#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete
Expand Down Expand Up @@ -270,17 +274,16 @@ public async Task LinqGetSearchResultsAsync()
Filter = r => r.Tag == "Even"
};

KernelSearchResults<object> searchResults = await typeSafeInterface.GetSearchResultsAsync(
KernelSearchResults<DataModel> searchResults = await typeSafeInterface.GetSearchResultsAsync(
"What is the Semantic Kernel?",
searchOptions);
var results = await searchResults.Results.ToListAsync();

// Assert - Results should be DataModel objects with Tag == "Even"
// Assert - Results should be strongly-typed DataModel objects with Tag == "Even"
Assert.NotEmpty(results);
Assert.All(results, result =>
{
var dataModel = Assert.IsType<DataModel>(result);
Assert.Equal("Even", dataModel.Tag);
Assert.Equal("Even", result.Tag); // Direct property access - no cast needed!
});
}

Expand Down
Loading