diff --git a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs index 2eab6f88..a953c489 100644 --- a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs +++ b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs @@ -105,7 +105,7 @@ public virtual async Task> DiffuseAsync(ModelOptions modelOpt // Create random seed if none was set schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); - var diffuseTime = _logger?.LogBegin("Diffuse starting..."); + var diffuseTime = _logger?.LogBegin("Diffuser starting..."); _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); // Check guidance @@ -114,36 +114,15 @@ public virtual async Task> DiffuseAsync(ModelOptions modelOpt // Process prompts var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions.BaseModel, promptOptions, performGuidance); - // If video input, process frames - if (promptOptions.HasInputVideo) - { - var frameIndex = 0; - DenseTensor videoTensor = null; - var videoFrames = promptOptions.InputVideo.VideoFrames.Frames; - var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex); - foreach (var videoFrame in videoFrames) - { - frameIndex++; - promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame); - promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame); - var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken); - - // Frame Progress - ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor); + var tensorResult = promptOptions.HasInputVideo + ? await DiffuseVideoAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken) + : await DiffuseImageAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); - // Concatenate frame - videoTensor = videoTensor.Concatenate(frameResultTensor); - } + _logger?.LogEnd($"Diffuser complete", diffuseTime); + return tensorResult; + } - _logger?.LogEnd($"Diffuse complete", diffuseTime); - return videoTensor; - } - // Run Scheduler steps - var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); - _logger?.LogEnd($"Diffuse complete", diffuseTime); - return schedulerResult; - } @@ -180,13 +159,73 @@ public virtual async IAsyncEnumerable DiffuseBatchAsync(ModelOption var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex); foreach (var batchSchedulerOption in batchSchedulerOptions) { - var diffuseTime = _logger?.LogBegin("Diffuse starting..."); - yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken)); - _logger?.LogEnd($"Diffuse complete", diffuseTime); + var tensorResult = promptOptions.HasInputVideo + ? await DiffuseVideoAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken) + : await DiffuseImageAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken); + + yield return new BatchResult(batchSchedulerOption, tensorResult); batchIndex++; } - _logger?.LogEnd($"Diffuse batch complete", diffuseBatchTime); + _logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime); + } + + + /// + /// Diffuses the image. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected virtual async Task> DiffuseImageAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + var diffuseTime = _logger?.LogBegin("Image Diffuser starting..."); + var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + _logger?.LogEnd($"Image Diffuser complete", diffuseTime); + return schedulerResult; + } + + + /// + /// Diffuses the video. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected virtual async Task> DiffuseVideoAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + var diffuseTime = _logger?.LogBegin("Video Diffuser starting..."); + + var frameIndex = 0; + DenseTensor videoTensor = null; + var videoFrames = promptOptions.InputVideo.VideoFrames.Frames; + var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex); + foreach (var videoFrame in videoFrames) + { + frameIndex++; + promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame); + promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame); + var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken); + + // Frame Progress + ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor); + + // Concatenate frame + videoTensor = videoTensor.Concatenate(frameResultTensor); + } + + _logger?.LogEnd($"Video Diffuser complete", diffuseTime); + return videoTensor; }