Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified OpenAI Provisioning #47174

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public static partial class AzureOpenAIExtensions
{
public static void Add(this System.Collections.Generic.List<OpenAI.Chat.ChatMessage> messages, OpenAI.Chat.ChatCompletion completion) { }
public static void Add(this System.Collections.Generic.List<OpenAI.Chat.ChatMessage> messages, System.Collections.Generic.IEnumerable<Azure.CloudMachine.OpenAI.VectorbaseEntry> entries) { }
public static string AsText(this OpenAI.Chat.ChatCompletion completion) { throw null; }
public static string AsText(this OpenAI.Chat.ChatMessageContent completion) { throw null; }
public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static void Trim(this System.Collections.Generic.List<OpenAI.Chat.ChatMessage> messages) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ public override ClientConnectionOptions GetConnectionOptions(Type clientType, st
case "Azure.AI.OpenAI.AzureOpenAIClient":
return new ClientConnectionOptions(new($"https://{Id}.openai.azure.com"), Credential);
case "OpenAI.Chat.ChatClient":
return new ClientConnectionOptions(Id);
return new ClientConnectionOptions($"{Id}_chat");
case "OpenAI.Embeddings.EmbeddingClient":
return new ClientConnectionOptions($"{Id}-embedding");
return new ClientConnectionOptions($"{Id}_embedding");
default:
throw new Exception($"unknown client {clientId}");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System.ClientModel;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using Azure.AI.OpenAI;
using Azure.Core;
using OpenAI.Chat;
Expand Down Expand Up @@ -47,6 +49,37 @@ public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace wor
return embeddingsClient;
}

/// <summary>
/// returns full text of all parts.
/// </summary>
/// <param name="completion"></param>
/// <returns></returns>
public static string AsText(this ChatCompletion completion)
=> completion.Content.AsText();

/// <summary>
/// returns full text of all parts.
/// </summary>
/// <param name="completion"></param>
/// <returns></returns>
public static string AsText(this ChatMessageContent completion)
{
StringBuilder sb = new();
foreach (ChatMessageContentPart part in completion)
{
switch (part.Kind)
{
case ChatMessageContentPartKind.Text:
sb.AppendLine(part.Text);
break;
default:
sb.AppendLine($"<{part.Kind}>");
break;
}
}
return sb.ToString();
}

private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace)
{
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient));
Expand Down
54 changes: 17 additions & 37 deletions sdk/cloudmachine/Azure.CloudMachine/tests/CloudMachineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,21 @@
using Microsoft.Extensions.Primitives;
using System.Linq;
using System.IO;
using Azure.Provisioning;

namespace Azure.CloudMachine.Tests;

