From e2acc3ece7f9d6412f9dd3ba4feb03763b67d786 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 28 Feb 2024 12:28:49 +1300 Subject: [PATCH] Support LCM-SDXL guidance embeddings --- .../LatentConsistencyXLDiffuser.cs | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs index 73af2ef8..537a2170 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs @@ -1,4 +1,6 @@ using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; using OnnxStack.Core.Model; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; @@ -6,6 +8,11 @@ using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Schedulers.LatentConsistency; +using System.Diagnostics; +using System.Linq; +using System.Threading.Tasks; +using System.Threading; +using System; namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL { @@ -29,6 +36,92 @@ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel public override DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistencyXL; + /// + /// Runs the scheduler steps. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + public override async Task> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + // Get Scheduler + using (var scheduler = GetScheduler(schedulerOptions)) + { + // Get timesteps + var timesteps = GetTimesteps(schedulerOptions, scheduler); + + // Create latent sample + var latents = await PrepareLatentsAsync(promptOptions, schedulerOptions, scheduler, timesteps); + + // Get Model metadata + var metadata = await _unet.GetMetadataAsync(); + + // Get Time ids + var addTimeIds = GetAddTimeIds(schedulerOptions); + + // Get Guidance Scale Embedding + var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale); + + // Loop though the timesteps + var step = 0; + foreach (var timestep in timesteps) + { + step++; + var stepTime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Create input tensor. + var inputLatent = performGuidance ? latents.Repeat(2) : latents; + var inputTensor = scheduler.ScaleInput(inputLatent, timestep); + var timestepTensor = CreateTimestepTensor(timestep); + var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds; + + var outputChannels = performGuidance ? 2 : 1; + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddInputTensor(timestepTensor); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + if (inferenceParameters.InputCount == 6) + inferenceParameters.AddInputTensor(guidanceEmbeddings); + inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds); + inferenceParameters.AddInputTensor(timeids); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _unet.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) + { + var noisePred = result.ToDenseTensor(); + + // Perform guidance + if (performGuidance) + noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale); + + // Scheduler Step + latents = scheduler.Step(noisePred, timestep, latents).Result; + } + } + + ReportProgress(progressCallback, step, timesteps.Count, latents); + _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime); + } + + // Unload if required + if (_memoryMode == MemoryModeType.Minimum) + await _unet.UnloadAsync(); + + // Decode Latents + return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents); + } + } + + /// /// Gets the scheduler. /// @@ -42,5 +135,26 @@ protected override IScheduler GetScheduler(SchedulerOptions options) _ => default }; } + + + /// + /// Gets the guidance scale embedding. + /// + /// The options. + /// The embedding dim. + /// + protected DenseTensor GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256) + { + var scale = (guidance - 1f) * 1000.0f; + var halfDim = embeddingDim / 2; + float log = MathF.Log(10000.0f) / (halfDim - 1); + var emb = Enumerable.Range(0, halfDim) + .Select(x => scale * MathF.Exp(-log * x)) + .ToArray(); + var embSin = emb.Select(MathF.Sin); + var embCos = emb.Select(MathF.Cos); + var guidanceEmbedding = embSin.Concat(embCos).ToArray(); + return new DenseTensor(guidanceEmbedding, new[] { 1, embeddingDim }); + } } }