Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Refactor StableDiffusionPipeline #119

Merged
merged 2 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions OnnxStack.Core/Video/OnnxVideo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ public OnnxVideo(VideoInfo info, DenseTensor<float> videoTensor)
}


/// <summary>
/// Initializes a new instance of the <see cref="OnnxVideo"/> class.
/// </summary>
/// <param name="info">The information.</param>
/// <param name="videoTensors">The video tensors.</param>
public OnnxVideo(VideoInfo info, IEnumerable<DenseTensor<float>> videoTensors)
{
_info = info;
_frames = videoTensors
.Select(x => new OnnxImage(x))
.ToList();
}


/// <summary>
/// Gets the height.
/// </summary>
Expand Down
4 changes: 4 additions & 0 deletions OnnxStack.StableDiffusion/Common/BatchResult.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Image;
using OnnxStack.Core.Video;
using OnnxStack.StableDiffusion.Config;

namespace OnnxStack.StableDiffusion.Common
{
public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor<float> Result);
public record BatchImageResult(SchedulerOptions SchedulerOptions, OnnxImage Result);
public record BatchVideoResult(SchedulerOptions SchedulerOptions, OnnxVideo Result);
}
4 changes: 2 additions & 2 deletions OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public interface IPipeline
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);


/// <summary>
Expand All @@ -126,6 +126,6 @@ public interface IPipeline
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
}
}
4 changes: 2 additions & 2 deletions OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public abstract IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
public abstract IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);


/// <summary>
Expand All @@ -158,7 +158,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public abstract IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
public abstract IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);


/// <summary>
Expand Down
Loading