Skip to content

Commit f01cf49

Browse files
Refactored the RAG process (#287)
1 parent 77d4276 commit f01cf49

File tree

7 files changed

+349
-194
lines changed

7 files changed

+349
-194
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Things we are currently working on:
1919
- [ ] App: Implement the process to vectorize one local file using embeddings
2020
- [ ] Runtime: Integration of the vector database [LanceDB](https://github.com/lancedb/lancedb)
2121
- [ ] App: Implement the continuous process of vectorizing data
22-
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286))~~
22+
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286), [#287](https://github.com/MindWorkAI/AI-Studio/pull/287))~~
2323
- [ ] App: Define a common augmentation interface for the integration of RAG processes in chats
2424
- [x] ~~App: Integrate data sources in chats (PR [#282](https://github.com/MindWorkAI/AI-Studio/pull/282))~~
2525

app/MindWork AI Studio/Chat/ContentText.cs

Lines changed: 9 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
using System.Text.Json.Serialization;
22

3-
using AIStudio.Agents;
4-
using AIStudio.Components;
53
using AIStudio.Provider;
64
using AIStudio.Settings;
7-
using AIStudio.Tools.RAG;
8-
using AIStudio.Tools.Services;
5+
using AIStudio.Tools.RAG.RAGProcesses;
96

107
namespace AIStudio.Chat;
118

@@ -43,204 +40,23 @@ public async Task CreateFromProviderAsync(IProvider provider, Model chatModel, I
4340
{
4441
if(chatThread is null)
4542
return;
46-
47-
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<ContentText>>()!;
48-
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
49-
var dataSourceService = Program.SERVICE_PROVIDER.GetService<DataSourceService>()!;
5043

51-
//
52-
// 1. Check if the user wants to bind any data sources to the chat:
53-
//
54-
if (chatThread.DataSourceOptions.IsEnabled() && lastPrompt is not null)
44+
// Call the RAG process. Right now, we only have one RAG process:
45+
if (lastPrompt is not null)
5546
{
56-
logger.LogInformation("Data sources are enabled for this chat.");
57-
58-
// Across the different code-branches, we keep track of whether it
59-
// makes sense to proceed with the RAG process:
60-
var proceedWithRAG = true;
61-
62-
//
63-
// When the user wants to bind data sources to the chat, we
64-
// have to check if the data sources are available for the
65-
// selected provider. Also, we have to check if any ERI
66-
// data sources changed its security requirements.
67-
//
68-
List<IDataSource> preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!;
69-
var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources);
70-
var selectedDataSources = dataSources.SelectedDataSources;
71-
72-
//
73-
// Should the AI select the data sources?
74-
//
75-
if (chatThread.DataSourceOptions.AutomaticDataSourceSelection)
76-
{
77-
// Get the agent for the data source selection:
78-
var selectionAgent = Program.SERVICE_PROVIDER.GetService<AgentDataSourceSelection>()!;
79-
80-
// Let the AI agent do its work:
81-
IReadOnlyList<DataSourceAgentSelected> finalAISelection = [];
82-
var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token);
83-
84-
// Check if the AI selected any data sources:
85-
if(aiSelectedDataSources.Count is 0)
86-
{
87-
logger.LogWarning("The AI did not select any data sources. The RAG process is skipped.");
88-
proceedWithRAG = false;
89-
90-
// Send the selected data sources to the data source selection component.
91-
// Then, the user can see which data sources were selected by the AI.
92-
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
93-
chatThread.AISelectedDataSources = finalAISelection;
94-
}
95-
else
96-
{
97-
// Log the selected data sources:
98-
var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'");
99-
logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
100-
101-
//
102-
// Check how many data sources were hallucinated by the AI:
103-
//
104-
var totalAISelectedDataSources = aiSelectedDataSources.Count;
105-
106-
// Filter out the data sources that are not available:
107-
aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList();
108-
109-
// Store the real AI-selected data sources:
110-
finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList();
111-
112-
var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count;
113-
if(numHallucinatedSources > 0)
114-
logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them.");
115-
116-
if (aiSelectedDataSources.Count > 3)
117-
{
118-
//
119-
// We have more than 3 data sources. Let's filter by confidence.
120-
// In order to do that, we must identify the lower and upper
121-
// bounds of the confidence interval:
122-
//
123-
var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList();
124-
var lowerBound = confidenceValues.Min();
125-
var upperBound = confidenceValues.Max();
126-
127-
//
128-
// Next, we search for a threshold so that we have between 2 and 3
129-
// data sources. When not possible, we take all data sources.
130-
//
131-
var threshold = 0.0f;
132-
133-
// Check the case where the confidence values are too close:
134-
if (upperBound - lowerBound >= 0.01)
135-
{
136-
var previousThreshold = 0.0f;
137-
for (var i = 0; i < 10; i++)
138-
{
139-
threshold = lowerBound + (upperBound - lowerBound) * i / 10;
140-
var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold);
141-
if (numMatches <= 1)
142-
{
143-
threshold = previousThreshold;
144-
break;
145-
}
146-
147-
if (numMatches is <= 3 and >= 2)
148-
break;
149-
150-
previousThreshold = threshold;
151-
}
152-
}
153-
154-
//
155-
// Filter the data sources by the threshold:
156-
//
157-
aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList();
158-
foreach (var dataSource in finalAISelection)
159-
if(aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id))
160-
dataSource.Selected = true;
161-
162-
logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}.");
163-
164-
// Transform the final data sources to the actual data sources:
165-
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
166-
}
167-
168-
// We have max. 3 data sources. We take all of them:
169-
else
170-
{
171-
// Transform the selected data sources to the actual data sources:
172-
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
173-
174-
// Mark the data sources as selected:
175-
foreach (var dataSource in finalAISelection)
176-
dataSource.Selected = true;
177-
}
178-
179-
// Send the selected data sources to the data source selection component.
180-
// Then, the user can see which data sources were selected by the AI.
181-
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
182-
chatThread.AISelectedDataSources = finalAISelection;
183-
}
184-
}
185-
else
186-
{
187-
//
188-
// No, the user made the choice manually:
189-
//
190-
var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'");
191-
logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
192-
}
193-
194-
if(selectedDataSources.Count == 0)
195-
{
196-
logger.LogWarning("No data sources are selected. The RAG process is skipped.");
197-
proceedWithRAG = false;
198-
}
199-
200-
//
201-
// Trigger the retrieval part of the (R)AG process:
202-
//
203-
var dataContexts = new List<IRetrievalContext>();
204-
if (proceedWithRAG)
205-
{
206-
//
207-
// We kick off the retrieval process for each data source in parallel:
208-
//
209-
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
210-
foreach (var dataSource in selectedDataSources)
211-
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));
212-
213-
//
214-
// Wait for all retrieval tasks to finish:
215-
//
216-
foreach (var retrievalTask in retrievalTasks)
217-
{
218-
try
219-
{
220-
dataContexts.AddRange(await retrievalTask);
221-
}
222-
catch (Exception e)
223-
{
224-
logger.LogError(e, "An error occurred during the retrieval process.");
225-
}
226-
}
227-
}
228-
229-
//
230-
// Perform the augmentation of the R(A)G process:
231-
//
232-
if (proceedWithRAG)
233-
{
234-
235-
}
47+
var rag = new AISrcSelWithRetCtxVal();
48+
chatThread = await rag.ProcessAsync(provider, lastPrompt, chatThread, token);
23649
}
237-
50+
23851
// Store the last time we got a response. We use this later
23952
// to determine whether we should notify the UI about the
24053
// new content or not. Depends on the energy saving mode
24154
// the user chose.
24255
var last = DateTimeOffset.Now;
24356

57+
// Get the settings manager:
58+
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
59+
24460
// Start another thread by using a task to uncouple
24561
// the UI thread from the AI processing:
24662
await Task.Run(async () =>
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using AIStudio.Settings;
2+
3+
namespace AIStudio.Tools.RAG;
4+
5+
/// <summary>
6+
/// Result of any data selection process.
7+
/// </summary>
8+
/// <param name="ProceedWithRAG">Makes it sense to proceed with the RAG process?</param>
9+
/// <param name="SelectedDataSources">The selected data sources.</param>
10+
public readonly record struct DataSelectionResult(bool ProceedWithRAG, IReadOnlyList<IDataSource> SelectedDataSources);

0 commit comments

Comments
 (0)