diff --git a/OnnxStack.Console/Examples/StableDebug.cs b/OnnxStack.Console/Examples/StableDebug.cs index 7b29949f..47172495 100644 --- a/OnnxStack.Console/Examples/StableDebug.cs +++ b/OnnxStack.Console/Examples/StableDebug.cs @@ -48,7 +48,7 @@ public async Task RunAsync() Strength = 0.6f }; - foreach (var model in _stableDiffusionService.Models) + foreach (var model in _stableDiffusionService.ModelSets) { OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModelAsync(model); @@ -71,7 +71,7 @@ public async Task RunAsync() } - private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options) + private async Task GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options) { var timestamp = Stopwatch.GetTimestamp(); var outputFilename = Path.Combine(_outputDirectory, $"{model.Name}_{options.Seed}_{options.SchedulerType}.png"); diff --git a/OnnxStack.Console/Examples/StableDiffusionBatch.cs b/OnnxStack.Console/Examples/StableDiffusionBatch.cs index 5faf87aa..7e1ab8a2 100644 --- a/OnnxStack.Console/Examples/StableDiffusionBatch.cs +++ b/OnnxStack.Console/Examples/StableDiffusionBatch.cs @@ -51,7 +51,7 @@ public async Task RunAsync() BatchType = BatchOptionType.Scheduler }; - foreach (var model in _stableDiffusionService.Models) + foreach (var model in _stableDiffusionService.ModelSets) { OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModelAsync(model); diff --git a/OnnxStack.Console/Examples/StableDiffusionExample.cs b/OnnxStack.Console/Examples/StableDiffusionExample.cs index e22fd05a..6045a64b 100644 --- a/OnnxStack.Console/Examples/StableDiffusionExample.cs +++ b/OnnxStack.Console/Examples/StableDiffusionExample.cs @@ -47,7 +47,7 @@ public async Task RunAsync() Seed = Random.Shared.Next() }; - foreach (var model in _stableDiffusionService.Models) + foreach (var model in _stableDiffusionService.ModelSets) { OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModelAsync(model); @@ -65,7 +65,7 @@ public async Task RunAsync() } } - private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options) + private async Task GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options) { var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png"); var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options); diff --git a/OnnxStack.Console/Examples/StableDiffusionGenerator.cs b/OnnxStack.Console/Examples/StableDiffusionGenerator.cs index 3dbd8b79..a51a9509 100644 --- a/OnnxStack.Console/Examples/StableDiffusionGenerator.cs +++ b/OnnxStack.Console/Examples/StableDiffusionGenerator.cs @@ -31,7 +31,7 @@ public async Task RunAsync() Directory.CreateDirectory(_outputDirectory); var seed = Random.Shared.Next(); - foreach (var model in _stableDiffusionService.Models) + foreach (var model in _stableDiffusionService.ModelSets) { OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModelAsync(model); @@ -62,7 +62,7 @@ public async Task RunAsync() OutputHelpers.ReadConsole(ConsoleColor.Gray); } - private async Task GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options, string key) + private async Task GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, string key) { var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png"); var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options); diff --git a/OnnxStack.Console/Examples/StableDiffusionGif.cs b/OnnxStack.Console/Examples/StableDiffusionGif.cs index b04f0d8f..12db5412 100644 --- a/OnnxStack.Console/Examples/StableDiffusionGif.cs +++ b/OnnxStack.Console/Examples/StableDiffusionGif.cs @@ -54,7 +54,7 @@ public async Task RunAsync() }; // Choose Model - var model = _stableDiffusionService.Models.FirstOrDefault(x => x.Name == "LCM-Dreamshaper-V7"); + var model = _stableDiffusionService.ModelSets.FirstOrDefault(x => x.Name == "LCM-Dreamshaper-V7"); OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green); await _stableDiffusionService.LoadModelAsync(model); diff --git a/OnnxStack.Console/appsettings.json b/OnnxStack.Console/appsettings.json index 522b23ce..09bb6916 100644 --- a/OnnxStack.Console/appsettings.json +++ b/OnnxStack.Console/appsettings.json @@ -6,8 +6,8 @@ } }, "AllowedHosts": "*", - "OnnxStackConfig": { - "OnnxModelSets": [ + "StableDiffusionConfig": { + "ModelSets": [ { "Name": "StableDiffusion 1.5", "IsEnabled": true, diff --git a/OnnxStack.Core/Config/IOnnxModelSetConfig.cs b/OnnxStack.Core/Config/IOnnxModelSetConfig.cs index d53e011e..fb62b995 100644 --- a/OnnxStack.Core/Config/IOnnxModelSetConfig.cs +++ b/OnnxStack.Core/Config/IOnnxModelSetConfig.cs @@ -11,6 +11,6 @@ public interface IOnnxModelSetConfig : IOnnxModel int IntraOpNumThreads { get; set; } ExecutionMode ExecutionMode { get; set; } ExecutionProvider ExecutionProvider { get; set; } - List ModelConfigurations { get; set; } + List ModelConfigurations { get; set; } } } diff --git a/OnnxStack.Core/Config/OnnxModelSessionConfig.cs b/OnnxStack.Core/Config/OnnxModelConfig.cs similarity index 95% rename from OnnxStack.Core/Config/OnnxModelSessionConfig.cs rename to OnnxStack.Core/Config/OnnxModelConfig.cs index cb6cda12..ab1a46e4 100644 --- a/OnnxStack.Core/Config/OnnxModelSessionConfig.cs +++ b/OnnxStack.Core/Config/OnnxModelConfig.cs @@ -3,7 +3,7 @@ namespace OnnxStack.Core.Config { - public class OnnxModelSessionConfig + public class OnnxModelConfig { public OnnxModelType Type { get; set; } public string OnnxModelPath { get; set; } diff --git a/OnnxStack.Core/Config/OnnxModelEqualityComparer.cs b/OnnxStack.Core/Config/OnnxModelEqualityComparer.cs new file mode 100644 index 00000000..a42eb3be --- /dev/null +++ b/OnnxStack.Core/Config/OnnxModelEqualityComparer.cs @@ -0,0 +1,17 @@ +using System.Collections.Generic; + +namespace OnnxStack.Core.Config +{ + public class OnnxModelEqualityComparer : IEqualityComparer + { + public bool Equals(IOnnxModel x, IOnnxModel y) + { + return x != null && y != null && x.Name == y.Name; + } + + public int GetHashCode(IOnnxModel obj) + { + return obj?.Name?.GetHashCode() ?? 0; + } + } +} diff --git a/OnnxStack.Core/Config/OnnxModelSetConfig.cs b/OnnxStack.Core/Config/OnnxModelSetConfig.cs index 1361ede8..82b205c3 100644 --- a/OnnxStack.Core/Config/OnnxModelSetConfig.cs +++ b/OnnxStack.Core/Config/OnnxModelSetConfig.cs @@ -13,6 +13,6 @@ public class OnnxModelSetConfig : IOnnxModelSetConfig public int IntraOpNumThreads { get; set; } public ExecutionMode ExecutionMode { get; set; } public ExecutionProvider ExecutionProvider { get; set; } - public List ModelConfigurations { get; set; } + public List ModelConfigurations { get; set; } } } diff --git a/OnnxStack.Core/Config/OnnxStackConfig.cs b/OnnxStack.Core/Config/OnnxStackConfig.cs index 1b0a4e02..fb2c035b 100644 --- a/OnnxStack.Core/Config/OnnxStackConfig.cs +++ b/OnnxStack.Core/Config/OnnxStackConfig.cs @@ -1,22 +1,11 @@ using OnnxStack.Common.Config; -using System.Collections.Generic; -using System.Linq; namespace OnnxStack.Core.Config { public class OnnxStackConfig : IConfigSection { - public List OnnxModelSets { get; set; } = new List(); - public void Initialize() { - if (OnnxModelSets.IsNullOrEmpty()) - return; - - foreach (var modelSet in OnnxModelSets) - { - modelSet.ApplyConfigurationOverrides(); - } } } } diff --git a/OnnxStack.Core/Extensions/Extensions.cs b/OnnxStack.Core/Extensions/Extensions.cs index 12daf882..d4771dc2 100644 --- a/OnnxStack.Core/Extensions/Extensions.cs +++ b/OnnxStack.Core/Extensions/Extensions.cs @@ -10,7 +10,7 @@ namespace OnnxStack.Core { public static class Extensions { - public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig configuration) + public static SessionOptions GetSessionOptions(this OnnxModelConfig configuration) { var sessionOptions = new SessionOptions { diff --git a/OnnxStack.Core/Model/OnnxModelSession.cs b/OnnxStack.Core/Model/OnnxModelSession.cs index 125d13fa..169c6856 100644 --- a/OnnxStack.Core/Model/OnnxModelSession.cs +++ b/OnnxStack.Core/Model/OnnxModelSession.cs @@ -9,7 +9,7 @@ public class OnnxModelSession : IDisposable { private readonly SessionOptions _options; private readonly InferenceSession _session; - private readonly OnnxModelSessionConfig _configuration; + private readonly OnnxModelConfig _configuration; /// /// Initializes a new instance of the class. @@ -17,7 +17,7 @@ public class OnnxModelSession : IDisposable /// The configuration. /// The container. /// Onnx model file not found - public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsContainer container) + public OnnxModelSession(OnnxModelConfig configuration, PrePackedWeightsContainer container) { if (!File.Exists(configuration.OnnxModelPath)) throw new FileNotFoundException("Onnx model file not found", configuration.OnnxModelPath); @@ -44,7 +44,7 @@ public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsCo /// /// Gets the configuration. /// - public OnnxModelSessionConfig Configuration => _configuration; + public OnnxModelConfig Configuration => _configuration; /// diff --git a/OnnxStack.Core/Model/OnnxModelSet.cs b/OnnxStack.Core/Model/OnnxModelSet.cs index 869a02db..64206f4e 100644 --- a/OnnxStack.Core/Model/OnnxModelSet.cs +++ b/OnnxStack.Core/Model/OnnxModelSet.cs @@ -78,7 +78,7 @@ public InferenceSession GetSession(OnnxModelType modelType) /// /// Type of the model. /// - public OnnxModelSessionConfig GetConfiguration(OnnxModelType modelType) + public OnnxModelConfig GetConfiguration(OnnxModelType modelType) { return _configuration.ModelConfigurations.FirstOrDefault(x => x.Type == modelType); } diff --git a/OnnxStack.Core/Registration.cs b/OnnxStack.Core/Registration.cs index 93c0f18c..7af89657 100644 --- a/OnnxStack.Core/Registration.cs +++ b/OnnxStack.Core/Registration.cs @@ -16,7 +16,7 @@ public static class Registration /// The service collection. public static void AddOnnxStack(this IServiceCollection serviceCollection) { - serviceCollection.AddSingleton(ConfigManager.LoadConfiguration()); + serviceCollection.AddSingleton(TryLoadAppSettings()); serviceCollection.AddSingleton(); } @@ -43,5 +43,22 @@ public static void AddOnnxStackConfig(this IServiceCollection serviceCollecti { serviceCollection.AddSingleton(ConfigManager.LoadConfiguration()); } + + + /// + /// Try load OnnxStackConfig from application settings if it exists. + /// + /// + private static OnnxStackConfig TryLoadAppSettings() + { + try + { + return ConfigManager.LoadConfiguration(); + } + catch + { + return new OnnxStackConfig(); + } + } } } diff --git a/OnnxStack.Core/Services/IOnnxModelService.cs b/OnnxStack.Core/Services/IOnnxModelService.cs index 8b9fe4ec..3bd259a1 100644 --- a/OnnxStack.Core/Services/IOnnxModelService.cs +++ b/OnnxStack.Core/Services/IOnnxModelService.cs @@ -26,14 +26,12 @@ public interface IOnnxModelService : IDisposable /// Task AddModelSet(IOnnxModelSetConfig modelSet); - /// /// Adds a collection of ModelSet /// /// The model sets. Task AddModelSet(IEnumerable modelSets); - /// /// Removes a model set. /// @@ -41,6 +39,13 @@ public interface IOnnxModelService : IDisposable /// Task RemoveModelSet(IOnnxModelSetConfig modelSet); + /// + /// Updates the model set. + /// + /// The model set. + /// + Task UpdateModelSet(IOnnxModelSetConfig modelSet); + /// /// Loads the model. /// @@ -65,13 +70,6 @@ public interface IOnnxModelService : IDisposable bool IsModelLoaded(IOnnxModel model); - /// - /// Updates the model set. - /// - /// The model set. - /// - bool UpdateModelSet(IOnnxModelSetConfig modelSet); - /// /// Determines whether the specified model type is enabled. /// diff --git a/OnnxStack.Core/Services/OnnxModelService.cs b/OnnxStack.Core/Services/OnnxModelService.cs index 0d93ff37..eb88a678 100644 --- a/OnnxStack.Core/Services/OnnxModelService.cs +++ b/OnnxStack.Core/Services/OnnxModelService.cs @@ -16,8 +16,8 @@ namespace OnnxStack.Core.Services public sealed class OnnxModelService : IOnnxModelService { private readonly OnnxStackConfig _configuration; - private readonly ConcurrentDictionary _onnxModelSets; - private readonly ConcurrentDictionary _onnxModelSetConfigs; + private readonly ConcurrentDictionary _onnxModelSets; + private readonly ConcurrentDictionary _onnxModelSetConfigs; /// /// Initializes a new instance of the class. @@ -26,8 +26,8 @@ public sealed class OnnxModelService : IOnnxModelService public OnnxModelService(OnnxStackConfig configuration) { _configuration = configuration; - _onnxModelSets = new ConcurrentDictionary(); - _onnxModelSetConfigs = _configuration.OnnxModelSets.ToConcurrentDictionary(x => x.Name, x => x as IOnnxModelSetConfig); + _onnxModelSets = new ConcurrentDictionary(new OnnxModelEqualityComparer()); + _onnxModelSetConfigs = new ConcurrentDictionary(new OnnxModelEqualityComparer()); } @@ -50,7 +50,7 @@ public OnnxModelService(OnnxStackConfig configuration) /// public Task AddModelSet(IOnnxModelSetConfig modelSet) { - return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet)); + return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet)); } /// @@ -74,7 +74,7 @@ public Task AddModelSet(IEnumerable modelSets) /// public Task RemoveModelSet(IOnnxModelSetConfig modelSet) { - return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet.Name, out _)); + return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet, out _)); } @@ -83,10 +83,10 @@ public Task RemoveModelSet(IOnnxModelSetConfig modelSet) /// /// The model set. /// - public bool UpdateModelSet(IOnnxModelSetConfig modelSet) + public Task UpdateModelSet(IOnnxModelSetConfig modelSet) { - _onnxModelSetConfigs.TryRemove(modelSet.Name, out _); - return _onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet); + _onnxModelSetConfigs.TryRemove(modelSet, out _); + return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet)); } @@ -120,7 +120,7 @@ public async Task UnloadModelAsync(IOnnxModel model) /// public bool IsModelLoaded(IOnnxModel model) { - return _onnxModelSets.ContainsKey(model.Name); + return _onnxModelSets.ContainsKey(model); } @@ -251,7 +251,7 @@ private OnnxMetadata GetNodeMetadataInternal(IOnnxModel model, OnnxModelType mod /// Model {model.Name} has not been loaded private OnnxModelSet GetModelSet(IOnnxModel model) { - if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet)) + if (!_onnxModelSets.TryGetValue(model, out var modelSet)) throw new Exception($"Model {model.Name} has not been loaded"); return modelSet; @@ -266,17 +266,17 @@ private OnnxModelSet GetModelSet(IOnnxModel model) /// Model {model.Name} not found in configuration private OnnxModelSet LoadModelSet(IOnnxModel model) { - if (_onnxModelSets.ContainsKey(model.Name)) - return _onnxModelSets[model.Name]; + if (_onnxModelSets.ContainsKey(model)) + return _onnxModelSets[model]; - if (!_onnxModelSetConfigs.TryGetValue(model.Name, out var modelSetConfig)) - throw new Exception($"Model {model.Name} not found in configuration"); + if (!_onnxModelSetConfigs.TryGetValue(model, out var modelSetConfig)) + throw new Exception($"Model {model.Name} not found"); if (!modelSetConfig.IsEnabled) throw new Exception($"Model {model.Name} is not enabled"); var modelSet = new OnnxModelSet(modelSetConfig); - _onnxModelSets.TryAdd(model.Name, modelSet); + _onnxModelSets.TryAdd(model, modelSet); return modelSet; } @@ -288,10 +288,10 @@ private OnnxModelSet LoadModelSet(IOnnxModel model) /// private bool UnloadModelSet(IOnnxModel model) { - if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet)) + if (!_onnxModelSets.TryGetValue(model, out _)) return true; - if (_onnxModelSets.TryRemove(model.Name, out modelSet)) + if (_onnxModelSets.TryRemove(model, out var modelSet)) { modelSet?.Dispose(); return true; @@ -310,9 +310,5 @@ public void Dispose() onnxModelSet?.Dispose(); } } - - } - - } diff --git a/OnnxStack.ImageUpscaler/Config/UpscaleModelSet.cs b/OnnxStack.ImageUpscaler/Config/UpscaleModelSet.cs index f6e00501..5a7ba714 100644 --- a/OnnxStack.ImageUpscaler/Config/UpscaleModelSet.cs +++ b/OnnxStack.ImageUpscaler/Config/UpscaleModelSet.cs @@ -16,6 +16,6 @@ public class UpscaleModelSet : IOnnxModelSetConfig public int IntraOpNumThreads { get; set; } public ExecutionMode ExecutionMode { get; set; } public ExecutionProvider ExecutionProvider { get; set; } - public List ModelConfigurations { get; set; } + public List ModelConfigurations { get; set; } } } diff --git a/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs b/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs index eeb00105..1b4fba95 100644 --- a/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs +++ b/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs @@ -14,16 +14,30 @@ public interface IUpscaleService { /// - /// Gets the configuration. + /// Gets the model sets. /// - ImageUpscalerConfig Configuration { get; } + IReadOnlyCollection ModelSets { get; } + /// + /// Adds the model. + /// + /// The model. + /// + Task AddModelAsync(UpscaleModelSet model); /// - /// Gets the model sets. + /// Removes the model. /// - IReadOnlyList ModelSets { get; } + /// The model. + /// + Task RemoveModelAsync(UpscaleModelSet model); + /// + /// Updates the model. + /// + /// The model. + /// + Task UpdateModelAsync(UpscaleModelSet model); /// /// Loads the model. diff --git a/OnnxStack.ImageUpscaler/Services/UpscaleService.cs b/OnnxStack.ImageUpscaler/Services/UpscaleService.cs index bbb689c6..eb2e6254 100644 --- a/OnnxStack.ImageUpscaler/Services/UpscaleService.cs +++ b/OnnxStack.ImageUpscaler/Services/UpscaleService.cs @@ -11,6 +11,7 @@ using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; using SixLabors.ImageSharp.Processing; +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -22,6 +23,7 @@ public class UpscaleService : IUpscaleService { private readonly IOnnxModelService _modelService; private readonly ImageUpscalerConfig _configuration; + private readonly HashSet _modelSetConfigs; /// /// Initializes a new instance of the class. @@ -33,20 +35,67 @@ public UpscaleService(ImageUpscalerConfig configuration, IOnnxModelService model { _configuration = configuration; _modelService = modelService; - _modelService.AddModelSet(_configuration.ModelSets); + _modelSetConfigs = new HashSet(_configuration.ModelSets, new OnnxModelEqualityComparer()); + _modelService.AddModelSet(_modelSetConfigs); } /// - /// Gets the configuration. + /// Gets the model sets. /// - public ImageUpscalerConfig Configuration => _configuration; + public IReadOnlyCollection ModelSets => _modelSetConfigs; /// - /// Gets the model sets. + /// Adds the model. /// - public IReadOnlyList ModelSets => _configuration.ModelSets; + /// The model. + /// + /// + public async Task AddModelAsync(UpscaleModelSet model) + { + if (await _modelService.AddModelSet(model)) + { + _modelSetConfigs.Add(model); + return true; + } + return false; + } + + + /// + /// Removes the model. + /// + /// The model. + /// + /// + public async Task RemoveModelAsync(UpscaleModelSet model) + { + if (await _modelService.RemoveModelSet(model)) + { + _modelSetConfigs.Remove(model); + return true; + } + return false; + } + + + /// + /// Updates the model. + /// + /// The model. + /// + /// + public async Task UpdateModelAsync(UpscaleModelSet model) + { + if (await _modelService.UpdateModelSet(model)) + { + _modelSetConfigs.Remove(model); + _modelSetConfigs.Add(model); + return true; + } + return false; + } /// @@ -56,6 +105,9 @@ public UpscaleService(ImageUpscalerConfig configuration, IOnnxModelService model /// public async Task LoadModelAsync(UpscaleModelSet model) { + if (!_modelSetConfigs.TryGetValue(model, out _)) + throw new Exception("ModelSet not found"); + var modelSet = await _modelService.LoadModelAsync(model); return modelSet is not null; } diff --git a/OnnxStack.StableDiffusion/Common/IModelOptions.cs b/OnnxStack.StableDiffusion/Common/IModelOptions.cs deleted file mode 100644 index 654fd40c..00000000 --- a/OnnxStack.StableDiffusion/Common/IModelOptions.cs +++ /dev/null @@ -1,24 +0,0 @@ -using OnnxStack.Core.Config; -using OnnxStack.StableDiffusion.Enums; -using System.Collections.Generic; -using System.Collections.Immutable; - -namespace OnnxStack.StableDiffusion.Common -{ - public interface IModelOptions : IOnnxModel - { - bool IsEnabled { get; set; } - int PadTokenId { get; set; } - int BlankTokenId { get; set; } - int SampleSize { get; set; } - float ScaleFactor { get; set; } - int TokenizerLimit { get; set; } - int TokenizerLength { get; set; } - int Tokenizer2Length { get; set; } - ModelType ModelType { get; set; } - TokenizerType TokenizerType { get; set; } - DiffuserPipelineType PipelineType { get; set; } - List Diffusers { get; set; } - ImmutableArray BlankTokenValueArray { get; set; } - } -} \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Common/IPromptService.cs b/OnnxStack.StableDiffusion/Common/IPromptService.cs index bef8fa52..c3bc9015 100644 --- a/OnnxStack.StableDiffusion/Common/IPromptService.cs +++ b/OnnxStack.StableDiffusion/Common/IPromptService.cs @@ -1,11 +1,10 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Config; using System.Threading.Tasks; namespace OnnxStack.StableDiffusion.Common { public interface IPromptService { - Task CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled); + Task CreatePromptAsync(StableDiffusionModelSet model, PromptOptions promptOptions, bool isGuidanceEnabled); } } \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs b/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs index 281ffbf2..2d2e0cd8 100644 --- a/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs @@ -1,6 +1,4 @@ using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Config; -using OnnxStack.Core.Model; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; @@ -15,18 +13,40 @@ namespace OnnxStack.StableDiffusion.Common { public interface IStableDiffusionService { + /// + /// Gets the models. + /// + IReadOnlyCollection ModelSets { get; } /// - /// Gets the models. + /// Adds the model. + /// + /// The model. + /// + Task AddModelAsync(StableDiffusionModelSet model); + + + /// + /// Removes the model. /// - List Models { get; } + /// The model. + /// + Task RemoveModelAsync(StableDiffusionModelSet model); + + + /// + /// Updates the model. + /// + /// The model. + /// + Task UpdateModelAsync(StableDiffusionModelSet model); /// /// Loads the model. /// /// The model options. /// - Task LoadModelAsync(IModelOptions modelOptions); + Task LoadModelAsync(StableDiffusionModelSet model); /// @@ -34,7 +54,7 @@ public interface IStableDiffusionService /// /// The model options. /// - Task UnloadModelAsync(IModelOptions modelOptions); + Task UnloadModelAsync(StableDiffusionModelSet model); /// /// Determines whether the specified model is loaded @@ -43,7 +63,7 @@ public interface IStableDiffusionService /// /// true if the specified model is loaded; otherwise, false. /// - bool IsModelLoaded(IModelOptions modelOptions); + bool IsModelLoaded(StableDiffusionModelSet model); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -53,7 +73,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task> GenerateAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -63,7 +83,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task> GenerateAsImageAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -73,7 +93,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task GenerateAsBytesAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -83,7 +103,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task GenerateAsStreamAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -95,7 +115,7 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -107,7 +127,7 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -119,7 +139,7 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -131,6 +151,6 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Config/StableDiffusionConfig.cs b/OnnxStack.StableDiffusion/Config/StableDiffusionConfig.cs index cd42e7b8..845ba91c 100644 --- a/OnnxStack.StableDiffusion/Config/StableDiffusionConfig.cs +++ b/OnnxStack.StableDiffusion/Config/StableDiffusionConfig.cs @@ -4,29 +4,29 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; namespace OnnxStack.StableDiffusion.Config { public class StableDiffusionConfig : IConfigSection { - public List OnnxModelSets { get; set; } = new List(); + public List ModelSets { get; set; } = new List(); public void Initialize() { - if (OnnxModelSets.IsNullOrEmpty()) + if (ModelSets.IsNullOrEmpty()) return; var defaultTokenizer = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "cliptokenizer.onnx"); if (!File.Exists(defaultTokenizer)) defaultTokenizer = string.Empty; - foreach (var modelSet in OnnxModelSets) + foreach (var modelSet in ModelSets) { modelSet.InitBlankTokenArray(); + modelSet.ApplyConfigurationOverrides(); foreach (var model in modelSet.ModelConfigurations) { - if (model.Type == OnnxModelType.Tokenizer && string.IsNullOrEmpty(model.OnnxModelPath)) + if ((model.Type == OnnxModelType.Tokenizer || model.Type == OnnxModelType.Tokenizer2) && string.IsNullOrEmpty(model.OnnxModelPath)) model.OnnxModelPath = defaultTokenizer; if (!File.Exists(model.OnnxModelPath)) diff --git a/OnnxStack.StableDiffusion/Config/ModelOptions.cs b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs similarity index 88% rename from OnnxStack.StableDiffusion/Config/ModelOptions.cs rename to OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs index dd2a22e7..6ccf1460 100644 --- a/OnnxStack.StableDiffusion/Config/ModelOptions.cs +++ b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs @@ -1,6 +1,5 @@ using Microsoft.ML.OnnxRuntime; using OnnxStack.Core.Config; -using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Enums; using System.Collections.Generic; using System.Collections.Immutable; @@ -9,7 +8,7 @@ namespace OnnxStack.StableDiffusion.Config { - public class ModelOptions : IModelOptions, IOnnxModelSetConfig + public class StableDiffusionModelSet : IOnnxModelSetConfig { public string Name { get; set; } public bool IsEnabled { get; set; } @@ -30,7 +29,7 @@ public class ModelOptions : IModelOptions, IOnnxModelSetConfig public int IntraOpNumThreads { get; set; } public ExecutionMode ExecutionMode { get; set; } public ExecutionProvider ExecutionProvider { get; set; } - public List ModelConfigurations { get; set; } + public List ModelConfigurations { get; set; } [JsonIgnore] public ImmutableArray BlankTokenValueArray { get; set; } diff --git a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs index 73397909..763076b7 100644 --- a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs +++ b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs @@ -74,7 +74,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer /// The scheduler. /// The timesteps. /// - protected abstract Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps); + protected abstract Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps); /// @@ -88,7 +88,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer /// The progress callback. /// The cancellation token. /// - protected abstract Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default); + protected abstract Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default); /// @@ -99,7 +99,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer /// The progress. /// The cancellation token. /// - public virtual async Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public virtual async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // Create random seed if none was set schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); @@ -133,7 +133,7 @@ public virtual async Task> DiffuseAsync(IModelOptions modelOp /// The cancellation token. /// /// - public virtual async IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Create random seed if none was set schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); @@ -207,7 +207,7 @@ protected virtual DenseTensor PerformGuidance(DenseTensor noisePre /// The options. /// The latents. /// - protected virtual async Task> DecodeLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor latents) + protected virtual async Task> DecodeLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, DenseTensor latents) { var timestamp = _logger.LogBegin(); diff --git a/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs index 0f2b1ba8..00e60c58 100644 --- a/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs @@ -1,5 +1,4 @@ using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Models; @@ -34,7 +33,7 @@ public interface IDiffuser /// The progress callback. /// The cancellation token. /// - Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// @@ -47,6 +46,6 @@ public interface IDiffuser /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs index e2aefdf9..6759b4e0 100644 --- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs @@ -45,7 +45,7 @@ public InstaFlowDiffuser(IOnnxModelService onnxModelService, IPromptService prom /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs index e4af6e15..72be26c3 100644 --- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs @@ -48,7 +48,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs index 5927007f..8e6ee1bf 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs @@ -57,7 +57,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override async Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override async Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs index 0e349529..b3f92cd9 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs @@ -65,7 +65,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -157,7 +157,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio /// The scheduler. /// The timesteps. /// - protected override async Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override async Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { // Image input, decode, add noise, return as latent 0 var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); @@ -187,7 +187,7 @@ protected override async Task> PrepareLatentsAsync(IModelOpti /// The prompt options. /// The scheduler options. /// - private DenseTensor PrepareMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + private DenseTensor PrepareMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) { using (var mask = promptOptions.InputImageMask.ToImage()) { diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs index 77ba7cab..b365d5ce 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs @@ -46,7 +46,7 @@ public LatentConsistencyDiffuser(IOnnxModelService onnxModelService, IPromptServ /// /// The cancellation token. /// - public override Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public override Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // LCM does not support negative prompting promptOptions.NegativePrompt = string.Empty; @@ -64,7 +64,7 @@ public override Task> DiffuseAsync(IModelOptions modelOptions /// The progress callback. /// The cancellation token. /// - public override IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public override IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // LCM does not support negative prompting promptOptions.NegativePrompt = string.Empty; @@ -88,7 +88,7 @@ protected override bool ShouldPerformGuidance(SchedulerOptions schedulerOptions) /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs index 6f188ff3..986e4299 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs @@ -46,7 +46,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs index 55063931..e02bb88e 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs @@ -49,7 +49,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The scheduler. /// The timesteps. /// - protected override Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs index cefa05d0..fb4e192f 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs @@ -59,7 +59,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override async Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override async Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs index d5129a0b..1e509a51 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs @@ -51,7 +51,7 @@ public InpaintDiffuser(IOnnxModelService onnxModelService, IPromptService prompt /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -124,7 +124,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio /// The prompt options. /// The scheduler options. /// - private DenseTensor PrepareMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + private DenseTensor PrepareMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) { using (var imageMask = promptOptions.InputImageMask.ToImage()) { @@ -169,7 +169,7 @@ private DenseTensor PrepareMask(IModelOptions modelOptions, PromptOptions /// The scheduler options. /// The scheduler. /// - private async Task> PrepareImageMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + private async Task> PrepareImageMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) { using (var image = promptOptions.InputImage.ToImage()) using (var mask = promptOptions.InputImageMask.ToImage()) @@ -259,7 +259,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The scheduler. /// The timesteps. /// - protected override Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs index 08d85b62..6fdd5473 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs @@ -50,7 +50,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -146,7 +146,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override async Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override async Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); @@ -175,7 +175,7 @@ protected override async Task> PrepareLatentsAsync(IModelOpti /// The prompt options. /// The scheduler options. /// - private DenseTensor PrepareMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + private DenseTensor PrepareMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) { using (var mask = promptOptions.InputImageMask.ToImage()) { diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs index 92e0f16a..60f9c913 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs @@ -48,7 +48,7 @@ public StableDiffusionDiffuser(IOnnxModelService onnxModelService, IPromptServic /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs index e4f0fae7..9ea8b0ad 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs @@ -48,7 +48,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs index 9096c92a..95ba11f5 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs @@ -58,7 +58,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override async Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override async Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs index 8d91b14b..206d1bde 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs @@ -49,7 +49,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -150,7 +150,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override async Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override async Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); @@ -179,7 +179,7 @@ protected override async Task> PrepareLatentsAsync(IModelOpti /// The prompt options. /// The scheduler options. /// - private DenseTensor PrepareMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + private DenseTensor PrepareMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) { using (var mask = promptOptions.InputImageMask.ToImage()) { diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs index 5f8e9ab4..f7359d37 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs @@ -45,7 +45,7 @@ public StableDiffusionXLDiffuser(IOnnxModelService onnxModelService, IPromptServ /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -115,7 +115,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio /// /// The scheduler options. /// - protected DenseTensor GetAddTimeIds(IModelOptions model, SchedulerOptions schedulerOptions, bool performGuidance) + protected DenseTensor GetAddTimeIds(StableDiffusionModelSet model, SchedulerOptions schedulerOptions, bool performGuidance) { float[] result; if (model.ModelType == ModelType.Refiner) diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs index 3beefc1f..f36cd09f 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/TextDiffuser.cs @@ -48,7 +48,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The options. /// The scheduler. /// - protected override Task> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); } diff --git a/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs b/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs index b908d86b..5739a64b 100644 --- a/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs +++ b/OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs @@ -1,5 +1,4 @@ -using OnnxStack.StableDiffusion.Common; -using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using System; using System.Collections.Generic; @@ -15,7 +14,7 @@ public static class BatchGenerator /// The batch options. /// The scheduler options. /// - public static List GenerateBatch(IModelOptions modelOptions, BatchOptions batchOptions, SchedulerOptions schedulerOptions) + public static List GenerateBatch(StableDiffusionModelSet modelOptions, BatchOptions batchOptions, SchedulerOptions schedulerOptions) { if (batchOptions.BatchType == BatchOptionType.Seed) { diff --git a/OnnxStack.StableDiffusion/Registration.cs b/OnnxStack.StableDiffusion/Registration.cs index 87cf4d9e..828668d2 100644 --- a/OnnxStack.StableDiffusion/Registration.cs +++ b/OnnxStack.StableDiffusion/Registration.cs @@ -23,7 +23,7 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo { serviceCollection.AddOnnxStack(); serviceCollection.RegisterServices(); - serviceCollection.AddSingleton(ConfigManager.LoadConfiguration(nameof(OnnxStackConfig))); + serviceCollection.AddSingleton(TryLoadAppSettings()); } @@ -86,5 +86,22 @@ private static void ConfigureLibraries() MaximumPoolSizeMegabytes = 100, }); } + + + /// + /// Try load StableDiffusionConfig from application settings. + /// + /// + private static StableDiffusionConfig TryLoadAppSettings() + { + try + { + return ConfigManager.LoadConfiguration(); + } + catch + { + return new StableDiffusionConfig(); + } + } } } diff --git a/OnnxStack.StableDiffusion/Services/PromptService.cs b/OnnxStack.StableDiffusion/Services/PromptService.cs index 1e9dce1d..38dfb0db 100644 --- a/OnnxStack.StableDiffusion/Services/PromptService.cs +++ b/OnnxStack.StableDiffusion/Services/PromptService.cs @@ -39,7 +39,7 @@ public record EmbedsResult(DenseTensor PromptEmbeds, DenseTensor P /// The prompt. /// The negative prompt. /// Tensor containing all text embeds generated from the prompt and negative prompt - public async Task CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled) + public async Task CreatePromptAsync(StableDiffusionModelSet model, PromptOptions promptOptions, bool isGuidanceEnabled) { return model.TokenizerType switch { @@ -58,7 +58,7 @@ public async Task CreatePromptAsync(IModelOptions model, /// The prompt options. /// if set to true is guidance enabled. /// - private async Task CreateEmbedsOneAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled) + private async Task CreateEmbedsOneAsync(StableDiffusionModelSet model, PromptOptions promptOptions, bool isGuidanceEnabled) { // Tokenize Prompt and NegativePrompt var promptTokens = await DecodeTextAsIntAsync(model, promptOptions.Prompt); @@ -82,7 +82,7 @@ private async Task CreateEmbedsOneAsync(IModelOptions mo /// The prompt options. /// if set to true is guidance enabled. /// - private async Task CreateEmbedsTwoAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled) + private async Task CreateEmbedsTwoAsync(StableDiffusionModelSet model, PromptOptions promptOptions, bool isGuidanceEnabled) { /// Tokenize Prompt and NegativePrompt with Tokenizer2 var promptTokens = await DecodeTextAsLongAsync(model, promptOptions.Prompt); @@ -109,7 +109,7 @@ private async Task CreateEmbedsTwoAsync(IModelOptions mo /// The prompt options. /// if set to true is guidance enabled. /// - private async Task CreateEmbedsBothAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled) + private async Task CreateEmbedsBothAsync(StableDiffusionModelSet model, PromptOptions promptOptions, bool isGuidanceEnabled) { // Tokenize Prompt and NegativePrompt var promptTokens = await DecodeTextAsIntAsync(model, promptOptions.Prompt); @@ -145,7 +145,7 @@ private async Task CreateEmbedsBothAsync(IModelOptions m /// /// The input text. /// Tokens generated for the specified text input - private Task DecodeTextAsIntAsync(IModelOptions model, string inputText) + private Task DecodeTextAsIntAsync(StableDiffusionModelSet model, string inputText) { if (string.IsNullOrEmpty(inputText)) return Task.FromResult(Array.Empty()); @@ -171,7 +171,7 @@ private Task DecodeTextAsIntAsync(IModelOptions model, string inputText) /// /// The input text. /// Tokens generated for the specified text input - private Task DecodeTextAsLongAsync(IModelOptions model, string inputText) + private Task DecodeTextAsLongAsync(StableDiffusionModelSet model, string inputText) { if (string.IsNullOrEmpty(inputText)) return Task.FromResult(Array.Empty()); @@ -197,7 +197,7 @@ private Task DecodeTextAsLongAsync(IModelOptions model, string inputText /// /// The tokenized input. /// - private async Task EncodeTokensAsync(IModelOptions model, int[] tokenizedInput) + private async Task EncodeTokensAsync(StableDiffusionModelSet model, int[] tokenizedInput) { var inputDim = new[] { 1, tokenizedInput.Length }; var outputDim = new[] { 1, tokenizedInput.Length, model.TokenizerLength }; @@ -223,7 +223,7 @@ private async Task EncodeTokensAsync(IModelOptions model, int[] tokeniz /// The model. /// The tokenized input. /// - private async Task EncodeTokensAsync(IModelOptions model, long[] tokenizedInput) + private async Task EncodeTokensAsync(StableDiffusionModelSet model, long[] tokenizedInput) { var inputDim = new[] { 1, tokenizedInput.Length }; var promptOutputDim = new[] { 1, tokenizedInput.Length, model.Tokenizer2Length }; @@ -250,7 +250,7 @@ private async Task EncodeTokensAsync(IModelOptions model, long[] /// The input tokens. /// The minimum length. /// - private async Task GenerateEmbedsAsync(IModelOptions model, long[] inputTokens, int minimumLength) + private async Task GenerateEmbedsAsync(StableDiffusionModelSet model, long[] inputTokens, int minimumLength) { // If less than minimumLength pad with blank tokens if (inputTokens.Length < minimumLength) @@ -284,7 +284,7 @@ private async Task GenerateEmbedsAsync(IModelOptions model, long[] /// The input tokens. /// The minimum length. /// - private async Task> GenerateEmbedsAsync(IModelOptions model, int[] inputTokens, int minimumLength) + private async Task> GenerateEmbedsAsync(StableDiffusionModelSet model, int[] inputTokens, int minimumLength) { // If less than minimumLength pad with blank tokens if (inputTokens.Length < minimumLength) diff --git a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs index 052f30b6..ebdd2091 100644 --- a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs @@ -1,5 +1,6 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; +using OnnxStack.Core.Config; using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; @@ -25,8 +26,9 @@ namespace OnnxStack.StableDiffusion.Services /// public sealed class StableDiffusionService : IStableDiffusionService { - private readonly IOnnxModelService _onnxModelService; + private readonly IOnnxModelService _modelService; private readonly StableDiffusionConfig _configuration; + private readonly HashSet _modelSetConfigs; private readonly ConcurrentDictionary _pipelines; /// @@ -36,48 +38,102 @@ public sealed class StableDiffusionService : IStableDiffusionService public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelService onnxModelService, IEnumerable pipelines) { _configuration = configuration; - _onnxModelService = onnxModelService; + _modelService = onnxModelService; + _modelSetConfigs = new HashSet(_configuration.ModelSets, new OnnxModelEqualityComparer()); + _modelService.AddModelSet(_modelSetConfigs); _pipelines = pipelines.ToConcurrentDictionary(k => k.PipelineType, k => k); } /// - /// Gets the models. + /// Gets the model sets. /// - public List Models => _configuration.OnnxModelSets; + public IReadOnlyCollection ModelSets => _modelSetConfigs; + + + /// + /// Adds the model. + /// + /// The model. + /// + public async Task AddModelAsync(StableDiffusionModelSet model) + { + if (await _modelService.AddModelSet(model)) + { + _modelSetConfigs.Add(model); + return true; + } + return false; + } + + + /// + /// Removes the model. + /// + /// The model. + /// + public async Task RemoveModelAsync(StableDiffusionModelSet model) + { + if (await _modelService.RemoveModelSet(model)) + { + _modelSetConfigs.Remove(model); + return true; + } + return false; + } + + + /// + /// Updates the model. + /// + /// The model. + /// + public async Task UpdateModelAsync(StableDiffusionModelSet model) + { + if (await _modelService.UpdateModelSet(model)) + { + _modelSetConfigs.Remove(model); + _modelSetConfigs.Add(model); + return true; + } + return false; + } /// /// Loads the model. /// - /// The model options. + /// The model options. /// - public async Task LoadModelAsync(IModelOptions modelOptions) + public async Task LoadModelAsync(StableDiffusionModelSet model) { - var model = await _onnxModelService.LoadModelAsync(modelOptions); - return model is not null; + if (!_modelSetConfigs.TryGetValue(model, out _)) + throw new Exception("ModelSet not found"); + + var modelSet = await _modelService.LoadModelAsync(model); + return modelSet is not null; } /// /// Unloads the model. /// - /// The model options. + /// The model options. /// - public async Task UnloadModelAsync(IModelOptions modelOptions) + public async Task UnloadModelAsync(StableDiffusionModelSet modelSet) { - return await _onnxModelService.UnloadModelAsync(modelOptions); + return await _modelService.UnloadModelAsync(modelSet); } /// /// Is the model loaded. /// - /// The model options. + /// The model options. /// - public bool IsModelLoaded(IModelOptions modelOptions) + public bool IsModelLoaded(StableDiffusionModelSet modelSet) { - return _onnxModelService.IsModelLoaded(modelOptions); + return _modelService.IsModelLoaded(modelSet); } /// @@ -88,7 +144,7 @@ public bool IsModelLoaded(IModelOptions modelOptions) /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task> GenerateAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); } @@ -102,7 +158,7 @@ public async Task> GenerateAsync(IModelOptions model, PromptO /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task> GenerateAsImageAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) .ContinueWith(t => t.Result.ToImage(), cancellationToken) @@ -118,7 +174,7 @@ public async Task> GenerateAsImageAsync(IModelOptions model, Promp /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task GenerateAsBytesAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) .ContinueWith(t => t.Result.ToImageBytes(), cancellationToken) @@ -134,7 +190,7 @@ public async Task GenerateAsBytesAsync(IModelOptions model, PromptOption /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task GenerateAsStreamAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) .ContinueWith(t => t.Result.ToImageStream(), cancellationToken) @@ -152,7 +208,7 @@ public async Task GenerateAsStreamAsync(IModelOptions model, PromptOptio /// The progress callback. /// The cancellation token. /// - public IAsyncEnumerable GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { return DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken); } @@ -168,7 +224,7 @@ public IAsyncEnumerable GenerateBatchAsync(IModelOptions modelOptio /// The progress callback. /// The cancellation token. /// - public async IAsyncEnumerable> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable> GenerateBatchAsImageAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) yield return result.ImageResult.ToImage(); @@ -185,7 +241,7 @@ public async IAsyncEnumerable> GenerateBatchAsImageAsync(IModelOpt /// The progress callback. /// The cancellation token. /// - public async IAsyncEnumerable GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) yield return result.ImageResult.ToImageBytes(); @@ -202,14 +258,14 @@ public async IAsyncEnumerable GenerateBatchAsBytesAsync(IModelOptions mo /// The progress callback. /// The cancellation token. /// - public async IAsyncEnumerable GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) yield return result.ImageResult.ToImageStream(); } - private async Task> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) + private async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); @@ -226,7 +282,7 @@ private async Task> DiffuseAsync(IModelOptions modelOptions, } - private IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, CancellationToken cancellationToken = default) + private IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); @@ -241,5 +297,7 @@ private IAsyncEnumerable DiffuseBatchAsync(IModelOptions modelOptio return diffuser.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progress, cancellationToken); } + + } } diff --git a/OnnxStack.UI/MainWindow.xaml.cs b/OnnxStack.UI/MainWindow.xaml.cs index f3efd04f..3f6376b5 100644 --- a/OnnxStack.UI/MainWindow.xaml.cs +++ b/OnnxStack.UI/MainWindow.xaml.cs @@ -43,7 +43,7 @@ public MainWindow(StableDiffusionConfig configuration, OnnxStackUIConfig uiSetti WindowRestoreCommand = new AsyncRelayCommand(WindowRestore); WindowMinimizeCommand = new AsyncRelayCommand(WindowMinimize); WindowMaximizeCommand = new AsyncRelayCommand(WindowMaximize); - Models = CreateModelOptions(configuration.OnnxModelSets); + Models = CreateModelOptions(configuration.ModelSets); InitializeComponent(); } @@ -119,7 +119,7 @@ private enum TabId PaintToImage = 3 } - private ObservableCollection CreateModelOptions(List onnxModelSets) + private ObservableCollection CreateModelOptions(List onnxModelSets) { var models = onnxModelSets .Select(model => new ModelOptionsModel diff --git a/OnnxStack.UI/Models/ModelOptionsModel.cs b/OnnxStack.UI/Models/ModelOptionsModel.cs index a1c81635..ac1b0976 100644 --- a/OnnxStack.UI/Models/ModelOptionsModel.cs +++ b/OnnxStack.UI/Models/ModelOptionsModel.cs @@ -1,4 +1,5 @@ using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; using System.ComponentModel; using System.Runtime.CompilerServices; @@ -35,7 +36,7 @@ public bool IsEnabled set { _isEnabled = value; NotifyPropertyChanged(); } } - public IModelOptions ModelOptions { get; set; } + public StableDiffusionModelSet ModelOptions { get; set; } #region INotifyPropertyChanged diff --git a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs index fe45052c..13f450c0 100644 --- a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs +++ b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs @@ -336,7 +336,7 @@ private void Reset() /// The scheduler options. /// The batch options. /// - private async IAsyncEnumerable ExecuteStableDiffusion(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) + private async IAsyncEnumerable ExecuteStableDiffusion(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) { _cancelationTokenSource = new CancellationTokenSource(); diff --git a/OnnxStack.UI/Views/ImageToImageView.xaml.cs b/OnnxStack.UI/Views/ImageToImageView.xaml.cs index 1ec1a935..5a1dd6e9 100644 --- a/OnnxStack.UI/Views/ImageToImageView.xaml.cs +++ b/OnnxStack.UI/Views/ImageToImageView.xaml.cs @@ -318,7 +318,7 @@ private void Reset() /// The scheduler options. /// The batch options. /// - private async IAsyncEnumerable ExecuteStableDiffusion(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) + private async IAsyncEnumerable ExecuteStableDiffusion(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) { _cancelationTokenSource = new CancellationTokenSource(); diff --git a/OnnxStack.UI/Views/ModelView.xaml.cs b/OnnxStack.UI/Views/ModelView.xaml.cs index 9f63855b..b225e4b4 100644 --- a/OnnxStack.UI/Views/ModelView.xaml.cs +++ b/OnnxStack.UI/Views/ModelView.xaml.cs @@ -35,7 +35,6 @@ public partial class ModelView : UserControl, INavigatable, INotifyPropertyChang private readonly ILogger _logger; private readonly string _defaultTokenizerPath; private readonly IDialogService _dialogService; - private readonly IOnnxModelService _onnxModelService; private readonly IModelDownloadService _modelDownloadService; private readonly StableDiffusionConfig _stableDiffusionConfig; private readonly IStableDiffusionService _stableDiffusionService; @@ -55,7 +54,6 @@ public ModelView() { _logger = App.GetService>(); _dialogService = App.GetService(); - _onnxModelService = App.GetService(); _stableDiffusionConfig = App.GetService(); _stableDiffusionService = App.GetService(); _modelDownloadService = App.GetService(); @@ -141,7 +139,7 @@ public Task NavigateAsync(ImageResult imageResult) private void Initialize() { ModelSets = new ObservableCollection(); - foreach (var installedModel in _stableDiffusionConfig.OnnxModelSets.Select(CreateViewModel)) + foreach (var installedModel in _stableDiffusionConfig.ModelSets.Select(CreateViewModel)) { _logger.LogDebug($"Initialize ModelSet: {installedModel.Name}"); @@ -521,7 +519,7 @@ private async Task SaveModelAsync(ModelSetViewModel modelSet) } // Add to Config file - _stableDiffusionConfig.OnnxModelSets.Add(newModelOption); + _stableDiffusionConfig.ModelSets.Add(newModelOption); // Update Templater if one was used UpdateTemplateStatus(newModelOption.Name, ModelTemplateStatus.Installed); @@ -533,7 +531,7 @@ private async Task SaveModelAsync(ModelSetViewModel modelSet) // Update OnnxStack Service newModelOption.ApplyConfigurationOverrides(); - _onnxModelService.UpdateModelSet(newModelOption); + await _stableDiffusionService.UpdateModelAsync(newModelOption); // Add new ViewModel ModelOptions.Add(new ModelOptionsModel @@ -622,7 +620,7 @@ private Task Copy() { var newModelSet = SelectedModelSet.IsTemplate ? CreateViewModel(UISettings.ModelTemplates.FirstOrDefault(x => x.Name == SelectedModelSet.Name)) - : CreateViewModel(_stableDiffusionConfig.OnnxModelSets.FirstOrDefault(x => x.Name == SelectedModelSet.Name)); + : CreateViewModel(_stableDiffusionConfig.ModelSets.FirstOrDefault(x => x.Name == SelectedModelSet.Name)); newModelSet.IsEnabled = false; newModelSet.IsTemplate = false; @@ -830,7 +828,7 @@ private bool CanExecuteExport() /// private async Task UnloadAndRemoveModelSetAsync(string name) { - var onnxModelSet = _stableDiffusionConfig.OnnxModelSets.FirstOrDefault(x => x.Name == name); + var onnxModelSet = _stableDiffusionConfig.ModelSets.FirstOrDefault(x => x.Name == name); if (onnxModelSet is not null) { // If model is loaded unload now @@ -844,7 +842,7 @@ private async Task UnloadAndRemoveModelSetAsync(string name) ModelOptions.Remove(viewModel); // Remove ModelSet - _stableDiffusionConfig.OnnxModelSets.Remove(onnxModelSet); + _stableDiffusionConfig.ModelSets.Remove(onnxModelSet); return true; } return false; @@ -877,7 +875,7 @@ private void UpdateTemplateStatus(string name, ModelTemplateStatus status) /// /// The model. /// - private bool ValidateModelSet(ModelOptions model) + private bool ValidateModelSet(StableDiffusionModelSet model) { if (model == null) return false; @@ -908,7 +906,7 @@ private Task SaveConfigurationFile() try { ConfigManager.SaveConfiguration(UISettings); - ConfigManager.SaveConfiguration(nameof(OnnxStackConfig), _stableDiffusionConfig); + ConfigManager.SaveConfiguration(_stableDiffusionConfig); return Task.FromResult(true); } catch (Exception ex) @@ -982,7 +980,7 @@ private ModelSetViewModel CreateViewModel(ModelConfigTemplate modelTemplate) /// /// The model options. /// - private ModelSetViewModel CreateViewModel(ModelOptions modelOptions) + private ModelSetViewModel CreateViewModel(StableDiffusionModelSet modelOptions) { var isValid = ValidateModelSet(modelOptions); return new ModelSetViewModel @@ -1057,9 +1055,9 @@ private ModelSetViewModel CreateViewModel(ModelOptions modelOptions) /// /// The edit model. /// - private ModelOptions CreateModelOptions(ModelSetViewModel editModel) + private StableDiffusionModelSet CreateModelOptions(ModelSetViewModel editModel) { - return new ModelOptions + return new StableDiffusionModelSet { IsEnabled = editModel.IsEnabled, Name = editModel.Name, @@ -1079,7 +1077,7 @@ private ModelOptions CreateModelOptions(ModelSetViewModel editModel) Diffusers = new List(editModel.GetDiffusers()), SampleSize = editModel.SampleSize, ModelType = editModel.ModelType, - ModelConfigurations = new List(editModel.ModelFiles.Select(x => new OnnxModelSessionConfig + ModelConfigurations = new List(editModel.ModelFiles.Select(x => new OnnxModelConfig { Type = x.Type, OnnxModelPath = x.OnnxModelPath, diff --git a/OnnxStack.UI/Views/PaintToImageView.xaml.cs b/OnnxStack.UI/Views/PaintToImageView.xaml.cs index f906c9fe..2a83f573 100644 --- a/OnnxStack.UI/Views/PaintToImageView.xaml.cs +++ b/OnnxStack.UI/Views/PaintToImageView.xaml.cs @@ -330,7 +330,7 @@ private void Reset() /// The scheduler options. /// The batch options. /// - private async IAsyncEnumerable ExecuteStableDiffusion(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) + private async IAsyncEnumerable ExecuteStableDiffusion(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) { _cancelationTokenSource = new CancellationTokenSource(); diff --git a/OnnxStack.UI/Views/TextToImageView.xaml.cs b/OnnxStack.UI/Views/TextToImageView.xaml.cs index 7097b755..70621bf4 100644 --- a/OnnxStack.UI/Views/TextToImageView.xaml.cs +++ b/OnnxStack.UI/Views/TextToImageView.xaml.cs @@ -293,7 +293,7 @@ private void Reset() /// The prompt options. /// The scheduler options. /// - private async IAsyncEnumerable ExecuteStableDiffusion(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) + private async IAsyncEnumerable ExecuteStableDiffusion(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions) { _cancelationTokenSource = new CancellationTokenSource(); diff --git a/OnnxStack.UI/appsettings.json b/OnnxStack.UI/appsettings.json index c594f974..30e86ef3 100644 --- a/OnnxStack.UI/appsettings.json +++ b/OnnxStack.UI/appsettings.json @@ -6,8 +6,8 @@ } }, "AllowedHosts": "*", - "OnnxStackConfig": { - "OnnxModelSets": [ + "StableDiffusionConfig": { + "ModelSets": [ ] }, "OnnxStackUIConfig": {