Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Support LCM-SDXL guidance embeddings #123

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using System.Threading;
using System;

namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
{
Expand All @@ -29,6 +36,92 @@ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistencyXL;


/// <summary>
/// Runs the scheduler steps.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
// Get Scheduler
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample
var latents = await PrepareLatentsAsync(promptOptions, schedulerOptions, scheduler, timesteps);

// Get Model metadata
var metadata = await _unet.GetMetadataAsync();

// Get Time ids
var addTimeIds = GetAddTimeIds(schedulerOptions);

// Get Guidance Scale Embedding
var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);

// Loop though the timesteps
var step = 0;
foreach (var timestep in timesteps)
{
step++;
var stepTime = Stopwatch.GetTimestamp();
cancellationToken.ThrowIfCancellationRequested();

// Create input tensor.
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(timestep);
var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;

var outputChannels = performGuidance ? 2 : 1;
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
if (inferenceParameters.InputCount == 6)
inferenceParameters.AddInputTensor(guidanceEmbeddings);
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
inferenceParameters.AddInputTensor(timeids);
inferenceParameters.AddOutputBuffer(outputDimension);

var results = await _unet.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
var noisePred = result.ToDenseTensor();

// Perform guidance
if (performGuidance)
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);

// Scheduler Step
latents = scheduler.Step(noisePred, timestep, latents).Result;
}
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
}


/// <summary>
/// Gets the scheduler.
/// </summary>
Expand All @@ -42,5 +135,26 @@ protected override IScheduler GetScheduler(SchedulerOptions options)
_ => default
};
}


/// <summary>
/// Gets the guidance scale embedding.
/// </summary>
/// <param name="options">The options.</param>
/// <param name="embeddingDim">The embedding dim.</param>
/// <returns></returns>
protected DenseTensor<float> GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256)
{
var scale = (guidance - 1f) * 1000.0f;
var halfDim = embeddingDim / 2;
float log = MathF.Log(10000.0f) / (halfDim - 1);
var emb = Enumerable.Range(0, halfDim)
.Select(x => scale * MathF.Exp(-log * x))
.ToArray();
var embSin = emb.Select(MathF.Sin);
var embCos = emb.Select(MathF.Cos);
var guidanceEmbedding = embSin.Concat(embCos).ToArray();
return new DenseTensor<float>(guidanceEmbedding, new[] { 1, embeddingDim });
}
}
}