Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reply to callback support, but still wait for completion #3

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 181 additions & 11 deletions BlazorDiffusion.ServiceInterface/CreativeService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
using System.Data;
using System.IO;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using BlazorDiffusion.ServiceModel;
using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Mvc.RazorPages;
using Microsoft.Extensions.Logging;
using ServiceStack;
using ServiceStack.Auth;
using ServiceStack.DataAnnotations;
using ServiceStack.IO;
using ServiceStack.Logging;
using ServiceStack.OrmLite;
Expand Down Expand Up @@ -67,7 +69,37 @@ public async Task<object> Any(CheckQuota request)
};
}

public async Task<object> Post(CreateCreative request)
public async Task<object> Post(CreateCreativeCallback request)
{
var createCreative = request.Context;

var modifiers = await Db.SelectAsync<Modifier>(x => Sql.In(x.Id, createCreative.ModifierIds));
var artists = createCreative.ArtistIds.Count == 0 ? new List<Artist>() :
await Db.SelectAsync<Artist>(x => Sql.In(x.Id, createCreative.ArtistIds));

var imageGenerationResponse = await stableDiffusion.GetQueueResult(request.RefId);
if (imageGenerationResponse == null)
throw HttpError.NotFound("ImageGenerationResponse not found");

var creativeQueue = await Db.SingleAsync<CreativeQueue>(x => x.RefId == request.RefId);
if (creativeQueue == null)
throw HttpError.NotFound("CreativeQueue not found");

var creativeId = await PersistCreative(createCreative,
imageGenerationResponse,
modifiers,
artists,
creativeQueue.UserId,
creativeQueue.RefId);

var creative = await Db.LoadSingleByIdAsync<Creative>(creativeId);

PublishMessage(new BackgroundTasks { NewCreative = creative });

return new CreateCreativeResponse { Result = creative };
}

public async Task<object> Post(QueueCreative request)
{
var session = await SessionAsAsync<CustomUserSession>();
var userId = session.GetUserId();
Expand Down Expand Up @@ -96,7 +128,7 @@ public async Task<object> Post(CreateCreative request)
var artists = request.ArtistIds.Count == 0 ? new List<Artist>() :
await Db.SelectAsync<Artist>(x => Sql.In(x.Id, request.ArtistIds));

var imageGenerationRequest = CreateImageGenerationRequest(request, modifiers, artists, userRoles);
var imageGenerationRequest = CreateImageGenerationRequest(request.ConvertTo<CreateCreative>(), modifiers, artists, userRoles);

var quotaError = await userQuotas.ValidateQuotaAsync(Db, imageGenerationRequest, userId, userRoles);
if (quotaError != null)
Expand All @@ -108,12 +140,90 @@ public async Task<object> Post(CreateCreative request)
return quotaError.ToHttpError(quotaError.ToResponseStatus());
}

var imageGenerationResponse = await GenerateImage(imageGenerationRequest);
var queueReq = new QueueImageGeneration
{
ImageGeneration = imageGenerationRequest,
Context = request
};
var queueImageGenerationResponse = await stableDiffusion.QueueGenerateImageAsync(queueReq);
Db.Insert(new CreativeQueue
{
RefId = queueImageGenerationResponse.RefId,
UserId = userId
});
return queueImageGenerationResponse;
}

