diff --git a/OnnxStack.Core/Model/OnnxModelSession.cs b/OnnxStack.Core/Model/OnnxModelSession.cs
index 15376c01..dadeef03 100644
--- a/OnnxStack.Core/Model/OnnxModelSession.cs
+++ b/OnnxStack.Core/Model/OnnxModelSession.cs
@@ -67,14 +67,18 @@ public async Task LoadAsync()
/// Unloads the model session.
///
///
- public Task UnloadAsync()
+ public async Task UnloadAsync()
{
+ // TODO: deadlock on model dispose when no synchronization context exists(console app)
+ // Task.Yield seems to force a context switch resolving any issues, revist this
+ await Task.Yield();
+
if (_session is not null)
{
- _metadata = null;
_session.Dispose();
+ _metadata = null;
+ _session = null;
}
- return Task.CompletedTask;
}
diff --git a/OnnxStack.StableDiffusion/Config/PipelineOptions.cs b/OnnxStack.StableDiffusion/Config/PipelineOptions.cs
new file mode 100644
index 00000000..e79a7845
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Config/PipelineOptions.cs
@@ -0,0 +1,7 @@
+using OnnxStack.StableDiffusion.Enums;
+
+namespace OnnxStack.StableDiffusion.Config
+{
+ public record PipelineOptions(string Name, MemoryModeType MemoryMode);
+
+}
diff --git a/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs
index 27b7d13c..313e9637 100644
--- a/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs
+++ b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs
@@ -14,7 +14,7 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig
public int SampleSize { get; set; } = 512;
public DiffuserPipelineType PipelineType { get; set; }
public List Diffusers { get; set; } = new List();
-
+ public MemoryModeType MemoryMode { get; set; }
public int DeviceId { get; set; }
public int InterOpNumThreads { get; set; }
public int IntraOpNumThreads { get; set; }
diff --git a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
index fdc9d8b7..ad1f1699 100644
--- a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
@@ -22,6 +22,7 @@ public abstract class DiffuserBase : IDiffuser
protected readonly UNetConditionModel _unet;
protected readonly AutoEncoderModel _vaeDecoder;
protected readonly AutoEncoderModel _vaeEncoder;
+ protected readonly MemoryModeType _memoryMode;
///
/// Initializes a new instance of the class.
@@ -31,12 +32,13 @@ public abstract class DiffuserBase : IDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public DiffuserBase(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
+ public DiffuserBase(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
{
_logger = logger;
_unet = unet;
_vaeDecoder = vaeDecoder;
_vaeEncoder = vaeEncoder;
+ _memoryMode = memoryMode;
}
///
@@ -137,10 +139,15 @@ protected virtual async Task> DecodeLatentsAsync(PromptOption
var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters);
using (var imageResult = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeDecoder.UnloadAsync();
+
_logger?.LogEnd("Latents decoded", timestamp);
return imageResult.ToDenseTensor();
}
}
+
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
index d3adaf3f..dfcf4d68 100644
--- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
@@ -30,8 +30,8 @@ public class ControlNetDiffuser : InstaFlowDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger)
+ public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
@@ -144,9 +144,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs
index 8f294118..5bbecc39 100644
--- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs
@@ -26,8 +26,8 @@ public abstract class InstaFlowDiffuser : DiffuserBase
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public InstaFlowDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public InstaFlowDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
/// Gets the type of the pipeline.
@@ -103,9 +103,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs
index b26358dd..f5cccd9d 100644
--- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs
@@ -20,8 +20,8 @@ public sealed class TextDiffuser : InstaFlowDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs
index 17696e08..dc4e4fd4 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs
@@ -29,8 +29,8 @@ public class ControlNetDiffuser : LatentConsistencyDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger)
+ public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
@@ -141,9 +141,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs
index ce777447..b9dc0c70 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs
@@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
+ public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -73,6 +73,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs
index 8e32720b..211abd88 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs
@@ -26,8 +26,8 @@ public sealed class ImageDiffuser : LatentConsistencyDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -70,6 +70,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs
index 20bcb1ac..6d020866 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs
@@ -29,8 +29,8 @@ public sealed class InpaintLegacyDiffuser : LatentConsistencyDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -138,9 +138,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, denoised);
}
@@ -168,6 +172,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scaledSample;
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs
index 1a659c5d..dfba7b23 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs
@@ -26,8 +26,8 @@ public abstract class LatentConsistencyDiffuser : DiffuserBase
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public LatentConsistencyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public LatentConsistencyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -103,9 +103,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, denoised);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs
index 6955e712..b1d14580 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs
@@ -20,8 +20,8 @@ public sealed class TextDiffuser : LatentConsistencyDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs
index 7e67eac1..a077de6c 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs
@@ -29,8 +29,8 @@ public class ControlNetDiffuser : LatentConsistencyXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger)
+ public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
@@ -146,9 +146,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs
index 2ea7459a..28dde8e2 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs
@@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
+ public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -71,6 +71,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs
index 36a10080..42f8110c 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs
@@ -27,8 +27,8 @@ public sealed class ImageDiffuser : LatentConsistencyXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -72,6 +72,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs
index 22a774b9..c18b1eac 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs
@@ -29,8 +29,8 @@ public sealed class InpaintLegacyDiffuser : LatentConsistencyXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -120,9 +120,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
@@ -163,6 +167,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scaledSample;
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs
index 4e4bad52..73af2ef8 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs
@@ -19,8 +19,8 @@ public abstract class LatentConsistencyXLDiffuser : StableDiffusionXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/TextDiffuser.cs
index ef8b80e6..71566406 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/TextDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/TextDiffuser.cs
@@ -20,8 +20,8 @@ public sealed class TextDiffuser : LatentConsistencyXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs
index b1210065..c39b7728 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs
@@ -29,8 +29,8 @@ public class ControlNetDiffuser : StableDiffusionDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger)
+ public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
@@ -137,9 +137,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs
index 9ca55c3a..f6a887f0 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs
@@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
+ public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -71,6 +71,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs
index 18febd7c..7cad1190 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs
@@ -26,8 +26,8 @@ public sealed class ImageDiffuser : StableDiffusionDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -71,6 +71,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs
index 14298644..36a1eada 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs
@@ -30,8 +30,8 @@ public sealed class InpaintDiffuser : StableDiffusionDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public InpaintDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public InpaintDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -108,9 +108,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
@@ -224,6 +228,10 @@ private async Task> PrepareImageMask(PromptOptions promptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var sample = result.ToDenseTensor();
var scaledSample = sample.MultiplyBy(_vaeEncoder.ScaleFactor);
if (schedulerOptions.GuidanceScale > 1f)
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs
index b5ab0f97..554894e5 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs
@@ -29,8 +29,8 @@ public sealed class InpaintLegacyDiffuser : StableDiffusionDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -114,9 +114,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
@@ -157,6 +161,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scaledSample;
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs
index b7b115f4..f9037f59 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs
@@ -25,8 +25,8 @@ public abstract class StableDiffusionDiffuser : DiffuserBase
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public StableDiffusionDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public StableDiffusionDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -95,10 +95,15 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
}
- ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ ReportProgress(progressCallback, step, timesteps.Count, latents);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs
index 90d5bc07..3bb1723a 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs
@@ -20,8 +20,8 @@ public sealed class TextDiffuser : StableDiffusionDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs
index 37abb8fc..b803563a 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs
@@ -29,8 +29,8 @@ public class ControlNetDiffuser : StableDiffusionXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetDiffuser( ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger)
+ public ControlNetDiffuser( ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
{
_controlNet = controlNet;
}
@@ -147,9 +147,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await Task.WhenAll(_controlNet.UnloadAsync(), _unet.UnloadAsync());
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs
index f0f6ac37..9838a4d7 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs
@@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
+ public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -74,6 +74,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs
index 012b2f5b..fa2e2c0d 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs
@@ -26,8 +26,8 @@ public sealed class ImageDiffuser : StableDiffusionXLDiffuser
///
///
///
- public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -74,6 +74,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs
index e3c18ce9..57bdd363 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs
@@ -29,8 +29,8 @@ public sealed class InpaintLegacyDiffuser : StableDiffusionXLDiffuser
///
///
///
- public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public InpaintLegacyDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -121,9 +121,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
@@ -167,6 +171,10 @@ protected override async Task> PrepareLatentsAsync(PromptOpti
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _vaeEncoder.UnloadAsync();
+
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
return scaledSample;
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs
index 09016d86..58c91990 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs
@@ -26,8 +26,8 @@ public abstract class StableDiffusionXLDiffuser : DiffuserBase
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public StableDiffusionXLDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public StableDiffusionXLDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
@@ -104,9 +104,13 @@ public override async Task> DiffuseAsync(PromptOptions prompt
}
ReportProgress(progressCallback, step, timesteps.Count, latents);
- _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
}
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs
index 78716639..e15da820 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs
@@ -20,8 +20,8 @@ public sealed class TextDiffuser : StableDiffusionXLDiffuser
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
- : base(unet, vaeDecoder, vaeEncoder, logger) { }
+ public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
+ : base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
///
diff --git a/OnnxStack.StableDiffusion/Enums/MemoryModeType.cs b/OnnxStack.StableDiffusion/Enums/MemoryModeType.cs
new file mode 100644
index 00000000..44a62010
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Enums/MemoryModeType.cs
@@ -0,0 +1,8 @@
+namespace OnnxStack.StableDiffusion.Enums
+{
+ public enum MemoryModeType
+ {
+ Maximum = 0,
+ Minimum = 10
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs b/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs
index ecbf4d1d..b59e1d89 100644
--- a/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs
+++ b/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs
@@ -22,7 +22,7 @@ public static class ModelFactory
/// The device identifier.
/// The execution provider.
///
- public static StableDiffusionModelSet CreateModelSet(string modelFolder, DiffuserPipelineType pipeline, ModelType modelType, int deviceId, ExecutionProvider executionProvider)
+ public static StableDiffusionModelSet CreateModelSet(string modelFolder, DiffuserPipelineType pipeline, ModelType modelType, int deviceId, ExecutionProvider executionProvider, MemoryModeType memoryMode)
{
var tokenizerPath = Path.Combine(modelFolder, "tokenizer", "model.onnx");
if (!File.Exists(tokenizerPath))
@@ -117,6 +117,7 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse
PipelineType = pipeline,
Diffusers = diffusers,
DeviceId = deviceId,
+ MemoryMode = memoryMode,
ExecutionProvider = executionProvider,
SchedulerOptions = GetDefaultSchedulerOptions(pipeline, modelType),
TokenizerConfig = tokenizerConfig,
diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
index b7675d1b..7f041055 100644
--- a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
@@ -18,14 +18,16 @@ namespace OnnxStack.StableDiffusion.Pipelines
public abstract class PipelineBase : IPipeline
{
protected readonly ILogger _logger;
+ protected readonly PipelineOptions _pipelineOptions;
///
/// Initializes a new instance of the class.
///
/// The logger.
- protected PipelineBase(ILogger logger)
+ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
{
_logger = logger;
+ _pipelineOptions = pipelineOptions;
}
diff --git a/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs
index dce079f4..2f16f39a 100644
--- a/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs
@@ -18,15 +18,15 @@ public sealed class InstaFlowPipeline : StableDiffusionPipeline
///
/// Initializes a new instance of the class.
///
- /// The model name.
+ /// The pipeline options
/// The tokenizer.
/// The text encoder.
/// The unet.
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public InstaFlowPipeline(string name, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
- : base(name, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
+ public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
+ : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
{
_supportedDiffusers = diffusers ?? new List
{
@@ -61,8 +61,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
{
return diffuserType switch
{
- DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
+ DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
@@ -81,7 +81,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
- return new InstaFlowPipeline(modelSet.Name, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
+ var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
+ return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}
@@ -94,9 +95,9 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
/// The execution provider.
/// The logger.
///
- public static new InstaFlowPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ public static new InstaFlowPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, MemoryModeType memoryMode = MemoryModeType.Maximum, ILogger logger = default)
{
- return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.InstaFlow, modelType, deviceId, executionProvider), logger);
+ return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.InstaFlow, modelType, deviceId, executionProvider, memoryMode), logger);
}
}
}
diff --git a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs
index 335ded20..b8ae31a6 100644
--- a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs
@@ -23,15 +23,15 @@ public sealed class LatentConsistencyPipeline : StableDiffusionPipeline
///
/// Initializes a new instance of the class.
///
- /// The model name.
+ /// The pipeline options
/// The tokenizer.
/// The text encoder.
/// The unet.
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public LatentConsistencyPipeline(string name, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
- : base(name, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
+ public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
+ : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
{
_supportedSchedulers = new List
{
@@ -109,11 +109,11 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
{
return diffuserType switch
{
- DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
+ DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
@@ -132,7 +132,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
- return new LatentConsistencyPipeline(modelSet.Name, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
+ var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
+ return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}
@@ -145,9 +146,9 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
/// The execution provider.
/// The logger.
///
- public static new LatentConsistencyPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ public static new LatentConsistencyPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, MemoryModeType memoryMode = MemoryModeType.Maximum, ILogger logger = default)
{
- return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.LatentConsistency, modelType, deviceId, executionProvider), logger);
+ return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.LatentConsistency, modelType, deviceId, executionProvider, memoryMode), logger);
}
}
}
diff --git a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs
index 17660002..aa41ec76 100644
--- a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs
@@ -23,7 +23,7 @@ public sealed class LatentConsistencyXLPipeline : StableDiffusionXLPipeline
///
/// Initializes a new instance of the class.
///
- /// The model name.
+ /// The pipeline options
/// The tokenizer.
/// The tokenizer2.
/// The text encoder.
@@ -32,8 +32,8 @@ public sealed class LatentConsistencyXLPipeline : StableDiffusionXLPipeline
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public LatentConsistencyXLPipeline(string name, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
- : base(name, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
+ public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
+ : base(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
{
_supportedSchedulers = new List
{
@@ -100,11 +100,11 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
{
return diffuserType switch
{
- DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
+ DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
@@ -125,7 +125,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
- return new LatentConsistencyXLPipeline(modelSet.Name, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
+ var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
+ return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}
@@ -138,9 +139,9 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
/// The execution provider.
/// The logger.
///
- public static new LatentConsistencyXLPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ public static new LatentConsistencyXLPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, MemoryModeType memoryMode = MemoryModeType.Maximum, ILogger logger = default)
{
- return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.LatentConsistencyXL, modelType, deviceId, executionProvider), logger);
+ return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.LatentConsistencyXL, modelType, deviceId, executionProvider, memoryMode), logger);
}
}
}
diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
index 93aae031..2e8a6fca 100644
--- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
@@ -21,13 +21,12 @@ namespace OnnxStack.StableDiffusion.Pipelines
{
public class StableDiffusionPipeline : PipelineBase
{
- protected readonly string _name;
protected readonly UNetConditionModel _unet;
protected readonly TokenizerModel _tokenizer;
protected readonly TextEncoderModel _textEncoder;
- protected readonly AutoEncoderModel _vaeDecoder;
- protected readonly AutoEncoderModel _vaeEncoder;
+ protected AutoEncoderModel _vaeDecoder;
+ protected AutoEncoderModel _vaeEncoder;
protected OnnxModelSession _controlNet;
protected List _supportedDiffusers;
protected IReadOnlyList _supportedSchedulers;
@@ -36,16 +35,15 @@ public class StableDiffusionPipeline : PipelineBase
///
/// Initializes a new instance of the class.
///
- /// The model name.
+ /// The pipeline options
/// The tokenizer.
/// The text encoder.
/// The unet.
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public StableDiffusionPipeline(string name, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) : base(logger)
+ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) : base(pipelineOptions, logger)
{
- _name = name;
_unet = unet;
_tokenizer = tokenizer;
_textEncoder = textEncoder;
@@ -78,7 +76,7 @@ public StableDiffusionPipeline(string name, TokenizerModel tokenizer, TextEncode
///
/// Gets the name.
///
- public override string Name => _name;
+ public override string Name => _pipelineOptions.Name;
///
@@ -110,12 +108,18 @@ public StableDiffusionPipeline(string name, TokenizerModel tokenizer, TextEncode
///
public override Task LoadAsync()
{
- return Task.WhenAll(
+ if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
+ return Task.CompletedTask;
+
+ // Preload all models into VRAM
+ return Task.WhenAll
+ (
_unet.LoadAsync(),
_tokenizer.LoadAsync(),
_textEncoder.LoadAsync(),
_vaeDecoder.LoadAsync(),
- _vaeEncoder.LoadAsync());
+ _vaeEncoder.LoadAsync()
+ );
}
@@ -125,7 +129,10 @@ public override Task LoadAsync()
///
public override async Task UnloadAsync()
{
+ // TODO: deadlock on model dispose when no synchronization context exists(console app)
+ // Task.Yield seems to force a context switch resolving any issues, revist this
await Task.Yield();
+
_unet?.Dispose();
_tokenizer?.Dispose();
_textEncoder?.Dispose();
@@ -231,6 +238,32 @@ public override async IAsyncEnumerable RunBatchAsync(BatchOptions b
}
+ ///
+ /// Overrides the vae encoder with a custom implementation, Caller is responsible for model lifetime
+ ///
+ /// The vae encoder.
+ public void OverrideVaeEncoder(AutoEncoderModel vaeEncoder)
+ {
+ if (_vaeEncoder != null)
+ _vaeEncoder.Dispose();
+
+ _vaeEncoder = vaeEncoder;
+ }
+
+
+ ///
+ /// Overrides the vae decoder with a custom implementation, Caller is responsible for model lifetime
+ ///
+ /// The vae decoder.
+ public void OverrideVaeDecoder(AutoEncoderModel vaeDecoder)
+ {
+ if (_vaeDecoder != null)
+ _vaeDecoder.Dispose();
+
+ _vaeDecoder = vaeDecoder;
+ }
+
+
///
/// Creates the diffuser.
///
@@ -241,12 +274,12 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
{
return diffuserType switch
{
- DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageInpaint => new InpaintDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
+ DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageInpaint => new InpaintDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
@@ -269,6 +302,13 @@ protected virtual async Task CreatePromptEmbedsAsync(Pro
var promptEmbeddings = await GeneratePromptEmbedsAsync(promptTokens, maxPromptTokenCount);
var negativePromptEmbeddings = await GeneratePromptEmbedsAsync(negativePromptTokens, maxPromptTokenCount);
+
+ if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
+ {
+ await _tokenizer.UnloadAsync();
+ await _textEncoder.UnloadAsync();
+ }
+
if (isGuidanceEnabled)
return new PromptEmbeddingsResult(negativePromptEmbeddings.Concatenate(promptEmbeddings));
@@ -380,7 +420,8 @@ protected IEnumerable PadWithBlankTokens(IEnumerable inputs, int requi
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
- return new StableDiffusionPipeline(modelSet.Name, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
+ var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
+ return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}
@@ -393,9 +434,10 @@ protected IEnumerable PadWithBlankTokens(IEnumerable inputs, int requi
/// The execution provider.
/// The logger.
///
- public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, MemoryModeType memoryMode = MemoryModeType.Maximum, ILogger logger = default)
{
- return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusion, modelType, deviceId, executionProvider), logger);
+ return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusion, modelType, deviceId, executionProvider, memoryMode), logger);
}
}
+
}
diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs
index 812cf5cb..57ca8059 100644
--- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs
@@ -25,7 +25,7 @@ public class StableDiffusionXLPipeline : StableDiffusionPipeline
///
/// Initializes a new instance of the class.
///
- /// The model name.
+ /// The pipeline options
/// The tokenizer.
/// The tokenizer2.
/// The text encoder.
@@ -34,8 +34,8 @@ public class StableDiffusionXLPipeline : StableDiffusionPipeline
/// The vae decoder.
/// The vae encoder.
/// The logger.
- public StableDiffusionXLPipeline(string name, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
- : base(name, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
+ public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
+ : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
{
_tokenizer2 = tokenizer2;
_textEncoder2 = textEncoder2;
@@ -66,12 +66,18 @@ public StableDiffusionXLPipeline(string name, TokenizerModel tokenizer, Tokenize
///
/// Loads the pipeline
///
- public override async Task LoadAsync()
+ public override Task LoadAsync()
{
- await Task.WhenAll(
+ if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
+ return base.LoadAsync();
+
+ // Preload all models into VRAM
+ return Task.WhenAll
+ (
_tokenizer2.LoadAsync(),
_textEncoder2.LoadAsync(),
- base.LoadAsync());
+ base.LoadAsync()
+ );
}
@@ -97,11 +103,11 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
{
return diffuserType switch
{
- DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
- DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _logger),
+ DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
+ DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
@@ -140,6 +146,13 @@ private async Task CreateEmbedsTwoAsync(PromptOptions pr
var promptEmbeddings = await GenerateEmbedsAsync(promptTokens, maxPromptTokenCount);
var negativePromptEmbeddings = await GenerateEmbedsAsync(negativePromptTokens, maxPromptTokenCount);
+ // Unload if required
+ if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
+ {
+ await _tokenizer2.UnloadAsync();
+ await _textEncoder2.UnloadAsync();
+ }
+
if (isGuidanceEnabled)
return new PromptEmbeddingsResult(
negativePromptEmbeddings.PromptEmbeds.Concatenate(promptEmbeddings.PromptEmbeds),
@@ -180,6 +193,13 @@ private async Task CreateEmbedsBothAsync(PromptOptions p
var pooledPromptEmbeds = dualPromptEmbeddings.PooledPromptEmbeds;
var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings.PooledPromptEmbeds;
+ // Unload if required
+ if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
+ {
+ await _tokenizer2.UnloadAsync();
+ await _textEncoder2.UnloadAsync();
+ }
+
if (isGuidanceEnabled)
return new PromptEmbeddingsResult(dualNegativePrompt.Concatenate(dualPrompt), pooledNegativePromptEmbeds.Concatenate(pooledPromptEmbeds));
@@ -302,7 +322,8 @@ private IEnumerable PadWithBlankTokens(IEnumerable inputs, int requi
var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
- return new StableDiffusionXLPipeline(modelSet.Name, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
+ var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
+ return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}
@@ -315,9 +336,9 @@ private IEnumerable PadWithBlankTokens(IEnumerable inputs, int requi
/// The execution provider.
/// The logger.
///
- public static new StableDiffusionXLPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ public static new StableDiffusionXLPipeline CreatePipeline(string modelFolder, ModelType modelType = ModelType.Base, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, MemoryModeType memoryMode = MemoryModeType.Maximum, ILogger logger = default)
{
- return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusionXL, modelType, deviceId, executionProvider), logger);
+ return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusionXL, modelType, deviceId, executionProvider, memoryMode), logger);
}
}
}
diff --git a/OnnxStack.UI/App.xaml b/OnnxStack.UI/App.xaml
index 6079e5d0..df6686a5 100644
--- a/OnnxStack.UI/App.xaml
+++ b/OnnxStack.UI/App.xaml
@@ -112,6 +112,12 @@
+
+
+
+
+
+
diff --git a/OnnxStack.UI/Dialogs/UpdateModelDialog.xaml b/OnnxStack.UI/Dialogs/UpdateModelDialog.xaml
index 3cab00cb..af52c565 100644
--- a/OnnxStack.UI/Dialogs/UpdateModelDialog.xaml
+++ b/OnnxStack.UI/Dialogs/UpdateModelDialog.xaml
@@ -65,8 +65,12 @@
DeviceId="{Binding UpdateModelSet.DeviceId, Mode=TwoWay}"
ExecutionProvider="{Binding UpdateModelSet.ExecutionProvider, Mode=TwoWay}" />
+
+
+
+
-
+
diff --git a/OnnxStack.UI/Models/OnnxStackUIConfig.cs b/OnnxStack.UI/Models/OnnxStackUIConfig.cs
index 63750597..b6f612fc 100644
--- a/OnnxStack.UI/Models/OnnxStackUIConfig.cs
+++ b/OnnxStack.UI/Models/OnnxStackUIConfig.cs
@@ -1,6 +1,7 @@
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Common.Config;
using OnnxStack.Core.Config;
+using OnnxStack.StableDiffusion.Enums;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
@@ -15,6 +16,7 @@ public class OnnxStackUIConfig : IConfigSection
public int DefaultIntraOpNumThreads { get; set; }
public ExecutionMode DefaultExecutionMode { get; set; }
public ExecutionProvider DefaultExecutionProvider { get; set; }
+ public MemoryModeType DefaultMemoryMode { get; set; }
[JsonIgnore]
public ExecutionProvider SupportedExecutionProvider => GetSupportedExecutionProvider();
diff --git a/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs b/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs
index 957e063c..267592dc 100644
--- a/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs
+++ b/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs
@@ -165,6 +165,15 @@ public DiffuserPipelineType PipelineType
set { _pipelineType = value; NotifyPropertyChanged(); }
}
+ private MemoryModeType _memoryMode;
+
+ public MemoryModeType MemoryMode
+ {
+ get { return _memoryMode; }
+ set { _memoryMode = value; NotifyPropertyChanged(); }
+ }
+
+
private ModelType _modelType;
public ModelType ModelType
@@ -250,6 +259,7 @@ public static UpdateStableDiffusionModelSetViewModel FromModelSet(StableDiffusio
ExecutionProvider = modelset.ExecutionProvider,
InterOpNumThreads = modelset.InterOpNumThreads,
IntraOpNumThreads = modelset.IntraOpNumThreads,
+ MemoryMode = modelset.MemoryMode,
Name = modelset.Name,
PipelineType = modelset.PipelineType,
@@ -402,6 +412,8 @@ public static StableDiffusionModelSet ToModelSet(UpdateStableDiffusionModelSetVi
InterOpNumThreads = modelset.InterOpNumThreads,
IntraOpNumThreads = modelset.IntraOpNumThreads,
+ MemoryMode = modelset.MemoryMode,
+
UnetConfig = new UNetConditionModelConfig
{
ModelType = modelset.ModelType,
diff --git a/OnnxStack.UI/Services/ModelFactory.cs b/OnnxStack.UI/Services/ModelFactory.cs
index eb9ee8ae..c3e94387 100644
--- a/OnnxStack.UI/Services/ModelFactory.cs
+++ b/OnnxStack.UI/Services/ModelFactory.cs
@@ -61,6 +61,7 @@ public StableDiffusionModelSet CreateStableDiffusionModelSet(string name, string
ExecutionProvider = _settings.DefaultExecutionProvider,
InterOpNumThreads = _settings.DefaultInterOpNumThreads,
IntraOpNumThreads = _settings.DefaultIntraOpNumThreads,
+ MemoryMode = _settings.DefaultMemoryMode,
IsEnabled = true,
};
diff --git a/OnnxStack.UI/Views/SettingsView.xaml b/OnnxStack.UI/Views/SettingsView.xaml
index d3ca351d..2f4c178a 100644
--- a/OnnxStack.UI/Views/SettingsView.xaml
+++ b/OnnxStack.UI/Views/SettingsView.xaml
@@ -65,9 +65,14 @@
DeviceId="{Binding UISettings.DefaultDeviceId, Mode=TwoWay}"
ExecutionProvider="{Binding UISettings.DefaultExecutionProvider, Mode=TwoWay}" />
-
+
+
+
+
+
+
-
+
diff --git a/OnnxStack.UI/appsettings.json b/OnnxStack.UI/appsettings.json
index 434ae234..aa0fa737 100644
--- a/OnnxStack.UI/appsettings.json
+++ b/OnnxStack.UI/appsettings.json
@@ -1,8 +1,8 @@
{
"Logging": {
"LogLevel": {
- "Default": "Information",
- "Microsoft.AspNetCore": "Warning"
+ "Default": "Debug",
+ "Microsoft": "Warning"
}
},
"AllowedHosts": "*",