Skip to content

Commit

Permalink
Update to OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
rstropek committed Jun 11, 2024
1 parent 7f8e9ce commit 26947bd
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="OpenAI" />
<PackageReference Include="dotenv.net" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" />
</ItemGroup>
Expand Down
219 changes: 98 additions & 121 deletions labs/020-functions-dotnet/FunctionCallingDotNet/Program.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Azure;
using Azure.AI.OpenAI;
using dotenv.net;
using FunctionCallingDotNet;
using Microsoft.EntityFrameworkCore;
using OpenAI.Chat;

// Get environment variables from .env file. We have to go up 7 levels to get to the root of the
// git repository (because of bin/Debug/net8.0 folder).
Expand All @@ -18,15 +16,12 @@
// In this sample, we use key-based authentication. This is only done because this sample
// will be done by a larger group in a hackathon event. In real world, AVOID key-based
// authentication. ALWAYS prefer Microsoft Entra-based authentication (Managed Identity)!
var client = new OpenAIClient(
new Uri(env["OPENAI_AZURE_ENDPOINT"]),
new AzureKeyCredential(env["OPENAI_AZURE_KEY"]));
var client = new ChatClient("gpt-4o", env["OPENAI_KEY"]);

var chatCompletionOptions = new ChatCompletionsOptions(
env["OPENAI_AZURE_DEPLOYMENT"],
List<ChatMessage> messages =
[
// System prompt
new ChatRequestSystemMessage("""
new SystemChatMessage("""
You are an assistant supporting business users who need to analyze the revene of
customers and products. Use the provided function tools to access the order database
and answer the user's questions.
Expand All @@ -40,24 +35,23 @@ tell her or him that you cannot answer the question because of a lack of access
to the required data.
"""),
// Initial assistant message to get the conversation started
new ChatRequestAssistantMessage("""
new AssistantChatMessage("""
Hi! I can help you with questions about customer and product revenue. What would you like to know?
"""),
]
)
];

