Skip to content

Commit

Permalink
Update lab 035 to new OpenAI SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
rstropek committed Jun 10, 2024
1 parent af53ce2 commit f11e034
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 109 deletions.
2 changes: 1 addition & 1 deletion labs/035-assistants-dotnet/035-assistants-dotnet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<PackageReference Include="Dapper" />
<PackageReference Include="dotenv.net" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" />
<PackageReference Include="Azure.AI.OpenAI.Assistants" />
<PackageReference Include="OpenAI" />
</ItemGroup>

</Project>
36 changes: 20 additions & 16 deletions labs/035-assistants-dotnet/Functions.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
using System.Data;
using System.Text;
using Azure.AI.OpenAI.Assistants;
using Dapper;
using Microsoft.Data.SqlClient;
using OpenAI.Assistants;
using OpenAI.Chat;

namespace AssistantsDotNet;

public static class Functions
{
public static readonly FunctionToolDefinition GetCustomersFunctionDefinition = new(
"getCustomers",
"Gets a filtered list of customers. At least one filter MUST be provided in the parameters. The result list is limited to 25 customers.",
JsonHelpers.FromObjectAsJson(
public static readonly FunctionToolDefinition GetCustomersFunctionDefinition = new()
{
FunctionName = "getCustomers",
Description = "Gets a filtered list of customers. At least one filter MUST be provided in the parameters. The result list is limited to 25 customers.",
Parameters = JsonHelpers.FromObjectAsJson(
new
{
Type = "object",
Expand All @@ -25,7 +27,7 @@ public static class Functions
},
Required = Array.Empty<string>()
})
);
};

public class GetCustomersParameters
{
Expand Down Expand Up @@ -88,10 +90,11 @@ public static async Task<IEnumerable<Customer>> GetCustomers(SqlConnection conne
return result;
}

public static readonly FunctionToolDefinition GetProductsFunctionDefinition = new(
"getProducts",
"Gets a filtered list of products. At least one filter MUST be provided in the parameters. The result list is limited to 25 products.",
JsonHelpers.FromObjectAsJson(
public static readonly FunctionToolDefinition GetProductsFunctionDefinition = new()
{
FunctionName = "getProducts",
Description = "Gets a filtered list of products. At least one filter MUST be provided in the parameters. The result list is limited to 25 products.",
Parameters = JsonHelpers.FromObjectAsJson(
new
{
Type = "object",
Expand All @@ -103,7 +106,7 @@ public static async Task<IEnumerable<Customer>> GetCustomers(SqlConnection conne
},
Required = Array.Empty<string>()
})
);
};

public class GetProductsParameters
{
Expand Down Expand Up @@ -153,10 +156,11 @@ public static async Task<IEnumerable<Product>> GetProducts(SqlConnection connect
return result;
}

public static readonly FunctionToolDefinition GetCustomerProductsRevenueFunctionDefinition = new(
"getCustomerProductsRevenue",
"Gets the revenue of the customer and products. The result is ordered by the revenue in descending order. The result list is limited to 25 records.",
JsonHelpers.FromObjectAsJson(
public static readonly FunctionToolDefinition GetCustomerProductsRevenueFunctionDefinition = new()
{
FunctionName = "getCustomerProductsRevenue",
Description = "Gets the revenue of the customer and products. The result is ordered by the revenue in descending order. The result list is limited to 25 records.",
Parameters = JsonHelpers.FromObjectAsJson(
new
{
Type = "object",
Expand All @@ -173,7 +177,7 @@ public static async Task<IEnumerable<Product>> GetProducts(SqlConnection connect
},
Required = Array.Empty<string>()
})
);
};

public class GetCustomerProductsRevenueParameters
{
Expand Down
122 changes: 44 additions & 78 deletions labs/035-assistants-dotnet/OpenAIExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,126 +1,92 @@
using Azure.AI.OpenAI.Assistants;
using OpenAI;
using OpenAI.Assistants;
#pragma warning disable OPENAI001

namespace AssistantsDotNet;

static class OpenAIExtensions
{
public static async Task<Assistant?> FindAssistantByName(this AssistantsClient client, string name)
public static async Task<Assistant?> FindAssistantByName(this AssistantClient client, string name)
{
PageableList<Assistant> assistants;
string? after = null;
do
await foreach (var assistant in client.GetAssistantsAsync())
{
assistants = await client.GetAssistantsAsync(after: after);
foreach (var assistant in assistants)
{
if (assistant.Name == name) { return assistant; }
}

after = assistants.LastId;
if (assistant.Name == name) { return assistant; }
}
while (assistants.HasMore);

return null;
}

public static async Task<Assistant> CreateOrUpdate(this AssistantsClient client, AssistantCreationOptions assistant)
public static async Task<Assistant> CreateOrUpdate(this AssistantClient client, string model, AssistantCreationOptions assistant)
{
var existing = await client.FindAssistantByName(assistant.Name);
if (existing != null)
{
var updateOptions = new UpdateAssistantOptions()
var updateOptions = new AssistantModificationOptions
{
Model = assistant.Model,
Model = model,
Name = assistant.Name,
Description = assistant.Description,
Instructions = assistant.Instructions,
Metadata = assistant.Metadata
DefaultTools = assistant.Tools,
};
foreach (var tool in assistant.Tools) { updateOptions.Tools.Add(tool); }
foreach (var fileId in assistant.FileIds) { updateOptions.FileIds.Add(fileId); }

return await client.UpdateAssistantAsync(existing.Id, updateOptions);
return await client.ModifyAssistantAsync(existing.Id, updateOptions);
}

return await client.CreateAssistantAsync(assistant);
return await client.CreateAssistantAsync(model, assistant);
}

public static async Task<ThreadRun> AddMessageAndRunToCompletion(this AssistantsClient client, string threadId, string assistantId,
string message, Func<RunStepFunctionToolCall, Task<object>>? functionCallback = null)
public static async IAsyncEnumerable<string> AddMessageAndRunToCompletion(this AssistantClient client, string threadId, string assistantId,
string message, Func<RequiredActionUpdate, Task<object>>? functionCallback = null)
{
await client.CreateMessageAsync(threadId, MessageRole.User, message);
var run = await client.CreateRunAsync(threadId, new CreateRunOptions(assistantId));
Console.WriteLine($"Run created { run.Value.Id }");
await client.CreateMessageAsync(threadId, [message]);
var asyncUpdate = client.CreateRunStreamingAsync(threadId, assistantId);

while (run.Value.Status == RunStatus.Queued || run.Value.Status == RunStatus.InProgress || run.Value.Status == RunStatus.Cancelling || run.Value.Status == RunStatus.RequiresAction)
ThreadRun? currentRun;
do
{
Console.WriteLine($"Run status { run.Value.Status }");

var steps = await client.GetRunStepsAsync(run, 1, ListSortOrder.Descending);

// If last step is a code interpreter call, log it (including generated Python code)
if (steps.Value.Any() && steps.Value.First().StepDetails is RunStepToolCallDetails toolCallDetails)
currentRun = null;
List<ToolOutput> outputsToSumit = [];
await foreach (var update in asyncUpdate)
{
foreach(var call in toolCallDetails.ToolCalls)
if (update is RunUpdate runUpdate) { currentRun = runUpdate; }
else if (update is RequiredActionUpdate requiredActionUpdate && functionCallback != null)
{
if (call is RunStepCodeInterpreterToolCall codeInterpreterToolCall && !string.IsNullOrEmpty(codeInterpreterToolCall.Input))
Console.WriteLine($"Calling function {requiredActionUpdate.ToolCallId} {requiredActionUpdate.FunctionName} {requiredActionUpdate.FunctionArguments}");

string functionResponse;
try
{
Console.WriteLine($"Code Interpreter Tool Call: {codeInterpreterToolCall.Input}");
var result = await functionCallback(requiredActionUpdate);
functionResponse = JsonHelpers.Serialize(result);
}
}
}

// Check if the run requires us to execute a function
if (run.Value.Status == RunStatus.RequiresAction && functionCallback != null)
{
var toolOutput = new List<ToolOutput>();
if (steps.Value.First().StepDetails is RunStepToolCallDetails stepDetails)
{
foreach(var call in stepDetails.ToolCalls.OfType<RunStepFunctionToolCall>())
catch (Exception ex)
{
Console.WriteLine($"Calling function { call.Id } { call.Name } { call.Arguments }");

string functionResponse;
try
{
var result = await functionCallback(call);
functionResponse = JsonHelpers.Serialize(result);
}
catch (Exception ex)
{
Console.WriteLine($"Function call failed, returning error message to ChatGPT { call.Name } { ex.Message }");
functionResponse = ex.Message;
}

toolOutput.Add(new()
{
ToolCallId = call.Id,
Output = functionResponse
});
Console.WriteLine($"Function call failed, returning error message to ChatGPT {requiredActionUpdate.FunctionName} {ex.Message}");
functionResponse = ex.Message;
}
}

if (toolOutput.Count != 0)
outputsToSumit.Add(new ToolOutput(requiredActionUpdate.ToolCallId, functionResponse));
}
else if (update is MessageContentUpdate contentUpdate)
{
run = await client.SubmitToolOutputsToRunAsync(threadId, run.Value.Id, toolOutput);
yield return contentUpdate.Text;
}
}


await Task.Delay(1000);
run = await client.GetRunAsync(threadId, run.Value.Id);
if (outputsToSumit.Count != 0)
{
asyncUpdate = client.SubmitToolOutputsToRunStreamingAsync(currentRun, outputsToSumit);
}
}

Console.WriteLine($"Final run status { run.Value.Status }");
return run;
while (currentRun?.Status.IsTerminal is false);
}

public static async Task<string?> GetLatestMessage(this AssistantsClient client, string threadId)
public static async Task<string?> GetLatestMessage(this AssistantClient client, string threadId)
{
var messages = await client.GetMessagesAsync(threadId, 1, ListSortOrder.Descending);
if (messages.Value.FirstOrDefault()?.ContentItems[0] is MessageTextContent tc)
await foreach(var msg in client.GetMessagesAsync(threadId, ListOrder.NewestFirst))
{
return tc.Text;
return msg.Content[0].Text;
}

return null;
Expand Down
31 changes: 17 additions & 14 deletions labs/035-assistants-dotnet/Program.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using AssistantsDotNet;
using Azure.AI.OpenAI.Assistants;
using dotenv.net;
using Microsoft.Data.SqlClient;
using OpenAI.Assistants;
#pragma warning disable OPENAI001

// Get environment variables from .env file. We have to go up 7 levels to get to the root of the
// git repository (because of bin/Debug/net8.0 folder).
Expand All @@ -14,9 +15,9 @@
// In this sample, we use key-based authentication. This is only done because this sample
// will be done by a larger group in a hackathon event. In real world, AVOID key-based
// authentication. ALWAYS prefer Microsoft Entra-based authentication (Managed Identity)!
var client = new AssistantsClient(env["OPENAI_KEY"]);
var client = new AssistantClient(env["OPENAI_KEY"]);

var assistant = await client.CreateOrUpdate(new(env["OPENAI_MODEL"])
var assistant = await client.CreateOrUpdate(env["OPENAI_MODEL"], new AssistantCreationOptions
{
Name = "Revenue Analyzer",
Description = "Retrieves customer and product revenue and analyzes it using code interpreter",
Expand Down Expand Up @@ -64,24 +65,26 @@ to the required data.
userMessage = options[selection - 1];
}

var run = await client.AddMessageAndRunToCompletion(thread.Value.Id, assistant.Id, userMessage, async functionCall =>
var first = true;
await foreach (var message in client.AddMessageAndRunToCompletion(thread.Value.Id, assistant.Id, userMessage, async functionCall =>
{
switch (functionCall.Name)
switch (functionCall.FunctionName)
{
case "getCustomers":
return await Functions.GetCustomers(sqlConnection, JsonHelpers.Deserialize<Functions.GetCustomersParameters>(functionCall.Arguments)!);
return await Functions.GetCustomers(sqlConnection, JsonHelpers.Deserialize<Functions.GetCustomersParameters>(functionCall.FunctionArguments)!);
case "getProducts":
return await Functions.GetProducts(sqlConnection, JsonHelpers.Deserialize<Functions.GetProductsParameters>(functionCall.Arguments)!);
return await Functions.GetProducts(sqlConnection, JsonHelpers.Deserialize<Functions.GetProductsParameters>(functionCall.FunctionArguments)!);
case "getCustomerProductsRevenue":
return await Functions.GetCustomerProductsRevenue(sqlConnection, JsonHelpers.Deserialize<Functions.GetCustomerProductsRevenueParameters>(functionCall.Arguments)!);
return await Functions.GetCustomerProductsRevenue(sqlConnection, JsonHelpers.Deserialize<Functions.GetCustomerProductsRevenueParameters>(functionCall.FunctionArguments)!);
default:
throw new Exception($"Function {functionCall.Name} is not supported");
throw new Exception($"Function {functionCall.FunctionName} is not supported");
}
});

if (run.Status == "completed")
}))
{
var lastMessage = await client.GetLatestMessage(thread.Value.Id);
Console.WriteLine($"\n🤖: {lastMessage}");
if (first) { Console.Write($"\n🤖: "); first = false; }
Console.Write(message);
}

Console.WriteLine();
//var lastMessage = await client.GetLatestMessage(thread.Value.Id);
}
1 change: 1 addition & 0 deletions labs/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
<PackageVersion Include="Dapper" Version="2.1.35" />
<PackageVersion Include="dotenv.net" Version="3.1.3" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.3" />
<PackageVersion Include="OpenAI" Version="2.0.0-beta.3" />
</ItemGroup>
</Project>

0 comments on commit f11e034

Please sign in to comment.