|
1 | 1 | using System.Text.Json.Serialization; |
2 | 2 |
|
3 | | -using AIStudio.Agents; |
4 | | -using AIStudio.Components; |
5 | 3 | using AIStudio.Provider; |
6 | 4 | using AIStudio.Settings; |
7 | | -using AIStudio.Tools.RAG; |
8 | | -using AIStudio.Tools.Services; |
| 5 | +using AIStudio.Tools.RAG.RAGProcesses; |
9 | 6 |
|
10 | 7 | namespace AIStudio.Chat; |
11 | 8 |
|
@@ -43,204 +40,23 @@ public async Task CreateFromProviderAsync(IProvider provider, Model chatModel, I |
43 | 40 | { |
44 | 41 | if(chatThread is null) |
45 | 42 | return; |
46 | | - |
47 | | - var logger = Program.SERVICE_PROVIDER.GetService<ILogger<ContentText>>()!; |
48 | | - var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!; |
49 | | - var dataSourceService = Program.SERVICE_PROVIDER.GetService<DataSourceService>()!; |
50 | 43 |
|
51 | | - // |
52 | | - // 1. Check if the user wants to bind any data sources to the chat: |
53 | | - // |
54 | | - if (chatThread.DataSourceOptions.IsEnabled() && lastPrompt is not null) |
| 44 | + // Call the RAG process. Right now, we only have one RAG process: |
| 45 | + if (lastPrompt is not null) |
55 | 46 | { |
56 | | - logger.LogInformation("Data sources are enabled for this chat."); |
57 | | - |
58 | | - // Across the different code-branches, we keep track of whether it |
59 | | - // makes sense to proceed with the RAG process: |
60 | | - var proceedWithRAG = true; |
61 | | - |
62 | | - // |
63 | | - // When the user wants to bind data sources to the chat, we |
64 | | - // have to check if the data sources are available for the |
65 | | - // selected provider. Also, we have to check if any ERI |
66 | | - // data sources changed its security requirements. |
67 | | - // |
68 | | - List<IDataSource> preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!; |
69 | | - var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources); |
70 | | - var selectedDataSources = dataSources.SelectedDataSources; |
71 | | - |
72 | | - // |
73 | | - // Should the AI select the data sources? |
74 | | - // |
75 | | - if (chatThread.DataSourceOptions.AutomaticDataSourceSelection) |
76 | | - { |
77 | | - // Get the agent for the data source selection: |
78 | | - var selectionAgent = Program.SERVICE_PROVIDER.GetService<AgentDataSourceSelection>()!; |
79 | | - |
80 | | - // Let the AI agent do its work: |
81 | | - IReadOnlyList<DataSourceAgentSelected> finalAISelection = []; |
82 | | - var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token); |
83 | | - |
84 | | - // Check if the AI selected any data sources: |
85 | | - if(aiSelectedDataSources.Count is 0) |
86 | | - { |
87 | | - logger.LogWarning("The AI did not select any data sources. The RAG process is skipped."); |
88 | | - proceedWithRAG = false; |
89 | | - |
90 | | - // Send the selected data sources to the data source selection component. |
91 | | - // Then, the user can see which data sources were selected by the AI. |
92 | | - await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection); |
93 | | - chatThread.AISelectedDataSources = finalAISelection; |
94 | | - } |
95 | | - else |
96 | | - { |
97 | | - // Log the selected data sources: |
98 | | - var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'"); |
99 | | - logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}."); |
100 | | - |
101 | | - // |
102 | | - // Check how many data sources were hallucinated by the AI: |
103 | | - // |
104 | | - var totalAISelectedDataSources = aiSelectedDataSources.Count; |
105 | | - |
106 | | - // Filter out the data sources that are not available: |
107 | | - aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList(); |
108 | | - |
109 | | - // Store the real AI-selected data sources: |
110 | | - finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList(); |
111 | | - |
112 | | - var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count; |
113 | | - if(numHallucinatedSources > 0) |
114 | | - logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them."); |
115 | | - |
116 | | - if (aiSelectedDataSources.Count > 3) |
117 | | - { |
118 | | - // |
119 | | - // We have more than 3 data sources. Let's filter by confidence. |
120 | | - // In order to do that, we must identify the lower and upper |
121 | | - // bounds of the confidence interval: |
122 | | - // |
123 | | - var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList(); |
124 | | - var lowerBound = confidenceValues.Min(); |
125 | | - var upperBound = confidenceValues.Max(); |
126 | | - |
127 | | - // |
128 | | - // Next, we search for a threshold so that we have between 2 and 3 |
129 | | - // data sources. When not possible, we take all data sources. |
130 | | - // |
131 | | - var threshold = 0.0f; |
132 | | - |
133 | | - // Check the case where the confidence values are too close: |
134 | | - if (upperBound - lowerBound >= 0.01) |
135 | | - { |
136 | | - var previousThreshold = 0.0f; |
137 | | - for (var i = 0; i < 10; i++) |
138 | | - { |
139 | | - threshold = lowerBound + (upperBound - lowerBound) * i / 10; |
140 | | - var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold); |
141 | | - if (numMatches <= 1) |
142 | | - { |
143 | | - threshold = previousThreshold; |
144 | | - break; |
145 | | - } |
146 | | - |
147 | | - if (numMatches is <= 3 and >= 2) |
148 | | - break; |
149 | | - |
150 | | - previousThreshold = threshold; |
151 | | - } |
152 | | - } |
153 | | - |
154 | | - // |
155 | | - // Filter the data sources by the threshold: |
156 | | - // |
157 | | - aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList(); |
158 | | - foreach (var dataSource in finalAISelection) |
159 | | - if(aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id)) |
160 | | - dataSource.Selected = true; |
161 | | - |
162 | | - logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}."); |
163 | | - |
164 | | - // Transform the final data sources to the actual data sources: |
165 | | - selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!; |
166 | | - } |
167 | | - |
168 | | - // We have max. 3 data sources. We take all of them: |
169 | | - else |
170 | | - { |
171 | | - // Transform the selected data sources to the actual data sources: |
172 | | - selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!; |
173 | | - |
174 | | - // Mark the data sources as selected: |
175 | | - foreach (var dataSource in finalAISelection) |
176 | | - dataSource.Selected = true; |
177 | | - } |
178 | | - |
179 | | - // Send the selected data sources to the data source selection component. |
180 | | - // Then, the user can see which data sources were selected by the AI. |
181 | | - await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection); |
182 | | - chatThread.AISelectedDataSources = finalAISelection; |
183 | | - } |
184 | | - } |
185 | | - else |
186 | | - { |
187 | | - // |
188 | | - // No, the user made the choice manually: |
189 | | - // |
190 | | - var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'"); |
191 | | - logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}."); |
192 | | - } |
193 | | - |
194 | | - if(selectedDataSources.Count == 0) |
195 | | - { |
196 | | - logger.LogWarning("No data sources are selected. The RAG process is skipped."); |
197 | | - proceedWithRAG = false; |
198 | | - } |
199 | | - |
200 | | - // |
201 | | - // Trigger the retrieval part of the (R)AG process: |
202 | | - // |
203 | | - var dataContexts = new List<IRetrievalContext>(); |
204 | | - if (proceedWithRAG) |
205 | | - { |
206 | | - // |
207 | | - // We kick off the retrieval process for each data source in parallel: |
208 | | - // |
209 | | - var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count); |
210 | | - foreach (var dataSource in selectedDataSources) |
211 | | - retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token)); |
212 | | - |
213 | | - // |
214 | | - // Wait for all retrieval tasks to finish: |
215 | | - // |
216 | | - foreach (var retrievalTask in retrievalTasks) |
217 | | - { |
218 | | - try |
219 | | - { |
220 | | - dataContexts.AddRange(await retrievalTask); |
221 | | - } |
222 | | - catch (Exception e) |
223 | | - { |
224 | | - logger.LogError(e, "An error occurred during the retrieval process."); |
225 | | - } |
226 | | - } |
227 | | - } |
228 | | - |
229 | | - // |
230 | | - // Perform the augmentation of the R(A)G process: |
231 | | - // |
232 | | - if (proceedWithRAG) |
233 | | - { |
234 | | - |
235 | | - } |
| 47 | + var rag = new AISrcSelWithRetCtxVal(); |
| 48 | + chatThread = await rag.ProcessAsync(provider, lastPrompt, chatThread, token); |
236 | 49 | } |
237 | | - |
| 50 | + |
238 | 51 | // Store the last time we got a response. We use this later |
239 | 52 | // to determine whether we should notify the UI about the |
240 | 53 | // new content or not. Depends on the energy saving mode |
241 | 54 | // the user chose. |
242 | 55 | var last = DateTimeOffset.Now; |
243 | 56 |
|
| 57 | + // Get the settings manager: |
| 58 | + var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!; |
| 59 | + |
244 | 60 | // Start another thread by using a task to uncouple |
245 | 61 | // the UI thread from the AI processing: |
246 | 62 | await Task.Run(async () => |
|
0 commit comments