Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Support for low memory devices #111

Merged
merged 4 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions OnnxStack.Core/Model/OnnxModelSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,18 @@ public async Task LoadAsync()
/// Unloads the model session.
/// </summary>
/// <returns></returns>
public Task UnloadAsync()
public async Task UnloadAsync()
{
// TODO: deadlock on model dispose when no synchronization context exists(console app)
// Task.Yield seems to force a context switch resolving any issues, revist this
await Task.Yield();

if (_session is not null)
{
_metadata = null;
_session.Dispose();
_metadata = null;
_session = null;
}
return Task.CompletedTask;
}


Expand Down
7 changes: 7 additions & 0 deletions OnnxStack.StableDiffusion/Config/PipelineOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using OnnxStack.StableDiffusion.Enums;

namespace OnnxStack.StableDiffusion.Config
{
public record PipelineOptions(string Name, MemoryModeType MemoryMode);

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig
public int SampleSize { get; set; } = 512;
public DiffuserPipelineType PipelineType { get; set; }
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();

public MemoryModeType MemoryMode { get; set; }
public int DeviceId { get; set; }
public int InterOpNumThreads { get; set; }
public int IntraOpNumThreads { get; set; }
Expand Down
9 changes: 8 additions & 1 deletion OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public abstract class DiffuserBase : IDiffuser
protected readonly UNetConditionModel _unet;
protected readonly AutoEncoderModel _vaeDecoder;
protected readonly AutoEncoderModel _vaeEncoder;
protected readonly MemoryModeType _memoryMode;

/// <summary>
/// Initializes a new instance of the <see cref="DiffuserBase"/> class.
Expand All @@ -31,12 +32,13 @@ public abstract class DiffuserBase : IDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public DiffuserBase(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
public DiffuserBase(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
{
_logger = logger;
_unet = unet;
_vaeDecoder = vaeDecoder;
_vaeEncoder = vaeEncoder;
_memoryMode = memoryMode;
}

/// <summary>
Expand Down Expand Up @@ -137,10 +139,15 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOption
var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters);
using (var imageResult = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeDecoder.UnloadAsync();

_logger?.LogEnd("Latents decoded", timestamp);
return imageResult.ToDenseTensor();
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public class ControlNetDiffuser : InstaFlowDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger)
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
Expand Down Expand Up @@ -144,9 +144,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public abstract class InstaFlowDiffuser : DiffuserBase
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public InstaFlowDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public InstaFlowDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }

/// <summary>
/// Gets the type of the pipeline.
Expand Down Expand Up @@ -103,9 +103,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
Expand Down
4 changes: 2 additions & 2 deletions OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public sealed class TextDiffuser : InstaFlowDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public class ControlNetDiffuser : LatentConsistencyDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger)
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
Expand Down Expand Up @@ -141,9 +141,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down Expand Up @@ -73,6 +73,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeEncoder.UnloadAsync();

var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public sealed class ImageDiffuser : LatentConsistencyDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down Expand Up @@ -70,6 +70,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeEncoder.UnloadAsync();

var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public sealed class InpaintLegacyDiffuser : LatentConsistencyDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down Expand Up @@ -138,9 +138,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, denoised);
}
Expand Down Expand Up @@ -168,6 +172,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeEncoder.UnloadAsync();

var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scaledSample;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public abstract class LatentConsistencyDiffuser : DiffuserBase
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public LatentConsistencyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public LatentConsistencyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down Expand Up @@ -103,9 +103,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, denoised);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public sealed class TextDiffuser : LatentConsistencyDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public class ControlNetDiffuser : LatentConsistencyXLDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger)
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
Expand Down Expand Up @@ -146,9 +146,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down Expand Up @@ -71,6 +71,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeEncoder.UnloadAsync();

var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public sealed class ImageDiffuser : LatentConsistencyXLDiffuser
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, logger) { }
public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }


/// <summary>
Expand Down Expand Up @@ -72,6 +72,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeEncoder.UnloadAsync();

var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
Expand Down
Loading