Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add image generation #510

Merged
merged 9 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public enum AgentField
Template,
Response,
Sample,
LlmConfig
LlmConfig,
Tool
}

public enum AgentTaskField
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace BotSharp.Abstraction.Agents.Enums;

public class AgentTool
{
public const string FileAnalyzer = "file-analyzer";
public const string ImageGenerator = "image-generator";
public const string HttpHandler = "http-handler";
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ public interface IAgentService
List<Agent> GetAgentsByUser(string userId);

PluginDef GetPlugin(string agentId);

IEnumerable<string> GetAgentTools();
}
13 changes: 13 additions & 0 deletions src/Infrastructure/BotSharp.Abstraction/Agents/Models/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ public class Agent
public List<string> Profiles { get; set; }
= new List<string>();

/// <summary>
/// Useful tools
/// </summary>
public List<string> Tools { get; set; }
= new List<string>();

/// <summary>
/// Inherit from agent
/// </summary>
Expand Down Expand Up @@ -121,6 +127,7 @@ public static Agent Clone(Agent agent)
Functions = agent.Functions,
Responses = agent.Responses,
Samples = agent.Samples,
Tools = agent.Tools,
Knowledges = agent.Knowledges,
IsPublic = agent.IsPublic,
Disabled = agent.Disabled,
Expand Down Expand Up @@ -162,6 +169,12 @@ public Agent SetSamples(List<string> samples)
return this;
}

public Agent SetTools(List<string> tools)
{
Tools = tools ?? new List<string>();
return this;
}