public class CloudMachineTests
{
[Test]
[TestCase([new string[] { "-bicep" }])]
[TestCase([new string[] { "" }])]
public void Provisioning(string[] args)
public void Configuration()
{
if (CloudMachineInfrastructure.Configure(args, (cm) =>
CloudMachineCommands.Execute(["-bicep"], (infrastructure) =>
{
cm.AddFeature(new KeyVaultFeature());
cm.AddFeature(new OpenAIFeature() // TODO: rework it such that models can be added as features
{
Chat = new AIModel("gpt-35-turbo", "0125"),
Embeddings = new AIModel("text-embedding-ada-002", "2")
});
}))
return;
infrastructure.AddFeature(new KeyVaultFeature());
infrastructure.AddFeature(new OpenAIModel("gpt-35-turbo", "0125"));
infrastructure.AddFeature(new OpenAIModel("text-embedding-ada-002", "2", AIModelKind.Embedding));
}, exitProcessIfHandled: false);

CloudMachineWorkspace cm = new();
Console.WriteLine(cm.Id);
Expand All @@ -47,12 +42,9 @@ public void Provisioning(string[] args)
[TestCase([new string[] { "" }])]
public void Storage(string[] args)
{
ManualResetEventSlim eventSlim = new(false);
if (CloudMachineInfrastructure.Configure(args, (cm) =>
{
}))
return;
if (CloudMachineCommands.Execute(args, exitProcessIfHandled: false)) return;

ManualResetEventSlim eventSlim = new(false);
CloudMachineClient cm = new();

cm.Storage.WhenUploaded((StorageFile file) =>
Expand All @@ -78,24 +70,15 @@ public void Storage(string[] args)
[TestCase([new string[] { "" }])]
public void OpenAI(string[] args)
{
if (CloudMachineInfrastructure.Configure(args, (cm) =>
if (CloudMachineCommands.Execute(args, (infrastructure) =>
{
cm.AddFeature(new OpenAIFeature()
{
Chat = new AIModel("gpt-35-turbo", "0125")
});
}))
return;
infrastructure.AddFeature(new OpenAIModel("gpt-35-turbo", "0125"));
}, exitProcessIfHandled: false)) return;

CloudMachineWorkspace cm = new();
ChatClient chat = cm.GetOpenAIChatClient();
ChatCompletion completion = chat.CompleteChat("Is Azure programming easy?");

ChatMessageContent content = completion.Content;
foreach (ChatMessageContentPart part in content)
{
Console.WriteLine(part.Text);
}
Console.WriteLine(completion.AsText());
}

[Ignore("no recordings yet")]
Expand All @@ -104,11 +87,10 @@ public void OpenAI(string[] args)
[TestCase([new string[] { "" }])]
public void KeyVault(string[] args)
{
if (CloudMachineInfrastructure.Configure(args, (cm) =>
if (CloudMachineCommands.Execute(args, (cm) =>
{
cm.AddFeature(new KeyVaultFeature());
}))
return;
}, exitProcessIfHandled: false)) return;

CloudMachineWorkspace cm = new();
SecretClient secrets = cm.GetKeyVaultSecretsClient();
Expand All @@ -121,8 +103,7 @@ public void KeyVault(string[] args)
[TestCase([new string[] { "" }])]
public void Messaging(string[] args)
{
if (CloudMachineInfrastructure.Configure(args))
return;
CloudMachineCommands.Execute(args);

CloudMachineClient cm = new();
cm.Messaging.WhenMessageReceived(message =>
Expand All @@ -143,8 +124,7 @@ public void Messaging(string[] args)
[TestCase([new string[] { "" }])]
public void Demo(string[] args)
{
if (CloudMachineInfrastructure.Configure(args))
return;
if (CloudMachineCommands.Execute(args, exitProcessIfHandled: false)) return;

CloudMachineClient cm = new();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
namespace Azure
{
public partial class RestCallFailedException : System.Exception
{
public RestCallFailedException(string message, System.ClientModel.Primitives.PipelineResponse response) { }
}
public partial class RestClient
{
public RestClient() { }
public RestClient(System.ClientModel.Primitives.PipelinePolicy auth) { }
public static Azure.RestClient Shared { get { throw null; } }
public System.ClientModel.Primitives.PipelineMessage Create(string method, System.Uri uri) { throw null; }
public System.ClientModel.Primitives.PipelineResponse Get(string uri, System.ClientModel.Primitives.RequestOptions? options = null) { throw null; }
public System.ClientModel.Primitives.PipelineResponse Patch(string uri, System.ClientModel.BinaryContent content, System.ClientModel.Primitives.RequestOptions? options = null) { throw null; }
public System.ClientModel.Primitives.PipelineResponse Post(string uri, System.ClientModel.BinaryContent content, System.ClientModel.Primitives.RequestOptions? options = null) { throw null; }
public System.ClientModel.Primitives.PipelineResponse Put(string uri, System.ClientModel.BinaryContent content, System.ClientModel.Primitives.RequestOptions? options = null) { throw null; }
public System.ClientModel.Primitives.PipelineResponse Send(System.ClientModel.Primitives.PipelineMessage message, System.ClientModel.Primitives.RequestOptions? options = null) { throw null; }
}
public partial class RestClientOptions : System.ClientModel.Primitives.ClientPipelineOptions
{
public RestClientOptions() { }
}
}
namespace Azure.CloudMachine
{
public partial class CloudMachineCommands
{
public CloudMachineCommands() { }
public static bool Execute(string[] args, System.Action<Azure.CloudMachine.CloudMachineInfrastructure>? configure = null, bool exitProcessIfHandled = true) { throw null; }
}
public partial class CloudMachineInfrastructure
{
public CloudMachineInfrastructure(string cmId) { }
public Azure.CloudMachine.FeatureCollection Features { get { throw null; } }
public string Id { get { throw null; } }
public Azure.Provisioning.Roles.UserAssignedIdentity Identity { get { throw null; } }
public Azure.Provisioning.ProvisioningParameter PrincipalIdParameter { get { throw null; } }
public void AddEndpoints<T>() { }
public void AddFeature(Azure.Provisioning.CloudMachine.CloudMachineFeature resource) { }
public void AddFeature(Azure.Provisioning.CloudMachine.CloudMachineFeature feature) { }
public void AddResource(Azure.Provisioning.Primitives.NamedProvisionableConstruct resource) { }
public Azure.Provisioning.ProvisioningPlan Build(Azure.Provisioning.ProvisioningBuildOptions? context = null) { throw null; }
public static bool Configure(string[] args, System.Action<Azure.CloudMachine.CloudMachineInfrastructure>? configure = null) { throw null; }
}
public partial class FeatureCollection
{
public FeatureCollection() { }
public System.Collections.Generic.IEnumerable<T> FindAll<T>() where T : Azure.Provisioning.CloudMachine.CloudMachineFeature { throw null; }
}
}
namespace Azure.CloudMachine.KeyVault
Expand All @@ -19,23 +52,23 @@ public partial class KeyVaultFeature : Azure.Provisioning.CloudMachine.CloudMach
{
public KeyVaultFeature(Azure.Provisioning.KeyVault.KeyVaultSku? sku = null) { }
public Azure.Provisioning.KeyVault.KeyVaultSku Sku { get { throw null; } set { } }
public override void AddTo(Azure.CloudMachine.CloudMachineInfrastructure infrastructure) { }
protected override Azure.Provisioning.Primitives.ProvisionableResource EmitCore(Azure.CloudMachine.CloudMachineInfrastructure infrastructure) { throw null; }
}
}
namespace Azure.CloudMachine.OpenAI
{
public partial class AIModel
public enum AIModelKind
{
public AIModel(string model, string modelVersion) { }
public string Model { get { throw null; } }
public string ModelVersion { get { throw null; } }
Chat = 0,
Embedding = 1,
}
public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
public partial class OpenAIModel : Azure.Provisioning.CloudMachine.CloudMachineFeature
{
public OpenAIFeature() { }
public Azure.CloudMachine.OpenAI.AIModel? Chat { get { throw null; } set { } }
public Azure.CloudMachine.OpenAI.AIModel? Embeddings { get { throw null; } set { } }
public override void AddTo(Azure.CloudMachine.CloudMachineInfrastructure cloudMachine) { }
public OpenAIModel(string model, string modelVersion, Azure.CloudMachine.OpenAI.AIModelKind kind = Azure.CloudMachine.OpenAI.AIModelKind.Chat) { }
public string Model { get { throw null; } }
public string ModelVersion { get { throw null; } }
public override void AddTo(Azure.CloudMachine.CloudMachineInfrastructure cm) { }
protected override Azure.Provisioning.Primitives.ProvisionableResource EmitCore(Azure.CloudMachine.CloudMachineInfrastructure cm) { throw null; }
}
}
namespace Azure.Provisioning.CloudMachine
Expand All @@ -44,7 +77,12 @@ public abstract partial class CloudMachineFeature
{
protected CloudMachineFeature() { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public abstract void AddTo(Azure.CloudMachine.CloudMachineInfrastructure cm);
public Azure.Provisioning.Primitives.ProvisionableResource Emited { get { throw null; } protected set { } }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public virtual void AddTo(Azure.CloudMachine.CloudMachineInfrastructure cm) { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public void Emit(Azure.CloudMachine.CloudMachineInfrastructure cm) { }
protected abstract Azure.Provisioning.Primitives.ProvisionableResource EmitCore(Azure.CloudMachine.CloudMachineInfrastructure cm);
}
}
namespace System.ClientModel.TypeSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
<LangVersion>12</LangVersion>

<!-- Disable warning CS1591: Missing XML comment for publicly visible type or member -->
<NoWarn>CS1591</NoWarn>
<NoWarn>CS1591;AZC0007</NoWarn>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Azure.Provisioning" />
<PackageReference Include="Azure.Provisioning.KeyVault" />
<PackageReference Include="Azure.Provisioning.Storage" />
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Azure.Provisioning.CloudMachine;
using Azure.Provisioning.Expressions;
using Azure.Provisioning.KeyVault;
using Azure.Provisioning.Primitives;

namespace Azure.CloudMachine.KeyVault;

Expand All @@ -20,7 +21,7 @@ public KeyVaultFeature(KeyVaultSku? sku = default)
}
Sku = sku;
}
public override void AddTo(CloudMachineInfrastructure infrastructure)
protected override ProvisionableResource EmitCore(CloudMachineInfrastructure infrastructure)
{
// Add a KeyVault to the CloudMachine infrastructure.
KeyVaultService keyVaultResource = new("cm_kv")
Expand Down Expand Up @@ -57,5 +58,7 @@ public override void AddTo(CloudMachineInfrastructure infrastructure)
kvMiRoleAssignment.RoleDefinitionId = BicepFunction.GetSubscriptionResourceId("Microsoft.Authorization/roleDefinitions", KeyVaultBuiltInRole.KeyVaultAdministrator.ToString());
kvMiRoleAssignment.PrincipalId = infrastructure.Identity.PrincipalId;
infrastructure.AddResource(kvMiRoleAssignment);

return keyVaultResource;
}
}
Loading
Loading