Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Batch Video Processing #95

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 71 additions & 32 deletions OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOpt
// Create random seed if none was set
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();

var diffuseTime = _logger?.LogBegin("Diffuse starting...");
var diffuseTime = _logger?.LogBegin("Diffuser starting...");
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}");

// Check guidance
Expand All @@ -114,36 +114,15 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOpt
// Process prompts
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions.BaseModel, promptOptions, performGuidance);

// If video input, process frames
if (promptOptions.HasInputVideo)
{
var frameIndex = 0;
DenseTensor<float> videoTensor = null;
var videoFrames = promptOptions.InputVideo.VideoFrames.Frames;
var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex);
foreach (var videoFrame in videoFrames)
{
frameIndex++;
promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame);
promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame);
var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);

// Frame Progress
ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor);
var tensorResult = promptOptions.HasInputVideo
? await DiffuseVideoAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken)
: await DiffuseImageAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);

// Concatenate frame
videoTensor = videoTensor.Concatenate(frameResultTensor);
}
_logger?.LogEnd($"Diffuser complete", diffuseTime);
return tensorResult;
}

_logger?.LogEnd($"Diffuse complete", diffuseTime);
return videoTensor;
}

// Run Scheduler steps
var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
_logger?.LogEnd($"Diffuse complete", diffuseTime);
return schedulerResult;
}



Expand Down Expand Up @@ -180,13 +159,73 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(ModelOption
var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex);
foreach (var batchSchedulerOption in batchSchedulerOptions)
{
var diffuseTime = _logger?.LogBegin("Diffuse starting...");
yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken));
_logger?.LogEnd($"Diffuse complete", diffuseTime);
var tensorResult = promptOptions.HasInputVideo
? await DiffuseVideoAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken)
: await DiffuseImageAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken);

yield return new BatchResult(batchSchedulerOption, tensorResult);
batchIndex++;
}

_logger?.LogEnd($"Diffuse batch complete", diffuseBatchTime);
_logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime);
}


/// <summary>
/// Diffuses the image.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
protected virtual async Task<DenseTensor<float>> DiffuseImageAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
var diffuseTime = _logger?.LogBegin("Image Diffuser starting...");
var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
_logger?.LogEnd($"Image Diffuser complete", diffuseTime);
return schedulerResult;
}


/// <summary>
/// Diffuses the video.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
protected virtual async Task<DenseTensor<float>> DiffuseVideoAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
var diffuseTime = _logger?.LogBegin("Video Diffuser starting...");

var frameIndex = 0;
DenseTensor<float> videoTensor = null;
var videoFrames = promptOptions.InputVideo.VideoFrames.Frames;
var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex);
foreach (var videoFrame in videoFrames)
{
frameIndex++;
promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame);
promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame);
var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);

// Frame Progress
ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor);

// Concatenate frame
videoTensor = videoTensor.Concatenate(frameResultTensor);
}

_logger?.LogEnd($"Video Diffuser complete", diffuseTime);
return videoTensor;
}


Expand Down