public Agent SetResponses(List<AgentResponse> responses)
{
Responses = responses ?? new List<AgentResponse>(); ;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,6 @@ Task<bool> SendMessage(string agentId,
Task<string> GetConversationSummary(IEnumerable<string> conversationId);

Task<Conversation> GetConversationRecordOrCreateNew(string agentId);

bool IsConversationMode();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace BotSharp.Abstraction.MLTasks;

public interface IImageGeneration
{
/// <summary>
/// The LLM provider like Microsoft Azure, OpenAI, ClaudAI
/// </summary>
string Provider { get; }

/// <summary>
/// Set model name, one provider can consume different model or version(s)
/// </summary>
/// <param name="model">deployment name</param>
void SetModelName(string model);

Task<RoleDialogModel> GetImageGeneration(Agent agent, List<RoleDialogModel> conversations);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ public interface ILlmProviderService
{
LlmModelSetting GetSetting(string provider, string model);
List<string> GetProviders();
LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null);
LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null, bool imageGenerate = false);
List<LlmModelSetting> GetProviderModels(string provider);
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public class LlmModelSetting
/// </summary>
public bool MultiModal { get; set; }

/// <summary>
/// If true, allow generating images
/// </summary>
public bool ImageGeneration { get; set; }

/// <summary>
/// Prompt cost per 1K token
/// </summary>
Expand All @@ -51,5 +56,6 @@ public override string ToString()
public enum LlmModelType
{
Text = 1,
Chat = 2
Chat = 2,
Image = 3
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public partial class AgentService
[MemoryCache(10 * 60, perInstanceCache: true)]
public async Task<Agent> LoadAgent(string id)
{
if (string.IsNullOrEmpty(id) || id == Guid.Empty.ToString())
if (string.IsNullOrEmpty(id))
{
return null;
}
Expand All @@ -28,7 +28,7 @@ public async Task<Agent> LoadAgent(string id)
var agent = await GetAgent(id);
if (agent == null)
{
throw new Exception($"Can't load agent by id: {id}");
return null;
}

if (agent.InheritAgentId != null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using BotSharp.Abstraction.Agents;
using BotSharp.Abstraction.Repositories.Enums;
using BotSharp.Abstraction.Routing.Models;
using BotSharp.Abstraction.Users.Enums;
using Microsoft.EntityFrameworkCore.Metadata;
using System.IO;

namespace BotSharp.Core.Agents.Services;
Expand Down Expand Up @@ -34,6 +32,7 @@ public async Task UpdateAgent(Agent agent, AgentField updateField)
record.Templates = agent.Templates ?? new List<AgentTemplate>();
record.Responses = agent.Responses ?? new List<AgentResponse>();
record.Samples = agent.Samples ?? new List<string>();
record.Tools = agent.Tools ?? new List<string>();
if (agent.LlmConfig != null && !agent.LlmConfig.IsInherit)
{
record.LlmConfig = agent.LlmConfig;
Expand Down Expand Up @@ -95,6 +94,7 @@ public async Task<string> UpdateAgentFromFile(string id)
.SetFunctions(foundAgent.Functions)
.SetResponses(foundAgent.Responses)
.SetSamples(foundAgent.Samples)
.SetTools(foundAgent.Tools)
.SetLlmConfig(foundAgent.LlmConfig);

_db.UpdateAgent(clonedAgent, AgentField.All);
Expand Down
11 changes: 11 additions & 0 deletions src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.IO;
using System.Reflection;

namespace BotSharp.Core.Agents.Services;

Expand Down Expand Up @@ -53,4 +54,14 @@ public List<Agent> GetAgentsByUser(string userId)
var agents = _db.GetAgentsByUser(userId);
return agents;
}

public IEnumerable<string> GetAgentTools()
{
var tools = typeof(AgentTool).GetFields(BindingFlags.Public | BindingFlags.Static)
.Where(f => f.IsLiteral && f.FieldType == typeof(string))
.Select(x => x.GetRawConstantValue()?.ToString())
.ToList();

return tools;
}
}
16 changes: 16 additions & 0 deletions src/Infrastructure/BotSharp.Core/BotSharp.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
</PropertyGroup>

<ItemGroup>
<None Remove="data\agents\00000000-0000-0000-0000-000000000000\agent.json" />
<None Remove="data\agents\00000000-0000-0000-0000-000000000000\instruction.liquid" />
<None Remove="data\agents\00000000-0000-0000-0000-000000000000\functions.json" />
<None Remove="data\agents\00000000-0000-0000-0000-000000000000\templates\load_attachment_prompt.liquid" />
<None Remove="data\agents\01dcc3e5-0af7-49e6-ad7a-a760bd12dc4b\agent.json" />
<None Remove="data\agents\01dcc3e5-0af7-49e6-ad7a-a760bd12dc4b\functions.json" />
<None Remove="data\agents\01dcc3e5-0af7-49e6-ad7a-a760bd12dc4b\instruction.liquid" />
Expand Down Expand Up @@ -146,6 +150,18 @@
<Content Include="data\agents\01fcc3e5-9af7-49e6-ad7a-a760bd12dc4a\templates\conversation.summary.liquid">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\00000000-0000-0000-0000-000000000000\agent.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\00000000-0000-0000-0000-000000000000\instruction.liquid">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\00000000-0000-0000-0000-000000000000\functions.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\00000000-0000-0000-0000-000000000000\templates\load_attachment_prompt.liquid">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\plugins\config.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,9 @@ public async Task<Conversation> GetConversationRecordOrCreateNew(string agentId)

return converation;
}

public bool IsConversationMode()
{
return !string.IsNullOrWhiteSpace(_conversationId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ public class LoadAttachmentFn : IFunctionCallback

private readonly IServiceProvider _services;
private readonly ILogger<LoadAttachmentFn> _logger;
private const string AIAssistant = "01fcc3e5-9af7-49e6-ad7a-a760bd12dc4a";
private readonly IEnumerable<string> _imageTypes = new List<string> { "image", "images", "png", "jpg", "jpeg" };
private readonly IEnumerable<string> _pdfTypes = new List<string> { "pdf" };
private static string TOOL_ASSISTANT = Guid.Empty.ToString();

public LoadAttachmentFn(
IServiceProvider services,
Expand All @@ -29,13 +29,13 @@ public async Task<bool> Execute(RoleDialogModel message)
var agentService = _services.GetRequiredService<IAgentService>();

var wholeDialogs = conv.GetDialogHistory();
var fileTypes = args?.FileTypes?.Split(",")?.ToList() ?? new List<string>();
var fileTypes = args?.FileTypes?.Split(",", StringSplitOptions.RemoveEmptyEntries)?.ToList() ?? new List<string>();
var dialogs = await AssembleFiles(conv.ConversationId, wholeDialogs, fileTypes);
var agent = await agentService.LoadAgent(!string.IsNullOrEmpty(message.CurrentAgentId) ? message.CurrentAgentId : AIAssistant);
var agent = await agentService.LoadAgent(TOOL_ASSISTANT);
var fileAgent = new Agent
{
Id = agent.Id,
Name = agent.Name,
Id = agent?.Id ?? Guid.Empty.ToString(),
Name = agent?.Name ?? "Unkown",
Instruction = !string.IsNullOrWhiteSpace(args?.UserRequest) ? args.UserRequest : "Please describe the files.",
TemplateDict = new Dictionary<string, object>()
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,53 @@

using Microsoft.EntityFrameworkCore;

namespace BotSharp.Core.Files.Hooks;

public class AttachmentProcessingHook : AgentHookBase
{
private readonly IServiceProvider _services;
private static string TOOL_ASSISTANT = Guid.Empty.ToString();

public override string SelfId => string.Empty;

public AttachmentProcessingHook(IServiceProvider services, AgentSettings settings)
: base(services, settings)
{
_services = services;
}

public override void OnAgentLoaded(Agent agent)
{
var fileService = _services.GetRequiredService<IBotSharpFileService>();
var conv = _services.GetRequiredService<IConversationService>();
var hasConvFiles = fileService.HasConversationUserFiles(conv.ConversationId);
var isConvMode = conv.IsConversationMode();
var isEnabled = !agent.Tools.IsNullOrEmpty() && agent.Tools.Contains(AgentTool.FileAnalyzer);

if (hasConvFiles)
if (isConvMode && isEnabled)
{
agent.Instruction += "\r\n\r\nPlease call load_attachment if user wants to describe files, such as images, pdf.\r\n\r\n";

if (agent.Functions != null)
var (prompt, loadAttachmentFn) = GetLoadAttachmentFn();
if (loadAttachmentFn != null)
{
var json = JsonSerializer.Serialize(new
if (!string.IsNullOrWhiteSpace(prompt))
{
agent.Instruction += $"\r\n\r\n{prompt}\r\n\r\n";
}

if (agent.Functions == null)
{
user_request = new
{
type = "string",
description = "The request posted by user, which is related to analyzing requested files. User can request for multiple files to process at one time."
},
file_types = new
{
type = "string",
description = "The file types requested by user to analyze, such as image, png, jpeg, and pdf. There can be multiple file types in a single request. An example output is, 'image,pdf'"
}
});

agent.Functions.Add(new FunctionDef
agent.Functions = new List<FunctionDef> { loadAttachmentFn };
}
else
{
Name = "load_attachment",
Description = "If the user's request is related to analyzing files and/or images, you can call this function to analyze files and images.",
Parameters =
{
Properties = JsonSerializer.Deserialize<JsonDocument>(json),
Required = new List<string>
{
"user_request",
"file_types"
}
}
});
agent.Functions.Add(loadAttachmentFn);
}
}
}

base.OnAgentLoaded(agent);
}

private (string, FunctionDef?) GetLoadAttachmentFn()
{
var fnName = "load_attachment";
var db = _services.GetRequiredService<IBotSharpRepository>();
var agent = db.GetAgent(TOOL_ASSISTANT);
var prompt = agent?.Templates?.FirstOrDefault(x => x.Name.IsEqualTo($"{fnName}_prompt"))?.Content ?? string.Empty;
var loadAttachmentFn = agent?.Functions?.FirstOrDefault(x => x.Name.IsEqualTo(fnName));
return (prompt, loadAttachmentFn);
}
}
Loading