Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions shell/agents/AIShell.OpenAI.Agent/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ internal enum EndpointType
{
AzureOpenAI,
OpenAI,
CompatibleThirdParty,
}

public class GPT
Expand Down Expand Up @@ -56,9 +57,16 @@ public GPT(
bool noDeployment = string.IsNullOrEmpty(Deployment);
Type = noEndpoint && noDeployment
? EndpointType.OpenAI
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
: !noEndpoint && noDeployment
? EndpointType.CompatibleThirdParty
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: 'Deployment' key present but 'Endpoint' key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");

if (ModelInfo is null && Type is EndpointType.CompatibleThirdParty)
{
ModelInfo = ModelInfo.ThirdPartyModel;
}
}

/// <summary>
Expand Down Expand Up @@ -142,11 +150,18 @@ private void ShowEndpointInfo(IHost host)
new(label: " Model", m => m.ModelName),
},

EndpointType.OpenAI => new CustomElement<GPT>[]
{
EndpointType.OpenAI =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Model", m => m.ModelName),
},
],

EndpointType.CompatibleThirdParty =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Endpoint", m => m.Endpoint),
new(label: " Model", m => m.ModelName),
],

_ => throw new UnreachableException(),
};
Expand Down
5 changes: 5 additions & 0 deletions shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ internal class ModelInfo
private static readonly Dictionary<string, ModelInfo> s_modelMap;
private static readonly Dictionary<string, Task<Tokenizer>> s_encodingMap;

// A rough estimate to cover all third-party models.
// - most popular models today support 32K+ context length;
// - use the gpt-4o encoding as an estimate for token count.
internal static readonly ModelInfo ThirdPartyModel = new(32_000, encoding: Gpt4oEncoding);

static ModelInfo()
{
// For reference, see https://platform.openai.com/docs/models and the "Counting tokens" section in
Expand Down
8 changes: 7 additions & 1 deletion shell/agents/AIShell.OpenAI.Agent/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ private void RefreshOpenAIClient()
return;
}

EndpointType type = _gptToUse.Type;
string userKey = Utils.ConvertFromSecureString(_gptToUse.Key);

if (_gptToUse.Type is EndpointType.AzureOpenAI)
if (type is EndpointType.AzureOpenAI)
{
// Create a client that targets Azure OpenAI service or Azure API Management service.
var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
Expand Down Expand Up @@ -152,6 +153,11 @@ private void RefreshOpenAIClient()
{
// Create a client that targets the non-Azure OpenAI service.
var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
if (type is EndpointType.CompatibleThirdParty)
{
clientOptions.Endpoint = new(_gptToUse.Endpoint);
}

var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions);
_client = aiClient.GetChatClient(_gptToUse.ModelName);
}
Expand Down
Loading