diff --git a/OnnxStack.Console/Examples/StableDebug.cs b/OnnxStack.Console/Examples/StableDebug.cs index 9409c685..d2015423 100644 --- a/OnnxStack.Console/Examples/StableDebug.cs +++ b/OnnxStack.Console/Examples/StableDebug.cs @@ -1,7 +1,7 @@ -using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion; +using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; -using OnnxStack.StableDiffusion.Services; using SixLabors.ImageSharp; using System.Diagnostics; @@ -37,11 +37,11 @@ public async Task RunAsync() { Prompt = prompt, NegativePrompt = negativePrompt, - SchedulerType = SchedulerType.LMS }; var schedulerOptions = new SchedulerOptions { + SchedulerType = SchedulerType.LMS, Seed = 624461087, //Seed = Random.Shared.Next(), GuidanceScale = 8, @@ -54,9 +54,9 @@ public async Task RunAsync() OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModel(model); - foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType)) + foreach (var schedulerType in model.PipelineType.GetSchedulerTypes()) { - promptOptions.SchedulerType = schedulerType; + schedulerOptions.SchedulerType = schedulerType; OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green); await GenerateImage(model, promptOptions, schedulerOptions); } @@ -72,12 +72,12 @@ public async Task RunAsync() private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options) { var timestamp = Stopwatch.GetTimestamp(); - var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png"); + var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png"); var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options); if (result is not null) { await result.SaveAsPngAsync(outputFilename); - OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); + OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow); return true; } diff --git a/OnnxStack.Console/Examples/StableDiffusionBatch.cs b/OnnxStack.Console/Examples/StableDiffusionBatch.cs index 3cb33ae7..0aa01de4 100644 --- a/OnnxStack.Console/Examples/StableDiffusionBatch.cs +++ b/OnnxStack.Console/Examples/StableDiffusionBatch.cs @@ -1,9 +1,9 @@ -using OnnxStack.Core; -using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; -using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion; using SixLabors.ImageSharp; +using OnnxStack.StableDiffusion.Helpers; namespace OnnxStack.Console.Runner { @@ -31,22 +31,10 @@ public async Task RunAsync() while (true) { - OutputHelpers.WriteConsole("Please type a prompt and press ENTER", ConsoleColor.Yellow); - var prompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan); - - OutputHelpers.WriteConsole("Please type a negative prompt and press ENTER (optional)", ConsoleColor.Yellow); - var negativePrompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan); - - OutputHelpers.WriteConsole("Please enter a batch count and press ENTER", ConsoleColor.Yellow); - var batch = OutputHelpers.ReadConsole(ConsoleColor.Cyan); - int.TryParse(batch, out var batchCount); - batchCount = Math.Max(1, batchCount); var promptOptions = new PromptOptions { - Prompt = prompt, - NegativePrompt = negativePrompt, - BatchCount = batchCount + Prompt = "Photo of a cat" }; var schedulerOptions = new SchedulerOptions @@ -54,20 +42,33 @@ public async Task RunAsync() Seed = Random.Shared.Next(), GuidanceScale = 8, - InferenceSteps = 22, + InferenceSteps = 20, Strength = 0.6f }; + var batchOptions = new BatchOptions + { + BatchType = BatchOptionType.Scheduler + }; + foreach (var model in _stableDiffusionService.Models) { OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModel(model); - foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType)) + var batchIndex = 0; + var callback = (int batch, int batchCount, int step, int steps) => + { + batchIndex = batch; + OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan); + }; + + await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback)) { - promptOptions.SchedulerType = schedulerType; - OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green); - await GenerateImage(model, promptOptions, schedulerOptions); + var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png"); + var image = result.ImageResult.ToImage(); + await image.SaveAsPngAsync(outputFilename); + OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); } OutputHelpers.WriteConsole($"Unloading Model `{model.Name}`...", ConsoleColor.Green); @@ -75,24 +76,5 @@ public async Task RunAsync() } } } - - private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options) - { - - var result = await _stableDiffusionService.GenerateAsync(model, prompt, options); - if (result == null) - return false; - - var imageTensors = result.Split(prompt.BatchCount); - for (int i = 0; i < imageTensors.Length; i++) - { - var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{i}.png"); - var image = imageTensors[i].ToImage(); - await image.SaveAsPngAsync(outputFilename); - OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); - } - - return true; - } } } diff --git a/OnnxStack.Console/Examples/StableDiffusionExample.cs b/OnnxStack.Console/Examples/StableDiffusionExample.cs index ce925c1b..45f0f763 100644 --- a/OnnxStack.Console/Examples/StableDiffusionExample.cs +++ b/OnnxStack.Console/Examples/StableDiffusionExample.cs @@ -1,7 +1,6 @@ -using OnnxStack.Core; +using OnnxStack.StableDiffusion; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; -using OnnxStack.StableDiffusion.Enums; using SixLabors.ImageSharp; namespace OnnxStack.Console.Runner @@ -53,9 +52,9 @@ public async Task RunAsync() OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModel(model); - foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType)) + foreach (var schedulerType in model.PipelineType.GetSchedulerTypes()) { - promptOptions.SchedulerType = schedulerType; + schedulerOptions.SchedulerType = schedulerType; OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green); await GenerateImage(model, promptOptions, schedulerOptions); } @@ -68,13 +67,13 @@ public async Task RunAsync() private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options) { - var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png"); + var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png"); var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options); if (result == null) return false; await result.SaveAsPngAsync(outputFilename); - OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); + OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); return true; } } diff --git a/OnnxStack.Console/Examples/StableDiffusionGenerator.cs b/OnnxStack.Console/Examples/StableDiffusionGenerator.cs index 3a8d6a29..a9317b25 100644 --- a/OnnxStack.Console/Examples/StableDiffusionGenerator.cs +++ b/OnnxStack.Console/Examples/StableDiffusionGenerator.cs @@ -1,7 +1,6 @@ -using OnnxStack.Core; +using OnnxStack.StableDiffusion; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; -using OnnxStack.StableDiffusion.Enums; using SixLabors.ImageSharp; using System.Collections.ObjectModel; @@ -48,9 +47,9 @@ public async Task RunAsync() { Seed = Random.Shared.Next() }; - foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType)) + foreach (var schedulerType in model.PipelineType.GetSchedulerTypes()) { - promptOptions.SchedulerType = schedulerType; + schedulerOptions.SchedulerType = schedulerType; OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green); await GenerateImage(model, promptOptions, schedulerOptions, generationPrompt.Key); } @@ -65,13 +64,13 @@ public async Task RunAsync() private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options, string key) { - var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{key}.png"); + var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png"); var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options); if (result == null) return false; await result.SaveAsPngAsync(outputFilename); - OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); + OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); return true; } diff --git a/OnnxStack.Console/Helpers.cs b/OnnxStack.Console/Helpers.cs deleted file mode 100644 index 5d47cdd2..00000000 --- a/OnnxStack.Console/Helpers.cs +++ /dev/null @@ -1,28 +0,0 @@ -using OnnxStack.StableDiffusion.Enums; - -namespace OnnxStack.Console -{ - internal static class Helpers - { - public static SchedulerType[] GetPipelineSchedulers(DiffuserPipelineType pipelineType) - { - return pipelineType switch - { - DiffuserPipelineType.StableDiffusion => new[] - { - SchedulerType.LMS, - SchedulerType.Euler, - SchedulerType.EulerAncestral, - SchedulerType.DDPM, - SchedulerType.DDIM, - SchedulerType.KDPM2 - }, - DiffuserPipelineType.LatentConsistency => new[] - { - SchedulerType.LCM - }, - _ => default - }; - } - } -} diff --git a/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs b/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs index 68f6e237..15396cd0 100644 --- a/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs @@ -2,6 +2,7 @@ using OnnxStack.Core.Config; using OnnxStack.Core.Model; using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; using System; @@ -83,5 +84,53 @@ public interface IStableDiffusionService /// The cancellation token. /// The diffusion result as Task GenerateAsStreamAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Config/BatchOptions.cs b/OnnxStack.StableDiffusion/Config/BatchOptions.cs new file mode 100644 index 00000000..b0c34f30 --- /dev/null +++ b/OnnxStack.StableDiffusion/Config/BatchOptions.cs @@ -0,0 +1,12 @@ +using OnnxStack.StableDiffusion.Enums; + +namespace OnnxStack.StableDiffusion.Config +{ + public record BatchOptions + { + public BatchOptionType BatchType { get; set; } + public float ValueTo { get; set; } + public float ValueFrom { get; set; } + public float Increment { get; set; } = 1f; + } +} diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs index c942931b..037de6ed 100644 --- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs +++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs @@ -14,7 +14,6 @@ public class PromptOptions [StringLength(512)] public string NegativePrompt { get; set; } - public SchedulerType SchedulerType { get; set; } public int BatchCount { get; set; } = 1; diff --git a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs index e10f5482..badd37d1 100644 --- a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs +++ b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs @@ -4,8 +4,13 @@ namespace OnnxStack.StableDiffusion.Config { - public class SchedulerOptions + public record SchedulerOptions { + /// + /// Gets or sets the type of scheduler. + /// + public SchedulerType SchedulerType { get; set; } + /// /// Gets or sets the height. /// diff --git a/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs index 25161a7f..0f2b1ba8 100644 --- a/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs @@ -2,7 +2,9 @@ using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Models; using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -33,5 +35,18 @@ public interface IDiffuser /// The cancellation token. /// Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the stable diffusion batch loop + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs index b8ac8ebd..6f3dfa5d 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs @@ -8,11 +8,13 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Schedulers.LatentConsistency; using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -83,17 +85,78 @@ public virtual async Task> DiffuseAsync(IModelOptions modelOp schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); var diffuseTime = _logger?.LogBegin("Begin..."); - _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}"); + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); // LCM does not support negative prompting + var performGuidance = false; promptOptions.NegativePrompt = string.Empty; + // Process prompts + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + + // Run Scheduler steps + var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + + _logger?.LogEnd($"End", diffuseTime); + + return schedulerResult; + } + + + /// + /// Runs the stable diffusion batch loop + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + /// + public async IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var diffuseBatchTime = _logger?.LogBegin("Begin..."); + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); + + // LCM does not support negative prompting + var performGuidance = false; + promptOptions.NegativePrompt = string.Empty; + + // Process prompts + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + + // Generate batch options + var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions); + + var batchIndex = 1; + var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps); + foreach (var batchSchedulerOption in batchSchedulerOptions) + { + yield return new BatchResult(batchSchedulerOption, await RunSchedulerSteps(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken)); + batchIndex++; + } + + _logger?.LogEnd($"End", diffuseBatchTime); + } + + + /// + /// Runs the scheduler steps. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected virtual async Task> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { // Get Scheduler using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) { - // Process prompts - var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, false); - // Get timesteps var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler); @@ -137,9 +200,7 @@ public virtual async Task> DiffuseAsync(IModelOptions modelOp } // Decode Latents - var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised); - _logger?.LogEnd($"End", diffuseTime); - return result; + return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised); } } @@ -218,7 +279,7 @@ protected virtual IReadOnlyList CreateUnetInputParams(IModelOpti /// protected IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions options) { - return prompt.SchedulerType switch + return options.SchedulerType switch { SchedulerType.LCM => new LCMScheduler(options), _ => default @@ -261,5 +322,7 @@ protected static IReadOnlyList CreateInputParameters(params Name { return parameters.ToList(); } + + } } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs index b3d05bdb..974318f2 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs @@ -39,10 +39,13 @@ public InpaintDiffuser(IOnnxModelService onnxModelService, IPromptService prompt /// - /// Runs the Stable Diffusion inference. + /// Runs the stable diffusion loop /// - /// The options. - /// The scheduler configuration. + /// + /// The prompt options. + /// The scheduler options. + /// + /// The cancellation token. /// public override async Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { @@ -50,18 +53,39 @@ public override async Task> DiffuseAsync(IModelOptions modelO schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); var diffuseTime = _logger?.LogBegin("Begin..."); - _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}"); + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); + + // Should we perform classifier free guidance + var performGuidance = schedulerOptions.GuidanceScale > 1.0f; + + // Process prompts + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + + // Run Scheduler steps + var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + + _logger?.LogEnd($"End", diffuseTime); + + return schedulerResult; + } + /// + /// Runs the scheduler steps. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected override async Task> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { // Get Scheduler using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) { - // Should we perform classifier free guidance - var performGuidance = schedulerOptions.GuidanceScale > 1.0f; - - // Process prompts - var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); - // Get timesteps var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler); @@ -110,9 +134,7 @@ public override async Task> DiffuseAsync(IModelOptions modelO } // Decode Latents - var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents); - _logger?.LogEnd($"End", diffuseTime); - return result; + return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents); } } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs index 0a30687b..d8495b7a 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs @@ -53,17 +53,38 @@ public override async Task> DiffuseAsync(IModelOptions modelO schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); var diffuseTime = _logger?.LogBegin("Begin..."); - _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}"); + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); - // Get Scheduler - using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) - { - // Should we perform classifier free guidance - var performGuidance = schedulerOptions.GuidanceScale > 1.0f; + // Should we perform classifier free guidance + var performGuidance = schedulerOptions.GuidanceScale > 1.0f; + + // Process prompts + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + + // Run Scheduler steps + var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); - // Process prompts - var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + _logger?.LogEnd($"End", diffuseTime); + return schedulerResult; + } + + + /// + /// Runs the scheduler steps. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected override async Task> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) + { // Get timesteps var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler); @@ -120,9 +141,7 @@ public override async Task> DiffuseAsync(IModelOptions modelO } // Decode Latents - var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents); - _logger?.LogEnd($"End", diffuseTime); - return result; + return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents); } } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs index 6cfed168..a3db90a7 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs @@ -8,11 +8,13 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Schedulers.StableDiffusion; using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -82,17 +84,76 @@ public virtual async Task> DiffuseAsync(IModelOptions modelOp schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); var diffuseTime = _logger?.LogBegin("Begin..."); - _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}"); + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); - // Get Scheduler - using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) + // Should we perform classifier free guidance + var performGuidance = schedulerOptions.GuidanceScale > 1.0f; + + // Process prompts + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + + // Run Scheduler steps + var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + + _logger?.LogEnd($"End", diffuseTime); + + return schedulerResult; + } + + + /// + /// Runs the stable diffusion batch loop + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + /// + public async IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation]CancellationToken cancellationToken = default) + { + var diffuseBatchTime = _logger?.LogBegin("Begin..."); + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); + + // Should we perform classifier free guidance + var performGuidance = schedulerOptions.GuidanceScale > 1.0f; + + // Process prompts + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + + // Generate batch options + var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions); + + var batchIndex = 1; + var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps); + foreach (var batchSchedulerOption in batchSchedulerOptions) { - // Should we perform classifier free guidance - var performGuidance = schedulerOptions.GuidanceScale > 1.0f; + yield return new BatchResult(batchSchedulerOption, await RunSchedulerSteps(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken)); + batchIndex++; + } - // Process prompts - var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + _logger?.LogEnd($"End", diffuseBatchTime); + } + + /// + /// Runs the scheduler steps. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected virtual async Task> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + // Get Scheduler + using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) + { // Get timesteps var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler); @@ -130,13 +191,11 @@ public virtual async Task> DiffuseAsync(IModelOptions modelOp } progressCallback?.Invoke(step, timesteps.Count); - _logger?.LogEnd(LogLevel.Debug,$"Step {step}/{timesteps.Count}", stepTime); + _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime); } // Decode Latents - var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents); - _logger?.LogEnd($"End", diffuseTime); - return result; + return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents); } } @@ -149,7 +208,7 @@ public virtual async Task> DiffuseAsync(IModelOptions modelOp /// protected virtual async Task> DecodeLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor latents) { - var timestamp = _logger?.LogBegin("Begin..."); + var timestamp = _logger?.LogBegin("Begin..."); // Scale and decode the image latents with vae. latents = latents.MultiplyBy(1.0f / model.ScaleFactor); @@ -237,7 +296,7 @@ protected virtual IReadOnlyList CreateUnetInputParams(IModelOpti /// protected virtual IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions options) { - return prompt.SchedulerType switch + return options.SchedulerType switch { SchedulerType.LMS => new LMSScheduler(options), SchedulerType.Euler => new EulerScheduler(options), diff --git a/OnnxStack.StableDiffusion/Enums/BatchOptionType.cs b/OnnxStack.StableDiffusion/Enums/BatchOptionType.cs new file mode 100644 index 00000000..f9eb329b --- /dev/null +++ b/OnnxStack.StableDiffusion/Enums/BatchOptionType.cs @@ -0,0 +1,11 @@ +namespace OnnxStack.StableDiffusion.Enums +{ + public enum BatchOptionType + { + Seed = 0, + Step = 1, + Guidance = 2, + Strength = 3, + Scheduler = 4 + } +} diff --git a/OnnxStack.StableDiffusion/Extensions.cs b/OnnxStack.StableDiffusion/Extensions.cs index ae87a1fa..22890470 100644 --- a/OnnxStack.StableDiffusion/Extensions.cs +++ b/OnnxStack.StableDiffusion/Extensions.cs @@ -1,13 +1,12 @@ using Microsoft.ML.OnnxRuntime; -using OnnxStack.Core; using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; using System; -using System.Collections.Generic; using System.Linq; namespace OnnxStack.StableDiffusion { - internal static class Extensions + public static class Extensions { /// /// Gets the first element and casts it to the specified type. @@ -15,7 +14,7 @@ internal static class Extensions /// Desired return type /// The collection. /// Firts element in the collection cast as - public static T FirstElementAs(this IDisposableReadOnlyCollection collection) + internal static T FirstElementAs(this IDisposableReadOnlyCollection collection) { if (collection is null || collection.Count == 0) return default; @@ -34,7 +33,7 @@ public static T FirstElementAs(this IDisposableReadOnlyCollectionDesired return type /// The collection. /// Last element in the collection cast as - public static T LastElementAs(this IDisposableReadOnlyCollection collection) + internal static T LastElementAs(this IDisposableReadOnlyCollection collection) { if (collection is null || collection.Count == 0) return default; @@ -53,7 +52,7 @@ public static T LastElementAs(this IDisposableReadOnlyCollectionThe options. /// /// Width must be divisible by 64 - public static int GetScaledWidth(this SchedulerOptions options) + internal static int GetScaledWidth(this SchedulerOptions options) { if (options.Width % 64 > 0) throw new ArgumentOutOfRangeException(nameof(options.Width), $"{nameof(options.Width)} must be divisible by 64"); @@ -68,7 +67,7 @@ public static int GetScaledWidth(this SchedulerOptions options) /// The options. /// /// Height must be divisible by 64 - public static int GetScaledHeight(this SchedulerOptions options) + internal static int GetScaledHeight(this SchedulerOptions options) { if (options.Height % 64 > 0) throw new ArgumentOutOfRangeException(nameof(options.Height), $"{nameof(options.Height)} must be divisible by 64"); @@ -84,9 +83,36 @@ public static int GetScaledHeight(this SchedulerOptions options) /// The batch. /// The channels. /// Tensor dimension of [batch, channels, (Height / 8), (Width / 8)] - public static int[] GetScaledDimension(this SchedulerOptions options, int batch = 1, int channels = 4) + internal static int[] GetScaledDimension(this SchedulerOptions options, int batch = 1, int channels = 4) { return new[] { batch, channels, options.GetScaledHeight(), options.GetScaledWidth() }; } + + + /// + /// Gets the pipeline schedulers. + /// + /// Type of the pipeline. + /// + public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipelineType) + { + return pipelineType switch + { + DiffuserPipelineType.StableDiffusion => new[] + { + SchedulerType.LMS, + SchedulerType.Euler, + SchedulerType.EulerAncestral, + SchedulerType.DDPM, + SchedulerType.DDIM, + SchedulerType.KDPM2 + }, + DiffuserPipelineType.LatentConsistency => new[] + { + SchedulerType.LCM + }, + _ => default + }; + } } } diff --git a/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs b/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs new file mode 100644 index 00000000..648da062 --- /dev/null +++ b/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs @@ -0,0 +1,56 @@ +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace OnnxStack.StableDiffusion.Helpers +{ + public static class BatchGenerator + { + /// + /// Generates the batch of SchedulerOptions fo batch processing. + /// + /// The batch options. + /// The scheduler options. + /// + public static List GenerateBatch(IModelOptions modelOptions, BatchOptions batchOptions, SchedulerOptions schedulerOptions) + { + if (batchOptions.BatchType == BatchOptionType.Seed) + { + return Enumerable.Range(0, Math.Max(1, (int)batchOptions.ValueTo)) + .Select(x => Random.Shared.Next()) + .Select(x => schedulerOptions with { Seed = x }) + .ToList(); + } + else if (batchOptions.BatchType == BatchOptionType.Step) + { + return Enumerable.Range(Math.Max(0, (int)batchOptions.ValueFrom), Math.Max(1, (int)batchOptions.ValueTo)) + .Select(x => schedulerOptions with { InferenceSteps = x }) + .ToList(); + } + else if (batchOptions.BatchType == BatchOptionType.Guidance) + { + var totalIncrements = (int)Math.Max(1, (batchOptions.ValueTo - batchOptions.ValueFrom) / batchOptions.Increment); + return Enumerable.Range(0, totalIncrements) + .Select(x => schedulerOptions with { GuidanceScale = batchOptions.ValueFrom + (batchOptions.Increment * x) }) + .ToList(); + } + else if (batchOptions.BatchType == BatchOptionType.Strength) + { + var totalIncrements = (int)Math.Max(1, (batchOptions.ValueTo - batchOptions.ValueFrom) / batchOptions.Increment); + return Enumerable.Range(0, totalIncrements) + .Select(x => schedulerOptions with { Strength = batchOptions.ValueFrom + (batchOptions.Increment * x) }) + .ToList(); + } + else if (batchOptions.BatchType == BatchOptionType.Scheduler) + { + return modelOptions.PipelineType.GetSchedulerTypes() + .Select(x => schedulerOptions with { SchedulerType = x }) + .ToList(); + } + return new List(); + } + } +} diff --git a/OnnxStack.StableDiffusion/Models/BatchResult.cs b/OnnxStack.StableDiffusion/Models/BatchResult.cs new file mode 100644 index 00000000..066a9e36 --- /dev/null +++ b/OnnxStack.StableDiffusion/Models/BatchResult.cs @@ -0,0 +1,7 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.StableDiffusion.Config; + +namespace OnnxStack.StableDiffusion.Models +{ + public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor ImageResult); +} diff --git a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs index e3af60b4..75d7c0ac 100644 --- a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs @@ -5,12 +5,14 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -139,6 +141,73 @@ public async Task GenerateAsStreamAsync(IModelOptions model, PromptOptio } + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + public IAsyncEnumerable GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + { + return DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken); + } + + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + public async IAsyncEnumerable> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) + yield return result.ImageResult.ToImage(); + } + + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + public async IAsyncEnumerable GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) + yield return result.ImageResult.ToImageBytes(); + } + + + /// + /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress callback. + /// The cancellation token. + /// + public async IAsyncEnumerable GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) + yield return result.ImageResult.ToImageStream(); + } + + private async Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) @@ -150,5 +219,18 @@ private async Task> DiffuseAsync(IModelOptions modelOptions, return await diffuser.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken); } + + + private IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, CancellationToken cancellationToken = default) + { + if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) + throw new Exception("Pipeline not found or is unsupported"); + + var diffuser = pipeline.GetDiffuser(promptOptions.DiffuserType); + if (diffuser is null) + throw new Exception("Diffuser not found or is unsupported"); + + return diffuser.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progress, cancellationToken); + } } } diff --git a/OnnxStack.UI/App.xaml b/OnnxStack.UI/App.xaml index 9f239cfc..181eab26 100644 --- a/OnnxStack.UI/App.xaml +++ b/OnnxStack.UI/App.xaml @@ -91,14 +91,17 @@ - - + + + + + diff --git a/OnnxStack.UI/Models/BatchOptionsModel.cs b/OnnxStack.UI/Models/BatchOptionsModel.cs new file mode 100644 index 00000000..bd5019ec --- /dev/null +++ b/OnnxStack.UI/Models/BatchOptionsModel.cs @@ -0,0 +1,82 @@ +using OnnxStack.StableDiffusion.Enums; +using System.ComponentModel; +using System.Runtime.CompilerServices; + +namespace Models +{ + public class BatchOptionsModel : INotifyPropertyChanged + { + private float _valueTo; + private float _valueFrom; + private float _increment = 1; + private BatchOptionType _batchType; + private bool _isAutomationEnabled; + private int _stepValue; + private int _stepsValue = 1; + private int _batchValue; + private int _batchsValue = 1; + + public BatchOptionType BatchType + { + get { return _batchType; } + set { _batchType = value; NotifyPropertyChanged(); } + } + + public float ValueTo + { + get { return _valueTo; } + set { _valueTo = value; NotifyPropertyChanged(); } + } + + public float ValueFrom + { + get { return _valueFrom; } + set { _valueFrom = value; NotifyPropertyChanged(); } + } + + public float Increment + { + get { return _increment; } + set { _increment = value; NotifyPropertyChanged(); } + } + + public bool IsAutomationEnabled + { + get { return _isAutomationEnabled; } + set { _isAutomationEnabled = value; NotifyPropertyChanged(); } + } + + public int StepValue + { + get { return _stepValue; } + set { _stepValue = value; NotifyPropertyChanged(); } + } + + public int StepsValue + { + get { return _stepsValue; } + set { _stepsValue = value; NotifyPropertyChanged(); } + } + + public int BatchValue + { + get { return _batchValue; } + set { _batchValue = value; NotifyPropertyChanged(); } + } + + public int BatchsValue + { + get { return _batchsValue; } + set { _batchsValue = value; NotifyPropertyChanged(); } + } + + + #region INotifyPropertyChanged + public event PropertyChangedEventHandler PropertyChanged; + public void NotifyPropertyChanged([CallerMemberName] string property = "") + { + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(property)); + } + #endregion + } +} diff --git a/OnnxStack.UI/Models/PromptOptionsModel.cs b/OnnxStack.UI/Models/PromptOptionsModel.cs index cd3f813e..46ca68dd 100644 --- a/OnnxStack.UI/Models/PromptOptionsModel.cs +++ b/OnnxStack.UI/Models/PromptOptionsModel.cs @@ -10,7 +10,6 @@ public class PromptOptionsModel : INotifyPropertyChanged { private string _prompt; private string _negativePrompt; - private SchedulerType _schedulerType; [Required] [StringLength(512, MinimumLength = 1)] @@ -27,13 +26,6 @@ public string NegativePrompt set { _prompt = value; NotifyPropertyChanged(); } } - public SchedulerType SchedulerType - { - get { return _schedulerType; } - set { _schedulerType = value; NotifyPropertyChanged(); } - } - - #region INotifyPropertyChanged public event PropertyChangedEventHandler PropertyChanged; public void NotifyPropertyChanged([CallerMemberName] string property = "") diff --git a/OnnxStack.UI/Models/SchedulerOptionsModel.cs b/OnnxStack.UI/Models/SchedulerOptionsModel.cs index a31b15cd..14fcb788 100644 --- a/OnnxStack.UI/Models/SchedulerOptionsModel.cs +++ b/OnnxStack.UI/Models/SchedulerOptionsModel.cs @@ -32,6 +32,7 @@ public class SchedulerOptionsModel : INotifyPropertyChanged private AlphaTransformType _alphaTransformType = AlphaTransformType.Cosine; private float _maximumBeta = 0.999f; private int _originalInferenceSteps = 100; + private SchedulerType _schedulerType; /// /// Gets or sets the height. @@ -203,6 +204,13 @@ public int OriginalInferenceSteps set { _originalInferenceSteps = value; NotifyPropertyChanged(); } } + public SchedulerType SchedulerType + { + get { return _schedulerType; } + set { _schedulerType = value; NotifyPropertyChanged(); } + } + + #region INotifyPropertyChanged public event PropertyChangedEventHandler PropertyChanged; public void NotifyPropertyChanged([CallerMemberName] string property = "") diff --git a/OnnxStack.UI/UserControls/PromptControl.xaml b/OnnxStack.UI/UserControls/PromptControl.xaml index eb944861..04f8ace6 100644 --- a/OnnxStack.UI/UserControls/PromptControl.xaml +++ b/OnnxStack.UI/UserControls/PromptControl.xaml @@ -12,21 +12,16 @@ - - - - - - - - - - - - - - - - + + + + + + + + + + + diff --git a/OnnxStack.UI/UserControls/PromptControl.xaml.cs b/OnnxStack.UI/UserControls/PromptControl.xaml.cs index db7c0209..8650b292 100644 --- a/OnnxStack.UI/UserControls/PromptControl.xaml.cs +++ b/OnnxStack.UI/UserControls/PromptControl.xaml.cs @@ -18,8 +18,6 @@ namespace OnnxStack.UI.UserControls /// public partial class PromptControl : UserControl, INotifyPropertyChanged { - private ObservableCollection _schedulerTypes = new(); - /// Initializes a new instance of the class. public PromptControl() { @@ -56,41 +54,8 @@ public ModelOptionsModel SelectedModel } public static readonly DependencyProperty SelectedModelProperty = - DependencyProperty.Register("SelectedModel", typeof(ModelOptionsModel), typeof(PromptControl), new PropertyMetadata((d, e) => - { - if (d is PromptControl schedulerControl) - schedulerControl.OnModelChanged(e.NewValue as ModelOptionsModel); - })); - - public ObservableCollection SchedulerTypes - { - get { return _schedulerTypes; } - set { _schedulerTypes = value; NotifyPropertyChanged(); } - } - + DependencyProperty.Register("SelectedModel", typeof(ModelOptionsModel), typeof(PromptControl)); - /// - /// Called when the selected model has changed. - /// - /// The model options model. - private void OnModelChanged(ModelOptionsModel model) - { - SchedulerTypes.Clear(); - if (model is null) - return; - - if (model.ModelOptions.PipelineType == DiffuserPipelineType.StableDiffusion) - { - foreach (SchedulerType type in Enum.GetValues().Where(x => x != SchedulerType.LCM)) - SchedulerTypes.Add(type); - } - else if (model.ModelOptions.PipelineType == DiffuserPipelineType.LatentConsistency) - { - SchedulerTypes.Add(SchedulerType.LCM); - } - - PromptOptions.SchedulerType = SchedulerTypes.FirstOrDefault(); - } /// /// Resets the parameters. diff --git a/OnnxStack.UI/UserControls/SchedulerControl.xaml b/OnnxStack.UI/UserControls/SchedulerControl.xaml index 6bd309f5..6b673ef1 100644 --- a/OnnxStack.UI/UserControls/SchedulerControl.xaml +++ b/OnnxStack.UI/UserControls/SchedulerControl.xaml @@ -26,7 +26,10 @@ - + + + + @@ -73,7 +76,7 @@ - + @@ -213,10 +216,9 @@ - - + @@ -228,6 +230,159 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs b/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs index 8b68bcda..eef8e23e 100644 --- a/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs +++ b/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs @@ -7,6 +7,7 @@ using System; using System.Collections.ObjectModel; using System.ComponentModel; +using System.Linq; using System.Runtime.CompilerServices; using System.Windows; using System.Windows.Controls; @@ -20,6 +21,7 @@ namespace OnnxStack.UI.UserControls public partial class SchedulerControl : UserControl, INotifyPropertyChanged { private SchedulerOptionsConfig _optionsConfig = new(); + private ObservableCollection _schedulerTypes = new(); /// Initializes a new instance of the class. public SchedulerControl() @@ -34,6 +36,12 @@ public SchedulerControl() public ICommand RandomSeedCommand { get; } public ObservableCollection ValidSizes { get; } + public ObservableCollection SchedulerTypes + { + get { return _schedulerTypes; } + set { _schedulerTypes = value; NotifyPropertyChanged(); } + } + /// /// Gets or sets the selected model. /// @@ -74,6 +82,18 @@ public SchedulerOptionsModel SchedulerOptions DependencyProperty.Register("SchedulerOptions", typeof(SchedulerOptionsModel), typeof(SchedulerControl)); + public BatchOptionsModel BatchOptions + { + get { return (BatchOptionsModel)GetValue(BatchOptionsProperty); } + set { SetValue(BatchOptionsProperty, value); } + } + public static readonly DependencyProperty BatchOptionsProperty = + DependencyProperty.Register("BatchOptions", typeof(BatchOptionsModel), typeof(SchedulerControl)); + + + + + /// /// Gets or sets the options configuration. /// @@ -103,6 +123,23 @@ private void OnModelChanged(ModelOptionsModel model) SchedulerOptions.OriginalInferenceSteps = 50; SchedulerOptions.InferenceSteps = 6; } + + + SchedulerTypes.Clear(); + if (model is null) + return; + + if (model.ModelOptions.PipelineType == DiffuserPipelineType.StableDiffusion) + { + foreach (SchedulerType type in Enum.GetValues().Where(x => x != SchedulerType.LCM)) + SchedulerTypes.Add(type); + } + else if (model.ModelOptions.PipelineType == DiffuserPipelineType.LatentConsistency) + { + SchedulerTypes.Add(SchedulerType.LCM); + } + + SchedulerOptions.SchedulerType = SchedulerTypes.FirstOrDefault(); } diff --git a/OnnxStack.UI/Utils.cs b/OnnxStack.UI/Utils.cs index 5e65c095..0398c76b 100644 --- a/OnnxStack.UI/Utils.cs +++ b/OnnxStack.UI/Utils.cs @@ -1,4 +1,6 @@ -using OnnxStack.StableDiffusion.Config; +using Models; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; using OnnxStack.UI.Models; using System; using System.IO; @@ -81,8 +83,8 @@ public static async Task AutoSave(this ImageResult imageResult, string aut Directory.CreateDirectory(autosaveDirectory); var random = RandomString(); - var imageFile = Path.Combine(autosaveDirectory, $"image-{imageResult.SchedulerOptions.Seed}-{random}.png"); - var blueprintFile = Path.Combine(autosaveDirectory, $"image-{imageResult.SchedulerOptions.Seed}-{random}.json"); + var imageFile = Path.Combine(autosaveDirectory, $"image-{imageResult.SchedulerOptions.Seed}-{random}.png"); + var blueprintFile = Path.Combine(autosaveDirectory, $"image-{imageResult.SchedulerOptions.Seed}-{random}.json"); if (!await imageResult.SaveImageFile(imageFile)) return false; @@ -120,7 +122,8 @@ public static SchedulerOptions ToSchedulerOptions(this SchedulerOptionsModel mod TrainTimesteps = model.TrainTimesteps, UseKarrasSigmas = model.UseKarrasSigmas, VarianceType = model.VarianceType, - OriginalInferenceSteps = model.OriginalInferenceSteps + OriginalInferenceSteps = model.OriginalInferenceSteps, + SchedulerType = model.SchedulerType }; } @@ -152,6 +155,7 @@ public static SchedulerOptionsModel ToSchedulerOptionsModel(this SchedulerOption UseKarrasSigmas = model.UseKarrasSigmas, VarianceType = model.VarianceType, OriginalInferenceSteps = model.OriginalInferenceSteps, + SchedulerType = model.SchedulerType }; } @@ -160,11 +164,37 @@ public static PromptOptionsModel ToPromptOptionsModel(this PromptOptions promptO return new PromptOptionsModel { Prompt = promptOptions.Prompt, - NegativePrompt = promptOptions.NegativePrompt, - SchedulerType = promptOptions.SchedulerType + NegativePrompt = promptOptions.NegativePrompt + }; + } + + + + public static BatchOptionsModel ToBatchOptionsModel(this BatchOptions batchOptions) + { + return new BatchOptionsModel + { + BatchType = batchOptions.BatchType, + ValueTo = batchOptions.ValueTo, + Increment = batchOptions.Increment, + ValueFrom = batchOptions.ValueFrom }; } + + public static BatchOptions ToBatchOptions(this BatchOptionsModel batchOptionsModel) + { + return new BatchOptions + { + BatchType = batchOptionsModel.BatchType, + ValueTo = batchOptionsModel.ValueTo, + Increment = batchOptionsModel.Increment, + ValueFrom = batchOptionsModel.ValueFrom + }; + } + + + public static void LogToWindow(string message) { System.Windows.Application.Current.Dispatcher.BeginInvoke(DispatcherPriority.Render, new Action(() => diff --git a/OnnxStack.UI/Views/ImageInpaint.xaml b/OnnxStack.UI/Views/ImageInpaint.xaml index 57b93544..025e388b 100644 --- a/OnnxStack.UI/Views/ImageInpaint.xaml +++ b/OnnxStack.UI/Views/ImageInpaint.xaml @@ -19,23 +19,26 @@