public async Task<object> Post(CreateCreative request)
{
var session = await SessionAsAsync<CustomUserSession>();
var userId = session.GetUserId();
var userAuth = await userManager.FindByIdAsync(session.UserAuthId);
if (userAuth == null)
throw HttpError.Unauthorized("Not Authorized");

var requestLower = request.UserPrompt.ToLower();
foreach (var banWord in AppConfig.Instance.BanWords)
{
if (requestLower.Contains(banWord))
{
await userManager.SetLockoutEnabledAsync(userAuth, true);
await userManager.SetLockoutEndDateAsync(userAuth, DateTime.UtcNow.AddYears(1));
break;
}
}
if (userAuth.LockoutEnd > DateTime.UtcNow)
{
throw HttpError.Forbidden("Account is locked");
}

var userRoles = await session.GetRolesAsync(AuthRepositoryAsync);

var creativeId = await PersistCreative(request, imageGenerationResponse, modifiers, artists);
var modifiers = await Db.SelectAsync<Modifier>(x => Sql.In(x.Id, request.ModifierIds));
var artists = request.ArtistIds.Count == 0 ? new List<Artist>() :
await Db.SelectAsync<Artist>(x => Sql.In(x.Id, request.ArtistIds));

var creative = await Db.LoadSingleByIdAsync<Creative>(creativeId);
var imageGenerationRequest = CreateImageGenerationRequest(request, modifiers, artists, userRoles);

var quotaError = await userQuotas.ValidateQuotaAsync(Db, imageGenerationRequest, userId, userRoles);
if (quotaError != null)
{
log.LogInformation("User #{Id} {UserName} exceeded quota, credits: {CreditsUsed} + {CreditsRequested} > {DailyQuota}, time remaining: {TimeRemaining}",
session.UserAuthId, session.UserAuthName, quotaError.CreditsUsed,
quotaError.CreditsRequested, quotaError.DailyQuota, quotaError.TimeRemaining);

return quotaError.ToHttpError(quotaError.ToResponseStatus());
}

var queueReq = new QueueImageGeneration
{
ImageGeneration = imageGenerationRequest,
Context = request
};
var imageGenerationResponse = await stableDiffusion.QueueGenerateImageAsync(queueReq);
Db.Insert(new CreativeQueue
{
RefId = imageGenerationResponse.RefId,
UserId = userId
});
// Poll for completion
var refId = imageGenerationResponse.RefId;

Creative? creative = null;
var timeout = DateTime.UtcNow.AddSeconds(120);
while (DateTime.UtcNow < timeout)
{
creative = await Db.SingleAsync<Creative>(x => x.RefId == refId);
if (creative != null)
{
break;
}
await Task.Delay(1000);
}

if(creative == null)
throw new HttpError(HttpStatusCode.BadRequest,"Request timed out waiting for Creative");

PublishMessage(new BackgroundTasks { NewCreative = creative });

return new CreateCreativeResponse { Result = creative };
Expand Down Expand Up @@ -198,15 +308,21 @@ public async Task<object> Patch(UpdateArtifact request)

return artifact;
}
private async Task<int> PersistCreative(CreateCreative request,

private async Task<int> PersistCreative(CreateCreative request,
ImageGenerationResponse imageGenerationResponse,
List<Modifier> modifiers,
List<Artist> artists)
List<Modifier> modifiers,
List<Artist> artists,
int callbackUserId = 0,
string? refId = null)
{
request.UserPrompt = request.UserPrompt.Trim();
var session = await SessionAsAsync<CustomUserSession>();
var userId = session.GetUserId();
if(userId == 0)
userId = callbackUserId;
if(userId == 0)
throw HttpError.Unauthorized("Not Authorized");
var now = DateTime.UtcNow;
var creative = request.ConvertTo<Creative>()
.WithAudit(session.UserAuthId, now);
Expand All @@ -219,7 +335,7 @@ private async Task<int> PersistCreative(CreateCreative request,
creative.ArtistNames = artists.Select(x => x.GetArtistName()).ToList();
creative.ModifierNames = modifiers.Select(x => x.Name).ToList();
creative.Prompt = request.UserPrompt.ConstructPrompt(modifiers, artists);
creative.RefId = Guid.NewGuid().ToString("D");
creative.RefId = refId ?? Guid.NewGuid().ToString("D");

using var db = HostContext.AppHost.GetDbConnection();
using var transaction = db.OpenTransaction();
Expand Down Expand Up @@ -314,7 +430,7 @@ private async Task<ImageGenerationResponse> GenerateImage(ImageGeneration reques
throw HttpError.ServiceUnavailable($"Failed to generate image: {e.Message}");
}
}

public async Task<object> Get(ResizedArtifact request)
{
var artifact = Db.SingleById<Artifact>(request.ArtifactId);
Expand Down Expand Up @@ -445,6 +561,47 @@ public async Task<object> Any(GetCreative request)
}
}

public class QueueCreative
{
[Required]
public string UserPrompt { get; set; }

public int? Images { get; set; }

public int? Width { get; set; }

public int? Height { get; set; }

public int? Steps { get; set; }
public long? Seed { get; set; }

public string? EngineId { get; set; }

public List<int> ArtistIds { get; set; }
public List<int> ModifierIds { get; set; }
}

public class QueueCreativeResponse
{
public string? RefId { get; set; }
}

[Route("/creatives/callback", "POST")]
public class CreateCreativeCallback : IReturn<CreateCreativeCallbackResponse>
{
/// <summary>
/// Reference passed from original request
/// </summary>
public string RefId { get; set; }

public CreateCreative Context { get; set; }
}

public class CreateCreativeCallbackResponse
{

}

public static class CreateServiceUtils
{
public static async Task<bool> IsOwnerOrModerator(this Service service, int? userId) =>
Expand Down Expand Up @@ -477,6 +634,8 @@ public static async Task<Creative> SaveCreativeAsync(this IStableDiffusionClient
public interface IStableDiffusionClient
{
Task<ImageGenerationResponse> GenerateImageAsync(ImageGeneration request);
Task<QueueImageGenerationResponse> QueueGenerateImageAsync(QueueImageGeneration request);
Task<ImageGenerationResponse> GetQueueResult(string refId);
IVirtualFile? GetMetadataFile(Creative creative);
Task SaveMetadataAsync(Creative entry);
Task DeleteFolderAsync(Creative entry);
Expand Down Expand Up @@ -530,6 +689,17 @@ public class ImageGenerationResult
public ImageDetails? ImageDetails { get; set; }
}

public class QueueImageGeneration : IReturn<QueueImageGenerationResponse>
{
public ImageGeneration ImageGeneration { get; set; }
public object Context { get; set; }
}

public class QueueImageGenerationResponse
{
public string RefId { get; set; }
}

public class PrerenderView
{
public string ViewPath { get; set; }
Expand Down
11 changes: 11 additions & 0 deletions BlazorDiffusion.ServiceModel/Creatives.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ public class Creative : AuditBase
public string EngineId { get; set; }
}

[Tag(Tag.Creatives)]
public class CreativeQueue
{
[AutoIncrement]
public int Id { get; set; }

public string RefId { get; set; }

public int UserId { get; set; }
}

[Tag(Tag.Creatives)]
public class QueryCreatives : QueryDb<Creative>
{
Expand Down
10 changes: 10 additions & 0 deletions BlazorDiffusion.Tests/CreativeServiceMockedTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ public async Task<ImageGenerationResponse> GenerateImageAsync(ImageGeneration re
};
}

public Task<QueueImageGenerationResponse> QueueGenerateImageAsync(QueueImageGeneration request)
{
throw new NotImplementedException();
}

public Task<ImageGenerationResponse> GetQueueResult(string refId)
{
throw new NotImplementedException();
}

public Task SaveMetadataAsync(Creative entry) => Task.CompletedTask;
public Task DeleteFolderAsync(Creative entry) => Task.CompletedTask;

Expand Down
Loading
Loading