ChatCompletionOptions options = new()
{
// Define the tool functions that can be called from the assistant
Tools =
{
new ChatCompletionsFunctionToolDefinition(
new FunctionDefinition()
{
Name = "getCustomers",
Description = """
ChatTool.CreateFunctionTool(
functionName: "getCustomers",
functionDescription: """
Gets a filtered list of customers. At least one filter MUST be provided in
the parameters. The result list is limited to 25 customer.
""",
Parameters = BinaryData.FromObjectAsJson(
functionParameters: BinaryData.FromObjectAsJson(
new
{
Type = "object",
Expand Down Expand Up @@ -91,16 +85,14 @@ the parameters. The result list is limited to 25 customer.
},
Required = Array.Empty<string>()
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
}),
new ChatCompletionsFunctionToolDefinition(
new FunctionDefinition()
{
Name = "getProducts",
Description = """
),
ChatTool.CreateFunctionTool(
functionName: "getProducts",
functionDescription: """
Gets a filtered list of products. At least one filter MUST be
provided in the parameters. The result list is limited to 25 customer.
""",
Parameters = BinaryData.FromObjectAsJson(
functionParameters: BinaryData.FromObjectAsJson(
new
{
Type = "object",
Expand Down Expand Up @@ -129,15 +121,13 @@ provided in the parameters. The result list is limited to 25 customer.
},
Required = Array.Empty<string>()
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
}),
new ChatCompletionsFunctionToolDefinition(
new FunctionDefinition()
{
Name = "getTopCustomers",
Description = """
),
ChatTool.CreateFunctionTool(
functionName: "getTopCustomers",
functionDescription: """
Gets the customers with their revenue sorted by revenue in descending order.
""",
Parameters = BinaryData.FromObjectAsJson(
functionParameters: BinaryData.FromObjectAsJson(
new
{
Type = "object",
Expand All @@ -156,16 +146,14 @@ Gets the customers with their revenue sorted by revenue in descending order.
},
Required = Array.Empty<string>()
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
}),
new ChatCompletionsFunctionToolDefinition(
new FunctionDefinition()
{
Name = "getCustomerRevenueTrend",
Description = """
),
ChatTool.CreateFunctionTool(
functionName: "getCustomerRevenueTrend",
functionDescription: """
Gets the total revenue for a given customer per year, and month.
Use this function to analyze the revenue trend of a specific customer.
""",
Parameters = BinaryData.FromObjectAsJson(
functionParameters: BinaryData.FromObjectAsJson(
new
{
Type = "object",
Expand All @@ -179,16 +167,14 @@ Use this function to analyze the revenue trend of a specific customer.
},
Required = new[] { "customerID" }
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
}),
new ChatCompletionsFunctionToolDefinition(
new FunctionDefinition()
{
Name = "getCustomerProductBreakdown",
Description = """
),
ChatTool.CreateFunctionTool(
functionName: "getCustomerProductBreakdown",
functionDescription: """
Gets the total revenue for a given customer per product. Use this function
to analyze the revenue breakdown of a specific customer.
""",
Parameters = BinaryData.FromObjectAsJson(
functionParameters: BinaryData.FromObjectAsJson(
new
{
Type = "object",
Expand All @@ -202,121 +188,112 @@ to analyze the revenue breakdown of a specific customer.
},
Required = new[] { "customerID" }
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
})
)
},
};


while (true)
{
// Display the last message from the assistant
if (chatCompletionOptions.Messages.Last() is ChatRequestAssistantMessage am)
Console.WriteLine($"🤖: {messages.Last().Content[0].Text}");

// Ask the user for a message. Exit program in case of empty message.
string[] prompts =
[
"I am going to visit Carolyn Farino tomorrow. Tell me something about her and the products that she usually buys.",
"Did she ever buy a headset?",
"Give me a table by year and month of her revenues."
];
Console.WriteLine("\n");
for (int i = 0; i < prompts.Length; i++)
{
Console.WriteLine($"🤖: {am.Content}");
Console.WriteLine($"{i + 1}: {prompts[i]}");
}

// Ask the user for a message. Exit program in case of empty message.
Console.Write("\nYou (just press enter to exit the conversation): ");
var userMessage = Console.ReadLine();
if (string.IsNullOrEmpty(userMessage)) { break; }
if (int.TryParse(userMessage, out int selection) && selection >= 1 && selection <= prompts.Length)
{
userMessage = prompts[selection - 1];
}

// Add the user message to the list of messages to send to the API
chatCompletionOptions.Messages.Add(new ChatRequestUserMessage(userMessage));
messages.Add(new UserChatMessage(userMessage));

bool repeat;
bool requiresAction;
do
{
// If the last message from the assistant was a tool call, we need to
// add the tool call result to the list of messages to send to the API.
// We also need to repeat the call to the API to get the next message
// from the assistant. The next message could be another tool call.
// We have to repeat that process until the assistant sends a message
// that is not a tool call.
repeat = false;
requiresAction = false;

// Send the messages to the API and wait for the response. Display a
// waiting indicator while waiting for the response.
Console.Write("\nThinking...");
var chatTask = client.GetChatCompletionsAsync(chatCompletionOptions);
while (!chatTask.IsCompleted)
{
Console.Write(".");
await Task.Delay(1000);
}
ChatCompletion chatCompletion = await client.CompleteChatAsync(messages, options);

Console.WriteLine("\n");
var response = await chatTask;
if (response.GetRawResponse().IsError)
switch (chatCompletion.FinishReason)
{
Console.WriteLine($"Error: {response.GetRawResponse().ReasonPhrase}");
break;
}

// Add the response from the API to the list of messages to send to the API
chatCompletionOptions.Messages.Add(new ChatRequestAssistantMessage(response.Value.Choices[0].Message));
case ChatFinishReason.Stop:
messages.Add(new AssistantChatMessage(chatCompletion.Content[0].Text));
break;
case ChatFinishReason.ToolCalls:
{
messages.Add(new AssistantChatMessage(chatCompletion));
foreach (var toolCall in chatCompletion.ToolCalls)
{
Console.WriteLine($"\tExecuting tool {toolCall.FunctionName} with arguments {toolCall.FunctionArguments}.");
ToolChatMessage result;
switch (toolCall.FunctionName)
{
case "getCustomers":
result = await ExecuteQuery<CustomerFilter, Customer>(context, toolCall, context.GetCustomers);
break;

if (response.Value.Choices[0].Message.ToolCalls.Any())
{
// We have a tool call
case "getProducts":
result = await ExecuteQuery<ProductFilter, Product>(context, toolCall, context.GetProducts);
break;

foreach (var toolCall in response.Value.Choices[0].Message.ToolCalls.OfType<ChatCompletionsFunctionToolCall>())
{
Console.WriteLine($"\tExecuting tool {toolCall.Name} with arguments {toolCall.Arguments}.");
ChatRequestToolMessage result;
switch (toolCall.Name)
{
case "getCustomers":
result = await ExecuteQuery<CustomerFilter, Customer>(context, toolCall, context.GetCustomers);
break;
case "getTopCustomers":
result = await ExecuteQuery<TopCustomerFilter, TopCustomerResult>(context, toolCall, context.GetTopCustomers);
break;

case "getProducts":
result = await ExecuteQuery<ProductFilter, Product>(context, toolCall, context.GetProducts);
break;
case "getCustomerRevenueTrend":
result = await ExecuteQuery<CustomerDetailStatsFilter, CustomerRevenueTrendResult>(context, toolCall, context.GetCustomerRevenueTrend);
break;

case "getTopCustomers":
result = await ExecuteQuery<TopCustomerFilter, TopCustomerResult>(context, toolCall, context.GetTopCustomers);
break;
case "getCustomerProductBreakdown":
result = await ExecuteQuery<CustomerDetailStatsFilter, CustomerProductBreakdownResult>(context, toolCall, context.GetCustomerProductBreakdown);
break;

case "getCustomerRevenueTrend":
result = await ExecuteQuery<CustomerDetailStatsFilter, CustomerRevenueTrendResult>(context, toolCall, context.GetCustomerRevenueTrend);
break;
default:
throw new InvalidOperationException($"Tool {toolCall.FunctionName} does not exist.");
}

case "getCustomerProductBreakdown":
result = await ExecuteQuery<CustomerDetailStatsFilter, CustomerProductBreakdownResult>(context, toolCall, context.GetCustomerProductBreakdown);
break;
messages.Add(result);
}

default:
throw new InvalidOperationException($"Tool {toolCall.Name} does not exist.");
requiresAction = true;
break;
}

// Add the result of the tool call to the list of messages to send to the API
chatCompletionOptions.Messages.Add(result);
repeat = true;
}
default:
throw new NotImplementedException();
}
else
{
// We don't have a tool call. Add the response from the API to the list of messages to send to the API
chatCompletionOptions.Messages.Add(new ChatRequestAssistantMessage(response.Value.Choices[0].Message));
}
} while (repeat);

} while (requiresAction);
}

static async Task<ChatRequestToolMessage> ExecuteQuery<TFilter, TResult>(ApplicationDataContext context, ChatCompletionsFunctionToolCall toolCall, Func<TFilter, Task<TResult[]>> body)
static async Task<ToolChatMessage> ExecuteQuery<TFilter, TResult>(ApplicationDataContext context, ChatToolCall toolCall, Func<TFilter, Task<TResult[]>> body)
{
ChatRequestToolMessage result;
ToolChatMessage result;
var jsonOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
try
{
// Deserialize arguments
var filter = JsonSerializer.Deserialize<TFilter>(toolCall.Arguments, jsonOptions)!;
var filter = JsonSerializer.Deserialize<TFilter>(toolCall.FunctionArguments, jsonOptions)!;

// Get result from the database
var customers = await body(filter);
result = new ChatRequestToolMessage(JsonSerializer.Serialize(customers, jsonOptions), toolCall.Id);
result = new ToolChatMessage(toolCall.Id, JsonSerializer.Serialize(customers, jsonOptions));
}
catch (Exception ex)
{
result = new ChatRequestToolMessage(JsonSerializer.Serialize(new { Error = ex.Message }, jsonOptions), toolCall.Id);
result = new ToolChatMessage(toolCall.Id, JsonSerializer.Serialize(new { Error = ex.Message }, jsonOptions));
}

return result;
Expand Down

0 comments on commit 26947bd

Please sign in to comment.