From c1025c2bf776db9adf28feca1eeaa7ed36a50592 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Mar 2024 07:54:23 +1300 Subject: [PATCH 1/5] Upscale tensor support, tiled tensor support --- OnnxStack.Console/Examples/UpscaleExample.cs | 2 +- .../Examples/UpscaleStreamExample.cs | 2 +- OnnxStack.Core/Extensions/TensorExtension.cs | 74 +---------- OnnxStack.Core/Image/Extensions.cs | 100 +++++++++++++++ OnnxStack.Core/Model/ImageTiles.cs | 2 +- .../Common/UpscaleModel.cs | 32 ++--- .../Common/UpscaleModelConfig.cs | 14 +++ .../Extensions/ImageExtensions.cs | 88 ------------- OnnxStack.ImageUpscaler/Models/ImageTile.cs | 14 --- .../Models/UpscaleInput.cs | 9 -- .../Pipelines/ImageUpscalePipeline.cs | 116 ++++++++++++------ 11 files changed, 211 insertions(+), 242 deletions(-) create mode 100644 OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs delete mode 100644 OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs delete mode 100644 OnnxStack.ImageUpscaler/Models/ImageTile.cs delete mode 100644 OnnxStack.ImageUpscaler/Models/UpscaleInput.cs diff --git a/OnnxStack.Console/Examples/UpscaleExample.cs b/OnnxStack.Console/Examples/UpscaleExample.cs index d031b1a5..d0d86b82 100644 --- a/OnnxStack.Console/Examples/UpscaleExample.cs +++ b/OnnxStack.Console/Examples/UpscaleExample.cs @@ -25,7 +25,7 @@ public async Task RunAsync() var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); // Create Pipeline - var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4); + var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4, 512); // Run pipeline var result = await pipeline.RunAsync(inputImage); diff --git a/OnnxStack.Console/Examples/UpscaleStreamExample.cs b/OnnxStack.Console/Examples/UpscaleStreamExample.cs index 3c35c4de..b544ab00 100644 --- a/OnnxStack.Console/Examples/UpscaleStreamExample.cs +++ b/OnnxStack.Console/Examples/UpscaleStreamExample.cs @@ -26,7 +26,7 @@ public async Task RunAsync() var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile); // Create pipeline - var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4); + var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4, 512); // Load pipeline await pipeline.LoadAsync(); diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index ffcb29b0..3b8b07f1 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -5,6 +5,7 @@ using System.Numerics.Tensors; using System.Numerics; using OnnxStack.Core.Model; +using System.Threading.Tasks; namespace OnnxStack.Core { @@ -412,78 +413,5 @@ private static DenseTensor ConcatenateAxis2(DenseTensor tensor1, D return concatenatedTensor; } - - - /// - /// Splits the Tensor into 4 equal tiles. - /// - /// The source tensor. - /// TODO: Optimize - public static ImageTiles SplitTiles(this DenseTensor sourceTensor) - { - int tileWidth = sourceTensor.Dimensions[3] / 2; - int tileHeight = sourceTensor.Dimensions[2] / 2; - - return new ImageTiles( - SplitTile(sourceTensor, 0, 0, tileHeight, tileWidth), - SplitTile(sourceTensor, 0, tileWidth, tileHeight, tileWidth * 2), - SplitTile(sourceTensor, tileHeight, 0, tileHeight * 2, tileWidth), - SplitTile(sourceTensor, tileHeight, tileWidth, tileHeight * 2, tileWidth * 2)); - } - - private static DenseTensor SplitTile(DenseTensor tensor, int startRow, int startCol, int endRow, int endCol) - { - int height = endRow - startRow; - int width = endCol - startCol; - int channels = tensor.Dimensions[1]; - var slicedData = new DenseTensor(new[] { 1, channels, height, width }); - for (int c = 0; c < channels; c++) - { - for (int i = 0; i < height; i++) - { - for (int j = 0; j < width; j++) - { - slicedData[0, c, i, j] = tensor[0, c, startRow + i, startCol + j]; - } - } - } - return slicedData; - } - - - /// - /// Rejoins the tiles into a single Tensor. - /// - /// The tiles. - /// TODO: Optimize - public static DenseTensor RejoinTiles(this ImageTiles tiles) - { - int totalHeight = tiles.Tile1.Dimensions[2] + tiles.Tile3.Dimensions[2]; - int totalWidth = tiles.Tile1.Dimensions[3] + tiles.Tile2.Dimensions[3]; - int channels = tiles.Tile1.Dimensions[1]; - var destination = new DenseTensor(new[] { 1, channels, totalHeight, totalWidth }); - RejoinTile(destination, tiles.Tile1, 0, 0); - RejoinTile(destination, tiles.Tile2, 0, tiles.Tile1.Dimensions[3]); - RejoinTile(destination, tiles.Tile3, tiles.Tile1.Dimensions[2], 0); - RejoinTile(destination, tiles.Tile4, tiles.Tile1.Dimensions[2], tiles.Tile1.Dimensions[3]); - return destination; - } - - private static void RejoinTile(DenseTensor destination, DenseTensor tile, int startRow, int startCol) - { - int channels = tile.Dimensions[1]; - int height = tile.Dimensions[2]; - int width = tile.Dimensions[3]; - for (int c = 0; c < channels; c++) - { - for (int i = 0; i < height; i++) - { - for (int j = 0; j < width; j++) - { - destination[0, c, startRow + i, startCol + j] = tile[0, c, i, j]; - } - } - } - } } } diff --git a/OnnxStack.Core/Image/Extensions.cs b/OnnxStack.Core/Image/Extensions.cs index e75d86a8..a045223c 100644 --- a/OnnxStack.Core/Image/Extensions.cs +++ b/OnnxStack.Core/Image/Extensions.cs @@ -1,7 +1,9 @@ using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Model; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; using SixLabors.ImageSharp.Processing; +using System.Threading.Tasks; namespace OnnxStack.Core.Image { @@ -40,6 +42,104 @@ public static ResizeMode ToResizeMode(this ImageResizeMode resizeMode) }; } + + /// + /// Splits the Tensor into 4 equal tiles. + /// + /// The source tensor. + /// TODO: Optimize + public static ImageTiles SplitImageTiles(this DenseTensor sourceTensor, int overlap = 20) + { + var tileWidth = sourceTensor.Dimensions[3] / 2; + var tileHeight = sourceTensor.Dimensions[2] / 2; + return new ImageTiles(tileWidth, tileHeight, overlap, + SplitImageTile(sourceTensor, 0, 0, tileHeight + overlap, tileWidth + overlap), + SplitImageTile(sourceTensor, 0, tileWidth - overlap, tileHeight + overlap, tileWidth * 2), + SplitImageTile(sourceTensor, tileHeight - overlap, 0, tileHeight * 2, tileWidth + overlap), + SplitImageTile(sourceTensor, tileHeight - overlap, tileWidth - overlap, tileHeight * 2, tileWidth * 2)); + } + + + /// + /// Splits a tile from the source. + /// + /// The tensor. + /// The start row. + /// The start col. + /// The end row. + /// The end col. + /// + private static DenseTensor SplitImageTile(DenseTensor source, int startRow, int startCol, int endRow, int endCol) + { + int height = endRow - startRow; + int width = endCol - startCol; + int channels = source.Dimensions[1]; + var splitTensor = new DenseTensor(new[] { 1, channels, height, width }); + Parallel.For(0, channels, (c) => + { + Parallel.For(0, height, (i) => + { + Parallel.For(0, width, (j) => + { + splitTensor[0, c, i, j] = source[0, c, startRow + i, startCol + j]; + }); + }); + }); + return splitTensor; + } + + + /// + /// Joins the tiles into a single Tensor. + /// + /// The tiles. + /// TODO: Optimize + public static DenseTensor JoinImageTiles(this ImageTiles tiles) + { + var totalWidth = tiles.Width * 2; + var totalHeight = tiles.Height * 2; + var channels = tiles.Tile1.Dimensions[1]; + var destination = new DenseTensor(new[] { 1, channels, totalHeight, totalWidth }); + JoinImageTile(destination, tiles.Tile1, 0, 0, tiles.Height + tiles.Overlap, tiles.Width + tiles.Overlap); + JoinImageTile(destination, tiles.Tile2, 0, tiles.Width - tiles.Overlap, tiles.Height + tiles.Overlap, totalWidth); + JoinImageTile(destination, tiles.Tile3, tiles.Height - tiles.Overlap, 0, totalHeight, tiles.Width + tiles.Overlap); + JoinImageTile(destination, tiles.Tile4, tiles.Height - tiles.Overlap, tiles.Width - tiles.Overlap, totalHeight, totalWidth); + return destination; + } + + + /// + /// Joins the tile to the destination tensor. + /// + /// The destination. + /// The tile. + /// The start row. + /// The start col. + /// The end row. + /// The end col. + private static void JoinImageTile(DenseTensor destination, DenseTensor tile, int startRow, int startCol, int endRow, int endCol) + { + int height = endRow - startRow; + int width = endCol - startCol; + int channels = tile.Dimensions[1]; + Parallel.For(0, channels, (c) => + { + Parallel.For(0, height, (i) => + { + Parallel.For(0, width, (j) => + { + var value = tile[0, c, i, j]; + var existing = destination[0, c, startRow + i, startCol + j]; + if (existing > 0) + { + // Blend ovelap + value = (existing + value) / 2f; + } + destination[0, c, startRow + i, startCol + j] = value; + }); + }); + }); + } } public enum ImageNormalizeType diff --git a/OnnxStack.Core/Model/ImageTiles.cs b/OnnxStack.Core/Model/ImageTiles.cs index f44d1cf7..90425f56 100644 --- a/OnnxStack.Core/Model/ImageTiles.cs +++ b/OnnxStack.Core/Model/ImageTiles.cs @@ -2,5 +2,5 @@ namespace OnnxStack.Core.Model { - public record ImageTiles(DenseTensor Tile1, DenseTensor Tile2, DenseTensor Tile3, DenseTensor Tile4); + public record ImageTiles(int Width, int Height, int Overlap, DenseTensor Tile1, DenseTensor Tile2, DenseTensor Tile3, DenseTensor Tile4); } diff --git a/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs b/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs index a6352739..f8e9ec26 100644 --- a/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs +++ b/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs @@ -1,39 +1,39 @@ using Microsoft.ML.OnnxRuntime; using OnnxStack.Core.Config; using OnnxStack.Core.Model; +using System; namespace OnnxStack.ImageUpscaler.Common { public class UpscaleModel : OnnxModelSession { - private readonly int _channels; - private readonly int _sampleSize; - private readonly int _scaleFactor; + private readonly UpscaleModelConfig _configuration; public UpscaleModel(UpscaleModelConfig configuration) : base(configuration) { - _channels = configuration.Channels; - _sampleSize = configuration.SampleSize; - _scaleFactor = configuration.ScaleFactor; + _configuration = configuration; } - public int Channels => _channels; - public int SampleSize => _sampleSize; - public int ScaleFactor => _scaleFactor; - + public int Channels => _configuration.Channels; + public int SampleSize => _configuration.SampleSize; + public int ScaleFactor => _configuration.ScaleFactor; + public int TileSize => _configuration.TileSize; + public int TileOverlap => _configuration.TileOverlap; public static UpscaleModel Create(UpscaleModelConfig configuration) { return new UpscaleModel(configuration); } - public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleSize, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) { var configuration = new UpscaleModelConfig { - Channels = 3, + Channels = channels, SampleSize = sampleSize, ScaleFactor = scaleFactor, + TileOverlap = tileOverlap, + TileSize = Math.Min(sampleSize, tileSize > 0 ? tileSize : sampleSize), DeviceId = deviceId, ExecutionProvider = executionProvider, ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, @@ -44,12 +44,4 @@ public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleS return new UpscaleModel(configuration); } } - - - public record UpscaleModelConfig : OnnxModelConfig - { - public int Channels { get; set; } - public int SampleSize { get; set; } - public int ScaleFactor { get; set; } - } } diff --git a/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs b/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs new file mode 100644 index 00000000..eca0666b --- /dev/null +++ b/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs @@ -0,0 +1,14 @@ +using OnnxStack.Core.Config; + +namespace OnnxStack.ImageUpscaler.Common +{ + public record UpscaleModelConfig : OnnxModelConfig + { + public int Channels { get; set; } + public int SampleSize { get; set; } + public int ScaleFactor { get; set; } + + public int TileSize { get; set; } + public int TileOverlap { get; set; } + } +} diff --git a/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs b/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs deleted file mode 100644 index 60aee828..00000000 --- a/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs +++ /dev/null @@ -1,88 +0,0 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Image; -using OnnxStack.ImageUpscaler.Models; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; -using System; -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace OnnxStack.ImageUpscaler.Extensions -{ - internal static class ImageExtensions - { - /// - /// Generates the image tiles. - /// - /// The image source. - /// Maximum size of the tile. - /// The scale factor. - /// - internal static List GenerateTiles(this OnnxImage imageSource, int sampleSize, int scaleFactor) - { - var tiles = new List(); - var tileSizeX = Math.Min(sampleSize, imageSource.Width); - var tileSizeY = Math.Min(sampleSize, imageSource.Height); - var columns = (int)Math.Ceiling((double)imageSource.Width / tileSizeX); - var rows = (int)Math.Ceiling((double)imageSource.Height / tileSizeY); - var tileWidth = imageSource.Width / columns; - var tileHeight = imageSource.Height / rows; - - for (int y = 0; y < rows; y++) - { - for (int x = 0; x < columns; x++) - { - var tileRect = new Rectangle(x * tileWidth, y * tileHeight, tileWidth, tileHeight); - var tileDest = new Rectangle(tileRect.X * scaleFactor, tileRect.Y * scaleFactor, tileWidth * scaleFactor, tileHeight * scaleFactor); - var tileImage = ExtractTile(imageSource, tileRect); - tiles.Add(new ImageTile { Image = tileImage, Destination = tileDest }); - } - } - return tiles; - } - - - /// - /// Extracts an image tile from a source image. - /// - /// The source image. - /// The source area. - /// - internal static OnnxImage ExtractTile(this OnnxImage sourceImage, Rectangle sourceArea) - { - var height = sourceArea.Height; - var targetImage = new Image(sourceArea.Width, sourceArea.Height); - sourceImage.GetImage().ProcessPixelRows(targetImage, (sourceAccessor, targetAccessor) => - { - for (int i = 0; i < height; i++) - { - var sourceRow = sourceAccessor.GetRowSpan(sourceArea.Y + i); - var targetRow = targetAccessor.GetRowSpan(i); - sourceRow.Slice(sourceArea.X, sourceArea.Width).CopyTo(targetRow); - } - }); - return new OnnxImage(targetImage); - } - - - internal static void ApplyImageTile(this DenseTensor imageTensor, DenseTensor tileTensor, Rectangle location) - { - var offsetY = location.Y; - var offsetX = location.X; - var dimensions = tileTensor.Dimensions.ToArray(); - Parallel.For(0, dimensions[0], (i) => - { - Parallel.For(0, dimensions[1], (j) => - { - Parallel.For(0, dimensions[2], (k) => - { - Parallel.For(0, dimensions[3], (l) => - { - imageTensor[i, j, k + offsetY, l + offsetX] = tileTensor[i, j, k, l]; - }); - }); - }); - }); - } - } -} diff --git a/OnnxStack.ImageUpscaler/Models/ImageTile.cs b/OnnxStack.ImageUpscaler/Models/ImageTile.cs deleted file mode 100644 index 57e4ac77..00000000 --- a/OnnxStack.ImageUpscaler/Models/ImageTile.cs +++ /dev/null @@ -1,14 +0,0 @@ -using OnnxStack.Core.Image; -using SixLabors.ImageSharp; - -namespace OnnxStack.ImageUpscaler.Models -{ - internal record ImageTile - { - public OnnxImage Image { get; set; } - public Rectangle Destination { get; set; } - } -} - - - diff --git a/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs b/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs deleted file mode 100644 index 05d9f3a5..00000000 --- a/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs +++ /dev/null @@ -1,9 +0,0 @@ -using System.Collections.Generic; - -namespace OnnxStack.ImageUpscaler.Models -{ - internal record UpscaleInput(List ImageTiles, int OutputWidth, int OutputHeight); -} - - - diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs index 1761a916..85eb331f 100644 --- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs +++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs @@ -6,10 +6,7 @@ using OnnxStack.Core.Model; using OnnxStack.Core.Video; using OnnxStack.ImageUpscaler.Common; -using OnnxStack.ImageUpscaler.Extensions; -using OnnxStack.ImageUpscaler.Models; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -66,6 +63,21 @@ public async Task UnloadAsync() } + /// + /// Runs the upscale pipeline. + /// + /// The input image. + /// The cancellation token. + /// + public async Task> RunAsync(DenseTensor inputImage, CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Upscale image.."); + var result = await RunInternalAsync(inputImage, cancellationToken); + _logger?.LogEnd("Upscale image complete.", timestamp); + return result; + } + + /// /// Runs the upscale pipeline. /// @@ -125,46 +137,78 @@ public async IAsyncEnumerable RunAsync(IAsyncEnumerable im } + /// + /// Runs the upscale pipeline + /// + /// The input image. + /// The cancellation token. + /// private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { - var upscaleInput = CreateInputParams(inputImage, _upscaleModel.SampleSize, _upscaleModel.ScaleFactor); - var metadata = await _upscaleModel.GetMetadataAsync(); + var inputTensor = inputImage.GetImageTensor(ImageNormalizeType.ZeroToOne, _upscaleModel.Channels); + var outputTensor = await RunInternalAsync(inputTensor, cancellationToken); + return new OnnxImage(outputTensor, ImageNormalizeType.ZeroToOne); + } - var outputTensor = new DenseTensor(new[] { 1, _upscaleModel.Channels, upscaleInput.OutputHeight, upscaleInput.OutputWidth }); - foreach (var imageTile in upscaleInput.ImageTiles) - { - cancellationToken.ThrowIfCancellationRequested(); - var outputDimension = new[] { 1, _upscaleModel.Channels, imageTile.Destination.Height, imageTile.Destination.Width }; - var inputTensor = imageTile.Image.GetImageTensor(ImageNormalizeType.ZeroToOne, _upscaleModel.Channels); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(inputTensor); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await _upscaleModel.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); - } - } + /// + /// Runs the upscale pipeline + /// + /// The input tensor. + /// The cancellation token. + /// + private async Task> RunInternalAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) + { + if (inputTensor.Dimensions[2] <= _upscaleModel.TileSize && inputTensor.Dimensions[3] <= _upscaleModel.TileSize) + { + return await RunInferenceAsync(inputTensor, cancellationToken); } - return new OnnxImage(outputTensor, ImageNormalizeType.ZeroToOne); + + var inputTiles = inputTensor.SplitImageTiles(_upscaleModel.TileOverlap); + var outputTiles = new ImageTiles + ( + inputTiles.Width * _upscaleModel.ScaleFactor, + inputTiles.Height * _upscaleModel.ScaleFactor, + inputTiles.Overlap * _upscaleModel.ScaleFactor, + await RunInternalAsync(inputTiles.Tile1, cancellationToken), + await RunInternalAsync(inputTiles.Tile2, cancellationToken), + await RunInternalAsync(inputTiles.Tile3, cancellationToken), + await RunInternalAsync(inputTiles.Tile4, cancellationToken) + ); + return outputTiles.JoinImageTiles(); } + /// - /// Creates the input parameters. + /// Runs the model inference. /// - /// The image source. - /// Maximum size of the tile. - /// The scale factor. + /// The input tensor. + /// The cancellation token. /// - private static UpscaleInput CreateInputParams(OnnxImage imageSource, int maxTileSize, int scaleFactor) + private async Task> RunInferenceAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) { - var tiles = imageSource.GenerateTiles(maxTileSize, scaleFactor); - var width = imageSource.Width * scaleFactor; - var height = imageSource.Height * scaleFactor; - return new UpscaleInput(tiles, width, height); + var metadata = await _upscaleModel.GetMetadataAsync(); + cancellationToken.ThrowIfCancellationRequested(); + + var outputDimension = new[] + { + 1, + _upscaleModel.Channels, + inputTensor.Dimensions[2] * _upscaleModel.ScaleFactor, + inputTensor.Dimensions[3] * _upscaleModel.ScaleFactor + }; + + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _upscaleModel.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) + { + return result.ToDenseTensor(); + } + } } @@ -189,7 +233,7 @@ public static ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet, ILog /// The execution provider. /// The logger. /// - public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFactor, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) + public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFactor, int sampleSize, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) { var name = Path.GetFileNameWithoutExtension(modelFile); var configuration = new UpscaleModelSet @@ -200,9 +244,11 @@ public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFac ExecutionProvider = executionProvider, UpscaleModelConfig = new UpscaleModelConfig { - Channels = 3, + Channels = channels, SampleSize = sampleSize, ScaleFactor = scaleFactor, + TileOverlap = tileOverlap, + TileSize = Math.Min(sampleSize, tileSize > 0 ? tileSize : sampleSize), OnnxModelPath = modelFile } }; From 01d6290de1adc688626def9b5e1e72be0cd0d9a4 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Mar 2024 12:10:38 +1300 Subject: [PATCH 2/5] Handle both normalization types --- OnnxStack.Core/Extensions/TensorExtension.cs | 20 ++++++++ .../Common/UpscaleModel.cs | 7 ++- .../Common/UpscaleModelConfig.cs | 3 ++ .../Pipelines/ImageUpscalePipeline.cs | 46 +++++++++++++++---- 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index 3b8b07f1..b7ae64a9 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -413,5 +413,25 @@ private static DenseTensor ConcatenateAxis2(DenseTensor tensor1, D return concatenatedTensor; } + + + /// + /// Normalizes the tensor values from range -1 to 1 to 0 to 1. + /// + /// The image tensor. + public static void NormalizeOneOneToZeroOne(this DenseTensor imageTensor) + { + Parallel.For(0, (int)imageTensor.Length, (i) => imageTensor.SetValue(i, imageTensor.GetValue(i) / 2f + 0.5f)); + } + + + /// + /// Normalizes the tensor values from range 0 to 1 to -1 to 1. + /// + /// The image tensor. + public static void NormalizeZeroOneToOneOne(this DenseTensor imageTensor) + { + Parallel.For(0, (int)imageTensor.Length, (i) => imageTensor.SetValue(i, 2f * imageTensor.GetValue(i) - 1f)); + } } } diff --git a/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs b/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs index f8e9ec26..20d92cf8 100644 --- a/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs +++ b/OnnxStack.ImageUpscaler/Common/UpscaleModel.cs @@ -1,5 +1,6 @@ using Microsoft.ML.OnnxRuntime; using OnnxStack.Core.Config; +using OnnxStack.Core.Image; using OnnxStack.Core.Model; using System; @@ -19,13 +20,15 @@ public UpscaleModel(UpscaleModelConfig configuration) : base(configuration) public int ScaleFactor => _configuration.ScaleFactor; public int TileSize => _configuration.TileSize; public int TileOverlap => _configuration.TileOverlap; + public ImageNormalizeType NormalizeType => _configuration.NormalizeType; + public bool NormalizeInput => _configuration.NormalizeInput; public static UpscaleModel Create(UpscaleModelConfig configuration) { return new UpscaleModel(configuration); } - public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleSize, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleSize, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) { var configuration = new UpscaleModelConfig { @@ -34,6 +37,8 @@ public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleS ScaleFactor = scaleFactor, TileOverlap = tileOverlap, TileSize = Math.Min(sampleSize, tileSize > 0 ? tileSize : sampleSize), + NormalizeType = normalizeType, + NormalizeInput = normalizeInput, DeviceId = deviceId, ExecutionProvider = executionProvider, ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, diff --git a/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs b/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs index eca0666b..afc8d75b 100644 --- a/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs +++ b/OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs @@ -1,4 +1,5 @@ using OnnxStack.Core.Config; +using OnnxStack.Core.Image; namespace OnnxStack.ImageUpscaler.Common { @@ -10,5 +11,7 @@ public record UpscaleModelConfig : OnnxModelConfig public int TileSize { get; set; } public int TileOverlap { get; set; } + public ImageNormalizeType NormalizeType { get; set; } + public bool NormalizeInput { get; set; } } } diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs index 85eb331f..f57dda30 100644 --- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs +++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using Microsoft.ML.OnnxRuntime.Tensors; +using Newtonsoft.Json.Linq; using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Image; @@ -10,6 +11,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Numerics.Tensors; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -72,7 +74,7 @@ public async Task UnloadAsync() public async Task> RunAsync(DenseTensor inputImage, CancellationToken cancellationToken = default) { var timestamp = _logger?.LogBegin("Upscale image.."); - var result = await RunInternalAsync(inputImage, cancellationToken); + var result = await UpscaleTensorAsync(inputImage, cancellationToken); _logger?.LogEnd("Upscale image complete.", timestamp); return result; } @@ -87,7 +89,7 @@ public async Task> RunAsync(DenseTensor inputImage, Ca public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var timestamp = _logger?.LogBegin("Upscale image.."); - var result = await RunInternalAsync(inputImage, cancellationToken); + var result = await UpscaleImageAsync(inputImage, cancellationToken); _logger?.LogEnd("Upscale image complete.", timestamp); return result; } @@ -105,7 +107,7 @@ public async Task RunAsync(OnnxVideo inputVideo, CancellationToken ca var upscaledFrames = new List(); foreach (var videoFrame in inputVideo.Frames) { - upscaledFrames.Add(await RunInternalAsync(videoFrame, cancellationToken)); + upscaledFrames.Add(await UpscaleImageAsync(videoFrame, cancellationToken)); } var firstFrame = upscaledFrames.First(); @@ -131,23 +133,44 @@ public async IAsyncEnumerable RunAsync(IAsyncEnumerable im var timestamp = _logger?.LogBegin("Upscale video stream.."); await foreach (var imageFrame in imageFrames) { - yield return await RunInternalAsync(imageFrame, cancellationToken); + yield return await UpscaleImageAsync(imageFrame, cancellationToken); } _logger?.LogEnd("Upscale video stream complete.", timestamp); } + /// - /// Runs the upscale pipeline + /// Upscales the OnnxImage. /// /// The input image. /// The cancellation token. /// - private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + private async Task UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { - var inputTensor = inputImage.GetImageTensor(ImageNormalizeType.ZeroToOne, _upscaleModel.Channels); + var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels); var outputTensor = await RunInternalAsync(inputTensor, cancellationToken); - return new OnnxImage(outputTensor, ImageNormalizeType.ZeroToOne); + return new OnnxImage(outputTensor, _upscaleModel.NormalizeType); + } + + + /// + /// Upscales the DenseTensor + /// + /// The input image. + /// The cancellation token. + /// + public async Task> UpscaleTensorAsync(DenseTensor inputImage, CancellationToken cancellationToken = default) + { + if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne) + inputImage.NormalizeOneOneToZeroOne(); + + var result = await RunInternalAsync(inputImage, cancellationToken); + + if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne) + result.NormalizeZeroOneToOneOne(); + + return result; } @@ -233,7 +256,7 @@ public static ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet, ILog /// The execution provider. /// The logger. /// - public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFactor, int sampleSize, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) + public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFactor, int sampleSize, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) { var name = Path.GetFileNameWithoutExtension(modelFile); var configuration = new UpscaleModelSet @@ -249,10 +272,13 @@ public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFac ScaleFactor = scaleFactor, TileOverlap = tileOverlap, TileSize = Math.Min(sampleSize, tileSize > 0 ? tileSize : sampleSize), - OnnxModelPath = modelFile + NormalizeType = normalizeType, + NormalizeInput = normalizeInput, + OnnxModelPath = modelFile, } }; return CreatePipeline(configuration, logger); } } + } From 19b12115508d362867d219b9ed0998da95bf8c9c Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Mar 2024 14:49:29 +1300 Subject: [PATCH 3/5] Support recursive tiling, tidy up logging --- .../Pipelines/ImageUpscalePipeline.cs | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs index f57dda30..83acfa90 100644 --- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs +++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.Logging; using Microsoft.ML.OnnxRuntime.Tensors; -using Newtonsoft.Json.Linq; using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Image; @@ -11,7 +10,6 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using System.Numerics.Tensors; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -73,9 +71,9 @@ public async Task UnloadAsync() /// public async Task> RunAsync(DenseTensor inputImage, CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Upscale image.."); + var timestamp = _logger?.LogBegin("Upscale DenseTensor.."); var result = await UpscaleTensorAsync(inputImage, cancellationToken); - _logger?.LogEnd("Upscale image complete.", timestamp); + _logger?.LogEnd("Upscale DenseTensor complete.", timestamp); return result; } @@ -88,9 +86,9 @@ public async Task> RunAsync(DenseTensor inputImage, Ca /// public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Upscale image.."); + var timestamp = _logger?.LogBegin("Upscale OnnxImage.."); var result = await UpscaleImageAsync(inputImage, cancellationToken); - _logger?.LogEnd("Upscale image complete.", timestamp); + _logger?.LogEnd("Upscale OnnxImage complete.", timestamp); return result; } @@ -103,7 +101,7 @@ public async Task RunAsync(OnnxImage inputImage, CancellationToken ca /// public async Task RunAsync(OnnxVideo inputVideo, CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Upscale video.."); + var timestamp = _logger?.LogBegin("Upscale OnnxVideo.."); var upscaledFrames = new List(); foreach (var videoFrame in inputVideo.Frames) { @@ -117,7 +115,7 @@ public async Task RunAsync(OnnxVideo inputVideo, CancellationToken ca Height = firstFrame.Height, }; - _logger?.LogEnd("Upscale video complete.", timestamp); + _logger?.LogEnd("Upscale OnnxVideo complete.", timestamp); return new OnnxVideo(videoInfo, upscaledFrames); } @@ -130,16 +128,15 @@ public async Task RunAsync(OnnxVideo inputVideo, CancellationToken ca /// public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Upscale video stream.."); + var timestamp = _logger?.LogBegin("Upscale OnnxImage stream.."); await foreach (var imageFrame in imageFrames) { yield return await UpscaleImageAsync(imageFrame, cancellationToken); } - _logger?.LogEnd("Upscale video stream complete.", timestamp); + _logger?.LogEnd("Upscale OnnxImage stream complete.", timestamp); } - /// /// Upscales the OnnxImage. /// @@ -149,7 +146,7 @@ public async IAsyncEnumerable RunAsync(IAsyncEnumerable im private async Task UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels); - var outputTensor = await RunInternalAsync(inputTensor, cancellationToken); + var outputTensor = await RunInternalAsync(inputTensor, inputImage.Height, inputImage.Width, cancellationToken); return new OnnxImage(outputTensor, _upscaleModel.NormalizeType); } @@ -157,15 +154,17 @@ private async Task UpscaleImageAsync(OnnxImage inputImage, Cancellati /// /// Upscales the DenseTensor /// - /// The input image. + /// The input Tensor. /// The cancellation token. /// - public async Task> UpscaleTensorAsync(DenseTensor inputImage, CancellationToken cancellationToken = default) + public async Task> UpscaleTensorAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) { if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne) - inputImage.NormalizeOneOneToZeroOne(); + inputTensor.NormalizeOneOneToZeroOne(); - var result = await RunInternalAsync(inputImage, cancellationToken); + var height = inputTensor.Dimensions[2]; + var width = inputTensor.Dimensions[3]; + var result = await RunInternalAsync(inputTensor, height, width, cancellationToken); if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne) result.NormalizeZeroOneToOneOne(); @@ -180,9 +179,9 @@ public async Task> UpscaleTensorAsync(DenseTensor inpu /// The input tensor. /// The cancellation token. /// - private async Task> RunInternalAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) + private async Task> RunInternalAsync(DenseTensor inputTensor, int height, int width, CancellationToken cancellationToken = default) { - if (inputTensor.Dimensions[2] <= _upscaleModel.TileSize && inputTensor.Dimensions[3] <= _upscaleModel.TileSize) + if (height <= _upscaleModel.TileSize && width <= _upscaleModel.TileSize) { return await RunInferenceAsync(inputTensor, cancellationToken); } @@ -193,10 +192,10 @@ private async Task> RunInternalAsync(DenseTensor input inputTiles.Width * _upscaleModel.ScaleFactor, inputTiles.Height * _upscaleModel.ScaleFactor, inputTiles.Overlap * _upscaleModel.ScaleFactor, - await RunInternalAsync(inputTiles.Tile1, cancellationToken), - await RunInternalAsync(inputTiles.Tile2, cancellationToken), - await RunInternalAsync(inputTiles.Tile3, cancellationToken), - await RunInternalAsync(inputTiles.Tile4, cancellationToken) + await RunInternalAsync(inputTiles.Tile1, inputTiles.Height, inputTiles.Width, cancellationToken), + await RunInternalAsync(inputTiles.Tile2, inputTiles.Height, inputTiles.Width, cancellationToken), + await RunInternalAsync(inputTiles.Tile3, inputTiles.Height, inputTiles.Width, cancellationToken), + await RunInternalAsync(inputTiles.Tile4, inputTiles.Height, inputTiles.Width, cancellationToken) ); return outputTiles.JoinImageTiles(); } From ea63f96dd846e5b97643991cb16b380971b721ef Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Mar 2024 16:10:02 +1300 Subject: [PATCH 4/5] FeatureExtractor normalization --- .../Examples/ControlNetFeatureExample.cs | 2 +- .../Examples/FeatureExtractorExample.cs | 2 +- OnnxStack.Core/Image/Extensions.cs | 13 ++- OnnxStack.Core/Image/OnnxImage.cs | 44 +++++---- .../Common/FeatureExtractorModel.cs | 13 +-- .../Common/FeatureExtractorModelConfig.cs | 5 +- .../Pipelines/FeatureExtractorPipeline.cs | 93 +++++++++++++------ 7 files changed, 117 insertions(+), 55 deletions(-) diff --git a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs index 2341f04b..bda931e0 100644 --- a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs +++ b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs @@ -35,7 +35,7 @@ public async Task RunAsync() var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); // Create Annotation pipeline - var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true); + var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true); // Create Depth Image var controlImage = await annotationPipeline.RunAsync(inputImage); diff --git a/OnnxStack.Console/Examples/FeatureExtractorExample.cs b/OnnxStack.Console/Examples/FeatureExtractorExample.cs index e7a1ee37..a52b3ea8 100644 --- a/OnnxStack.Console/Examples/FeatureExtractorExample.cs +++ b/OnnxStack.Console/Examples/FeatureExtractorExample.cs @@ -37,7 +37,7 @@ public async Task RunAsync() { FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\canny.onnx"), FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\hed.onnx"), - FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true, inputResizeMode: ImageResizeMode.Stretch), + FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true, inputResizeMode: ImageResizeMode.Stretch), FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024, setOutputToInputAlpha: true, inputResizeMode: ImageResizeMode.Stretch) }; diff --git a/OnnxStack.Core/Image/Extensions.cs b/OnnxStack.Core/Image/Extensions.cs index a045223c..ef8fa61b 100644 --- a/OnnxStack.Core/Image/Extensions.cs +++ b/OnnxStack.Core/Image/Extensions.cs @@ -16,6 +16,17 @@ public static class Extensions /// The image tensor. /// public static OnnxImage ToImageMask(this DenseTensor imageTensor) + { + return new OnnxImage(imageTensor.FromMaskTensor()); + } + + + /// + /// Convert from single channle mask tensor to Rgba32 (Greyscale) + /// + /// The image tensor. + /// + public static Image FromMaskTensor(this DenseTensor imageTensor) { var width = imageTensor.Dimensions[3]; var height = imageTensor.Dimensions[2]; @@ -28,7 +39,7 @@ public static OnnxImage ToImageMask(this DenseTensor imageTensor) result[x, y] = new L8((byte)(imageTensor[0, 0, y, x] * 255.0f)); } } - return new OnnxImage(result.CloneAs()); + return result.CloneAs(); } } diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs index a9d9cb2f..c3db705e 100644 --- a/OnnxStack.Core/Image/OnnxImage.cs +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -64,27 +64,35 @@ public OnnxImage(DenseTensor imageTensor, ImageNormalizeType normalizeTyp { var height = imageTensor.Dimensions[2]; var width = imageTensor.Dimensions[3]; - var hasTransparency = imageTensor.Dimensions[1] == 4; - _imageData = new Image(width, height); - for (var y = 0; y < height; y++) + var channels = imageTensor.Dimensions[1]; + if (channels == 1) { - for (var x = 0; x < width; x++) + _imageData = imageTensor.FromMaskTensor(); + } + else + { + var hasTransparency = channels == 4; + _imageData = new Image(width, height); + for (var y = 0; y < height; y++) { - if (normalizeType == ImageNormalizeType.ZeroToOne) - { - _imageData[x, y] = new Rgba32( - DenormalizeZeroToOneToByte(imageTensor, 0, y, x), - DenormalizeZeroToOneToByte(imageTensor, 1, y, x), - DenormalizeZeroToOneToByte(imageTensor, 2, y, x), - hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue); - } - else + for (var x = 0; x < width; x++) { - _imageData[x, y] = new Rgba32( - DenormalizeOneToOneToByte(imageTensor, 0, y, x), - DenormalizeOneToOneToByte(imageTensor, 1, y, x), - DenormalizeOneToOneToByte(imageTensor, 2, y, x), - hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue); + if (normalizeType == ImageNormalizeType.ZeroToOne) + { + _imageData[x, y] = new Rgba32( + DenormalizeZeroToOneToByte(imageTensor, 0, y, x), + DenormalizeZeroToOneToByte(imageTensor, 1, y, x), + DenormalizeZeroToOneToByte(imageTensor, 2, y, x), + hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue); + } + else + { + _imageData[x, y] = new Rgba32( + DenormalizeOneToOneToByte(imageTensor, 0, y, x), + DenormalizeOneToOneToByte(imageTensor, 1, y, x), + DenormalizeOneToOneToByte(imageTensor, 2, y, x), + hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue); + } } } } diff --git a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs index e63b159b..5c029044 100644 --- a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs +++ b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs @@ -17,17 +17,18 @@ public FeatureExtractorModel(FeatureExtractorModelConfig configuration) public int OutputChannels => _configuration.OutputChannels; public int SampleSize => _configuration.SampleSize; - public bool NormalizeOutputTensor => _configuration.NormalizeOutputTensor; + public bool NormalizeOutput => _configuration.NormalizeOutput; public bool SetOutputToInputAlpha => _configuration.SetOutputToInputAlpha; public ImageResizeMode InputResizeMode => _configuration.InputResizeMode; - public ImageNormalizeType InputNormalization => _configuration.NormalizeInputTensor; + public ImageNormalizeType NormalizeType => _configuration.NormalizeType; + public bool NormalizeInput => _configuration.NormalizeInput; public static FeatureExtractorModel Create(FeatureExtractorModelConfig configuration) { return new FeatureExtractorModel(configuration); } - public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, int outputChannels = 1, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, bool normalizeOutput = false, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) { var configuration = new FeatureExtractorModelConfig { @@ -38,12 +39,12 @@ public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, IntraOpNumThreads = 0, OnnxModelPath = modelFile, - SampleSize = sampleSize, OutputChannels = outputChannels, - NormalizeOutputTensor = normalizeOutputTensor, + NormalizeType = normalizeType, + NormalizeInput = normalizeInput, + NormalizeOutput = normalizeOutput, SetOutputToInputAlpha = setOutputToInputAlpha, - NormalizeInputTensor = normalizeInputTensor, InputResizeMode = inputResizeMode }; return new FeatureExtractorModel(configuration); diff --git a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs index 50af0f96..eabb6451 100644 --- a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs +++ b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs @@ -7,9 +7,10 @@ public record FeatureExtractorModelConfig : OnnxModelConfig { public int SampleSize { get; set; } public int OutputChannels { get; set; } - public bool NormalizeOutputTensor { get; set; } + public bool NormalizeOutput { get; set; } public bool SetOutputToInputAlpha { get; set; } public ImageResizeMode InputResizeMode { get; set; } - public ImageNormalizeType NormalizeInputTensor { get; set; } + public ImageNormalizeType NormalizeType { get; set; } + public bool NormalizeInput { get; set; } } } diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs index 5ca8ee7d..7df52120 100644 --- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs +++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs @@ -62,6 +62,20 @@ public async Task UnloadAsync() } + /// + /// Generates the feature extractor image + /// + /// The input image. + /// + public async Task> RunAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Extracting DenseTensor feature..."); + var result = await ExtractTensorAsync(inputTensor, cancellationToken); + _logger?.LogEnd("Extracting DenseTensor feature complete.", timestamp); + return result; + } + + /// /// Generates the feature extractor image /// @@ -69,9 +83,9 @@ public async Task UnloadAsync() /// public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Extracting image feature..."); - var result = await RunInternalAsync(inputImage, cancellationToken); - _logger?.LogEnd("Extracting image feature complete.", timestamp); + var timestamp = _logger?.LogBegin("Extracting OnnxImage feature..."); + var result = await ExtractImageAsync(inputImage, cancellationToken); + _logger?.LogEnd("Extracting OnnxImage feature complete.", timestamp); return result; } @@ -83,13 +97,13 @@ public async Task RunAsync(OnnxImage inputImage, CancellationToken ca /// public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Extracting video features..."); + var timestamp = _logger?.LogBegin("Extracting OnnxVideo features..."); var featureFrames = new List(); foreach (var videoFrame in video.Frames) { featureFrames.Add(await RunAsync(videoFrame, cancellationToken)); } - _logger?.LogEnd("Extracting video features complete.", timestamp); + _logger?.LogEnd("Extracting OnnxVideo features complete.", timestamp); return new OnnxVideo(video.Info, featureFrames); } @@ -102,28 +116,62 @@ public async Task RunAsync(OnnxVideo video, CancellationToken cancell /// public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var timestamp = _logger?.LogBegin("Extracting video stream features..."); + var timestamp = _logger?.LogBegin("Extracting OnnxImage stream features..."); await foreach (var imageFrame in imageFrames) { - yield return await RunInternalAsync(imageFrame, cancellationToken); + yield return await ExtractImageAsync(imageFrame, cancellationToken); } - _logger?.LogEnd("Extracting video stream features complete.", timestamp); + _logger?.LogEnd("Extracting OnnxImage stream features complete.", timestamp); } /// - /// Runs the pipeline + /// Extracts the feature to OnnxImage. /// /// The input image. /// The cancellation token. /// - private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + private async Task ExtractImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var originalWidth = inputImage.Width; var originalHeight = inputImage.Height; var inputTensor = _featureExtractorModel.SampleSize <= 0 - ? await inputImage.GetImageTensorAsync(_featureExtractorModel.InputNormalization) - : await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, _featureExtractorModel.InputNormalization, resizeMode: _featureExtractorModel.InputResizeMode); + ? await inputImage.GetImageTensorAsync(_featureExtractorModel.NormalizeType) + : await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, _featureExtractorModel.NormalizeType, resizeMode: _featureExtractorModel.InputResizeMode); + + var outputTensor = await RunInternalAsync(inputTensor, cancellationToken); + var imageResult = new OnnxImage(outputTensor, _featureExtractorModel.NormalizeType); + + if (_featureExtractorModel.InputResizeMode == ImageResizeMode.Stretch && (imageResult.Width != originalWidth || imageResult.Height != originalHeight)) + imageResult.Resize(originalHeight, originalWidth, _featureExtractorModel.InputResizeMode); + + return imageResult; + } + + + /// + /// Extracts the feature to DenseTensor. + /// + /// The input tensor. + /// The cancellation token. + /// + public async Task> ExtractTensorAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) + { + if (_featureExtractorModel.NormalizeInput && _featureExtractorModel.NormalizeType == ImageNormalizeType.ZeroToOne) + inputTensor.NormalizeOneOneToZeroOne(); + + return await RunInternalAsync(inputTensor, cancellationToken); + } + + + /// + /// Runs the pipeline + /// + /// The input tensor. + /// The cancellation token. + /// + private async Task> RunInternalAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) + { var metadata = await _featureExtractorModel.GetMetadataAsync(); cancellationToken.ThrowIfCancellationRequested(); var outputShape = new[] { 1, _featureExtractorModel.OutputChannels, inputTensor.Dimensions[2], inputTensor.Dimensions[3] }; @@ -139,21 +187,13 @@ private async Task RunInternalAsync(OnnxImage inputImage, Cancellatio cancellationToken.ThrowIfCancellationRequested(); var outputTensor = inferenceResult.ToDenseTensor(outputShape); - if (_featureExtractorModel.NormalizeOutputTensor) + if (_featureExtractorModel.NormalizeOutput) outputTensor.NormalizeMinMax(); - var imageResult = default(OnnxImage); if (_featureExtractorModel.SetOutputToInputAlpha) - imageResult = new OnnxImage(AddAlphaChannel(inputTensor, outputTensor), _featureExtractorModel.InputNormalization); - else if (_featureExtractorModel.OutputChannels >= 3) - imageResult = new OnnxImage(outputTensor, _featureExtractorModel.InputNormalization); - else - imageResult = outputTensor.ToImageMask(); - - if (_featureExtractorModel.InputResizeMode == ImageResizeMode.Stretch && (imageResult.Width != originalWidth || imageResult.Height != originalHeight)) - imageResult.Resize(originalHeight, originalWidth, _featureExtractorModel.InputResizeMode); + return AddAlphaChannel(inputTensor, outputTensor); - return imageResult; + return outputTensor; } } } @@ -200,7 +240,7 @@ public static FeatureExtractorPipeline CreatePipeline(FeatureExtractorModelSet m /// The execution provider. /// The logger. /// - public static FeatureExtractorPipeline CreatePipeline(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) + public static FeatureExtractorPipeline CreatePipeline(string modelFile, int sampleSize = 0, int outputChannels = 1, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, bool normalizeOutput = false, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) { var name = Path.GetFileNameWithoutExtension(modelFile); var configuration = new FeatureExtractorModelSet @@ -214,9 +254,10 @@ public static FeatureExtractorPipeline CreatePipeline(string modelFile, int samp OnnxModelPath = modelFile, SampleSize = sampleSize, OutputChannels = outputChannels, - NormalizeOutputTensor = normalizeOutputTensor, + NormalizeOutput = normalizeOutput, + NormalizeInput = normalizeInput, + NormalizeType = normalizeType, SetOutputToInputAlpha = setOutputToInputAlpha, - NormalizeInputTensor = normalizeInputTensor, InputResizeMode = inputResizeMode } }; From 8695851e0b8bbf5f1c16944aa76c03968204cd4d Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Mar 2024 16:10:46 +1300 Subject: [PATCH 5/5] Update UI --- OnnxStack.UI/Models/ModelFileViewModel.cs | 14 +++++ ...UpdateFeatureExtractorModelSetViewModel.cs | 8 +-- .../UpdateStableDiffusionModelSetViewModel.cs | 63 ++++++++++++++++--- OnnxStack.UI/Services/ModelFactory.cs | 4 +- 4 files changed, 73 insertions(+), 16 deletions(-) diff --git a/OnnxStack.UI/Models/ModelFileViewModel.cs b/OnnxStack.UI/Models/ModelFileViewModel.cs index 2b7ffb28..b6f281a3 100644 --- a/OnnxStack.UI/Models/ModelFileViewModel.cs +++ b/OnnxStack.UI/Models/ModelFileViewModel.cs @@ -13,6 +13,8 @@ public class ModelFileViewModel : INotifyPropertyChanged private int? _intraOpNumThreads; private ExecutionMode? _executionMode; private ExecutionProvider? _executionProvider; + private OnnxModelPrecision? _precision; + private int _requiredMemory; private bool _isOverrideEnabled; private bool _hasChanged; @@ -52,6 +54,18 @@ public ExecutionProvider? ExecutionProvider set { _executionProvider = value; NotifyPropertyChanged(); } } + public OnnxModelPrecision? Precision + { + get { return _precision; } + set { _precision = value; NotifyPropertyChanged(); } + } + + public int RequiredMemory + { + get { return _requiredMemory; } + set { _requiredMemory = value; NotifyPropertyChanged(); } + } + public bool IsOverrideEnabled { get { return _isOverrideEnabled; } diff --git a/OnnxStack.UI/Models/UpdateFeatureExtractorModelSetViewModel.cs b/OnnxStack.UI/Models/UpdateFeatureExtractorModelSetViewModel.cs index a928f1cc..dc92f1ed 100644 --- a/OnnxStack.UI/Models/UpdateFeatureExtractorModelSetViewModel.cs +++ b/OnnxStack.UI/Models/UpdateFeatureExtractorModelSetViewModel.cs @@ -101,9 +101,9 @@ public static UpdateFeatureExtractorModelSetViewModel FromModelSet(FeatureExtrac ControlNetType = controlNetType, ModelFile = modelset.FeatureExtractorConfig.OnnxModelPath, - Normalize = modelset.FeatureExtractorConfig.Normalize, + Normalize = modelset.FeatureExtractorConfig.NormalizeOutput, SampleSize = modelset.FeatureExtractorConfig.SampleSize, - Channels = modelset.FeatureExtractorConfig.Channels, + Channels = modelset.FeatureExtractorConfig.OutputChannels, }; } @@ -120,8 +120,8 @@ public static FeatureExtractorModelSet ToModelSet(UpdateFeatureExtractorModelSet IntraOpNumThreads = modelset.IntraOpNumThreads, FeatureExtractorConfig = new FeatureExtractorModelConfig { - Channels = modelset.Channels, - Normalize = modelset.Normalize, + OutputChannels = modelset.Channels, + NormalizeOutput = modelset.Normalize, SampleSize = modelset.SampleSize, OnnxModelPath = modelset.ModelFile } diff --git a/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs b/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs index 267592dc..cd54ba81 100644 --- a/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs +++ b/OnnxStack.UI/Models/UpdateStableDiffusionModelSetViewModel.cs @@ -1,4 +1,5 @@ -using Microsoft.ML.OnnxRuntime; +using MathNet.Numerics; +using Microsoft.ML.OnnxRuntime; using OnnxStack.Core.Config; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; @@ -25,8 +26,10 @@ public class UpdateStableDiffusionModelSetViewModel : INotifyPropertyChanged private bool _enableControlNet; private bool _enableControlNetImage; private DiffuserPipelineType _pipelineType; - private int _sampleSize; + private MemoryModeType _memoryMode; + private ModelType _modelType; + private OnnxModelPrecision _precision; private ModelFileViewModel _unetModel; private ModelFileViewModel _vaeEncoderModel; @@ -55,6 +58,7 @@ public int PadTokenId get { return _padTokenId; } set { _padTokenId = value; NotifyPropertyChanged(); } } + public int BlankTokenId { get { return _blankTokenId; } @@ -79,7 +83,6 @@ public int TokenizerLimit set { _tokenizerLimit = value; NotifyPropertyChanged(); } } - public int Tokenizer2Length { get { return _tokenizer2Length; } @@ -165,23 +168,23 @@ public DiffuserPipelineType PipelineType set { _pipelineType = value; NotifyPropertyChanged(); } } - private MemoryModeType _memoryMode; - public MemoryModeType MemoryMode { get { return _memoryMode; } set { _memoryMode = value; NotifyPropertyChanged(); } } - - private ModelType _modelType; - public ModelType ModelType { get { return _modelType; } set { _modelType = value; NotifyPropertyChanged(); } } + public OnnxModelPrecision Precision + { + get { return _precision; } + set { _precision = value; NotifyPropertyChanged(); } + } public ModelFileViewModel UnetModel { @@ -272,125 +275,151 @@ public static UpdateStableDiffusionModelSetViewModel FromModelSet(StableDiffusio Tokenizer2Length = modelset.Tokenizer2Config?.TokenizerLength ?? 1280, ModelType = modelset.UnetConfig.ModelType, ScaleFactor = modelset.VaeDecoderConfig.ScaleFactor, - + Precision = modelset.Precision, UnetModel = new ModelFileViewModel { OnnxModelPath = modelset.UnetConfig.OnnxModelPath, + RequiredMemory = modelset.UnetConfig.RequiredMemory, DeviceId = modelset.UnetConfig.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.UnetConfig.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.UnetConfig.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.UnetConfig.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.UnetConfig.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + Precision = modelset.UnetConfig.Precision ?? modelset.Precision, IsOverrideEnabled = modelset.UnetConfig.DeviceId.HasValue || modelset.UnetConfig.ExecutionMode.HasValue || modelset.UnetConfig.ExecutionProvider.HasValue || modelset.UnetConfig.IntraOpNumThreads.HasValue || modelset.UnetConfig.InterOpNumThreads.HasValue + || modelset.UnetConfig.Precision.HasValue }, TokenizerModel = new ModelFileViewModel { OnnxModelPath = modelset.TokenizerConfig.OnnxModelPath, + RequiredMemory = modelset.TokenizerConfig.RequiredMemory, DeviceId = modelset.TokenizerConfig.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.TokenizerConfig.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.TokenizerConfig.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.TokenizerConfig.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.TokenizerConfig.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + Precision = modelset.TokenizerConfig.Precision ?? modelset.Precision, + IsOverrideEnabled = modelset.TokenizerConfig.DeviceId.HasValue || modelset.TokenizerConfig.ExecutionMode.HasValue || modelset.TokenizerConfig.ExecutionProvider.HasValue || modelset.TokenizerConfig.IntraOpNumThreads.HasValue || modelset.TokenizerConfig.InterOpNumThreads.HasValue + || modelset.TokenizerConfig.Precision.HasValue }, Tokenizer2Model = modelset.Tokenizer2Config is null ? default : new ModelFileViewModel { OnnxModelPath = modelset.Tokenizer2Config.OnnxModelPath, + RequiredMemory = modelset.Tokenizer2Config.RequiredMemory, DeviceId = modelset.Tokenizer2Config.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.Tokenizer2Config.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.Tokenizer2Config.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.Tokenizer2Config.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.Tokenizer2Config.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + Precision = modelset.Tokenizer2Config.Precision ?? modelset.Precision, + IsOverrideEnabled = modelset.Tokenizer2Config.DeviceId.HasValue || modelset.Tokenizer2Config.ExecutionMode.HasValue || modelset.Tokenizer2Config.ExecutionProvider.HasValue || modelset.Tokenizer2Config.IntraOpNumThreads.HasValue || modelset.Tokenizer2Config.InterOpNumThreads.HasValue + || modelset.Tokenizer2Config.Precision.HasValue }, TextEncoderModel = new ModelFileViewModel { OnnxModelPath = modelset.TextEncoderConfig.OnnxModelPath, + RequiredMemory = modelset.TextEncoderConfig.RequiredMemory, DeviceId = modelset.TextEncoderConfig.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.TextEncoderConfig.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.TextEncoderConfig.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.TextEncoderConfig.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.TextEncoderConfig.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + Precision = modelset.TextEncoderConfig.Precision ?? modelset.Precision, + IsOverrideEnabled = modelset.TextEncoderConfig.DeviceId.HasValue || modelset.TextEncoderConfig.ExecutionMode.HasValue || modelset.TextEncoderConfig.ExecutionProvider.HasValue || modelset.TextEncoderConfig.IntraOpNumThreads.HasValue || modelset.TextEncoderConfig.InterOpNumThreads.HasValue + || modelset.TextEncoderConfig.Precision.HasValue }, TextEncoder2Model = modelset.TextEncoder2Config is null ? default : new ModelFileViewModel { OnnxModelPath = modelset.TextEncoder2Config.OnnxModelPath, + RequiredMemory = modelset.TextEncoder2Config.RequiredMemory, DeviceId = modelset.TextEncoder2Config.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.TextEncoder2Config.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.TextEncoder2Config.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.TextEncoder2Config.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.TextEncoder2Config.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + Precision = modelset.TextEncoder2Config.Precision ?? modelset.Precision, + IsOverrideEnabled = modelset.TextEncoder2Config.DeviceId.HasValue || modelset.TextEncoder2Config.ExecutionMode.HasValue || modelset.TextEncoder2Config.ExecutionProvider.HasValue || modelset.TextEncoder2Config.IntraOpNumThreads.HasValue || modelset.TextEncoder2Config.InterOpNumThreads.HasValue + || modelset.TextEncoder2Config.Precision.HasValue }, VaeDecoderModel = new ModelFileViewModel { OnnxModelPath = modelset.VaeDecoderConfig.OnnxModelPath, + RequiredMemory = modelset.VaeDecoderConfig.RequiredMemory, DeviceId = modelset.VaeDecoderConfig.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.VaeDecoderConfig.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.VaeDecoderConfig.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.VaeDecoderConfig.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.VaeDecoderConfig.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + IsOverrideEnabled = modelset.VaeDecoderConfig.DeviceId.HasValue || modelset.VaeDecoderConfig.ExecutionMode.HasValue || modelset.VaeDecoderConfig.ExecutionProvider.HasValue || modelset.VaeDecoderConfig.IntraOpNumThreads.HasValue || modelset.VaeDecoderConfig.InterOpNumThreads.HasValue + || modelset.VaeDecoderConfig.Precision.HasValue }, VaeEncoderModel = new ModelFileViewModel { OnnxModelPath = modelset.VaeEncoderConfig.OnnxModelPath, + RequiredMemory = modelset.VaeEncoderConfig.RequiredMemory, DeviceId = modelset.VaeEncoderConfig.DeviceId ?? modelset.DeviceId, ExecutionMode = modelset.VaeEncoderConfig.ExecutionMode ?? modelset.ExecutionMode, ExecutionProvider = modelset.VaeEncoderConfig.ExecutionProvider ?? modelset.ExecutionProvider, InterOpNumThreads = modelset.VaeEncoderConfig.InterOpNumThreads ?? modelset.InterOpNumThreads, IntraOpNumThreads = modelset.VaeEncoderConfig.IntraOpNumThreads ?? modelset.IntraOpNumThreads, + Precision = modelset.VaeEncoderConfig.Precision ?? modelset.Precision, + IsOverrideEnabled = modelset.VaeEncoderConfig.DeviceId.HasValue || modelset.VaeEncoderConfig.ExecutionMode.HasValue || modelset.VaeEncoderConfig.ExecutionProvider.HasValue || modelset.VaeEncoderConfig.IntraOpNumThreads.HasValue || modelset.VaeEncoderConfig.InterOpNumThreads.HasValue + || modelset.VaeEncoderConfig.Precision.HasValue } }; @@ -411,18 +440,20 @@ public static StableDiffusionModelSet ToModelSet(UpdateStableDiffusionModelSetVi ExecutionProvider = modelset.ExecutionProvider, InterOpNumThreads = modelset.InterOpNumThreads, IntraOpNumThreads = modelset.IntraOpNumThreads, - + Precision = modelset.Precision, MemoryMode = modelset.MemoryMode, UnetConfig = new UNetConditionModelConfig { ModelType = modelset.ModelType, OnnxModelPath = modelset.UnetModel.OnnxModelPath, + RequiredMemory = modelset.UnetModel.RequiredMemory, DeviceId = modelset.UnetModel.IsOverrideEnabled && modelset.DeviceId != modelset.UnetModel.DeviceId ? modelset.UnetModel.DeviceId : default, ExecutionMode = modelset.UnetModel.IsOverrideEnabled && modelset.ExecutionMode != modelset.UnetModel.ExecutionMode ? modelset.UnetModel.ExecutionMode : default, ExecutionProvider = modelset.UnetModel.IsOverrideEnabled && modelset.ExecutionProvider != modelset.UnetModel.ExecutionProvider ? modelset.UnetModel.ExecutionProvider : default, IntraOpNumThreads = modelset.UnetModel.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.UnetModel.IntraOpNumThreads ? modelset.UnetModel.IntraOpNumThreads : default, InterOpNumThreads = modelset.UnetModel.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.UnetModel.InterOpNumThreads ? modelset.UnetModel.InterOpNumThreads : default, + Precision = modelset.UnetModel.IsOverrideEnabled && modelset.Precision != modelset.UnetModel.Precision ? modelset.UnetModel.Precision : default, }, TokenizerConfig = new TokenizerModelConfig @@ -432,11 +463,13 @@ public static StableDiffusionModelSet ToModelSet(UpdateStableDiffusionModelSetVi TokenizerLimit = modelset.TokenizerLimit, TokenizerLength = modelset.TokenizerLength, OnnxModelPath = modelset.TokenizerModel.OnnxModelPath, + RequiredMemory = modelset.TokenizerModel.RequiredMemory, DeviceId = modelset.TokenizerModel.IsOverrideEnabled && modelset.DeviceId != modelset.TokenizerModel.DeviceId ? modelset.TokenizerModel.DeviceId : default, ExecutionMode = modelset.TokenizerModel.IsOverrideEnabled && modelset.ExecutionMode != modelset.TokenizerModel.ExecutionMode ? modelset.TokenizerModel.ExecutionMode : default, ExecutionProvider = modelset.TokenizerModel.IsOverrideEnabled && modelset.ExecutionProvider != modelset.TokenizerModel.ExecutionProvider ? modelset.TokenizerModel.ExecutionProvider : default, IntraOpNumThreads = modelset.TokenizerModel.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.TokenizerModel.IntraOpNumThreads ? modelset.TokenizerModel.IntraOpNumThreads : default, InterOpNumThreads = modelset.TokenizerModel.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.TokenizerModel.InterOpNumThreads ? modelset.TokenizerModel.InterOpNumThreads : default, + Precision = modelset.TokenizerModel.IsOverrideEnabled && modelset.Precision != modelset.TokenizerModel.Precision ? modelset.TokenizerModel.Precision : default, }, Tokenizer2Config = modelset.Tokenizer2Model is null ? default : new TokenizerModelConfig @@ -446,53 +479,63 @@ public static StableDiffusionModelSet ToModelSet(UpdateStableDiffusionModelSetVi TokenizerLimit = modelset.TokenizerLimit, TokenizerLength = modelset.Tokenizer2Length, OnnxModelPath = modelset.Tokenizer2Model.OnnxModelPath, + RequiredMemory = modelset.Tokenizer2Model.RequiredMemory, DeviceId = modelset.Tokenizer2Model.IsOverrideEnabled && modelset.DeviceId != modelset.Tokenizer2Model.DeviceId ? modelset.Tokenizer2Model.DeviceId : default, ExecutionMode = modelset.Tokenizer2Model.IsOverrideEnabled && modelset.ExecutionMode != modelset.Tokenizer2Model.ExecutionMode ? modelset.Tokenizer2Model.ExecutionMode : default, ExecutionProvider = modelset.Tokenizer2Model.IsOverrideEnabled && modelset.ExecutionProvider != modelset.Tokenizer2Model.ExecutionProvider ? modelset.Tokenizer2Model.ExecutionProvider : default, IntraOpNumThreads = modelset.Tokenizer2Model.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.Tokenizer2Model.IntraOpNumThreads ? modelset.Tokenizer2Model.IntraOpNumThreads : default, InterOpNumThreads = modelset.Tokenizer2Model.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.Tokenizer2Model.InterOpNumThreads ? modelset.Tokenizer2Model.InterOpNumThreads : default, + Precision = modelset.Tokenizer2Model.IsOverrideEnabled && modelset.Precision != modelset.Tokenizer2Model.Precision ? modelset.Tokenizer2Model.Precision : default, }, TextEncoderConfig = new TextEncoderModelConfig { OnnxModelPath = modelset.TextEncoderModel.OnnxModelPath, + RequiredMemory = modelset.TextEncoderModel.RequiredMemory, DeviceId = modelset.TextEncoderModel.IsOverrideEnabled && modelset.DeviceId != modelset.TextEncoderModel.DeviceId ? modelset.TextEncoderModel.DeviceId : default, ExecutionMode = modelset.TextEncoderModel.IsOverrideEnabled && modelset.ExecutionMode != modelset.TextEncoderModel.ExecutionMode ? modelset.TextEncoderModel.ExecutionMode : default, ExecutionProvider = modelset.TextEncoderModel.IsOverrideEnabled && modelset.ExecutionProvider != modelset.TextEncoderModel.ExecutionProvider ? modelset.TextEncoderModel.ExecutionProvider : default, IntraOpNumThreads = modelset.TextEncoderModel.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.TextEncoderModel.IntraOpNumThreads ? modelset.TextEncoderModel.IntraOpNumThreads : default, InterOpNumThreads = modelset.TextEncoderModel.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.TextEncoderModel.InterOpNumThreads ? modelset.TextEncoderModel.InterOpNumThreads : default, + Precision = modelset.TextEncoderModel.IsOverrideEnabled && modelset.Precision != modelset.TextEncoderModel.Precision ? modelset.TextEncoderModel.Precision : default, }, TextEncoder2Config = modelset.TextEncoder2Model is null ? default : new TextEncoderModelConfig { OnnxModelPath = modelset.TextEncoder2Model.OnnxModelPath, + RequiredMemory = modelset.TextEncoder2Model.RequiredMemory, DeviceId = modelset.TextEncoder2Model.IsOverrideEnabled && modelset.DeviceId != modelset.TextEncoder2Model.DeviceId ? modelset.TextEncoder2Model.DeviceId : default, ExecutionMode = modelset.TextEncoder2Model.IsOverrideEnabled && modelset.ExecutionMode != modelset.TextEncoder2Model.ExecutionMode ? modelset.TextEncoder2Model.ExecutionMode : default, ExecutionProvider = modelset.TextEncoder2Model.IsOverrideEnabled && modelset.ExecutionProvider != modelset.TextEncoder2Model.ExecutionProvider ? modelset.TextEncoder2Model.ExecutionProvider : default, IntraOpNumThreads = modelset.TextEncoder2Model.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.TextEncoder2Model.IntraOpNumThreads ? modelset.TextEncoder2Model.IntraOpNumThreads : default, InterOpNumThreads = modelset.TextEncoder2Model.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.TextEncoder2Model.InterOpNumThreads ? modelset.TextEncoder2Model.InterOpNumThreads : default, + Precision = modelset.TextEncoder2Model.IsOverrideEnabled && modelset.Precision != modelset.TextEncoder2Model.Precision ? modelset.TextEncoder2Model.Precision : default, }, VaeDecoderConfig = new AutoEncoderModelConfig { ScaleFactor = modelset.ScaleFactor, OnnxModelPath = modelset.VaeDecoderModel.OnnxModelPath, + RequiredMemory = modelset.VaeDecoderModel.RequiredMemory, DeviceId = modelset.VaeDecoderModel.IsOverrideEnabled && modelset.DeviceId != modelset.VaeDecoderModel.DeviceId ? modelset.VaeDecoderModel.DeviceId : default, ExecutionMode = modelset.VaeDecoderModel.IsOverrideEnabled && modelset.ExecutionMode != modelset.VaeDecoderModel.ExecutionMode ? modelset.VaeDecoderModel.ExecutionMode : default, ExecutionProvider = modelset.VaeDecoderModel.IsOverrideEnabled && modelset.ExecutionProvider != modelset.VaeDecoderModel.ExecutionProvider ? modelset.VaeDecoderModel.ExecutionProvider : default, IntraOpNumThreads = modelset.VaeDecoderModel.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.VaeDecoderModel.IntraOpNumThreads ? modelset.VaeDecoderModel.IntraOpNumThreads : default, InterOpNumThreads = modelset.VaeDecoderModel.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.VaeDecoderModel.InterOpNumThreads ? modelset.VaeDecoderModel.InterOpNumThreads : default, + Precision = modelset.VaeDecoderModel.IsOverrideEnabled && modelset.Precision != modelset.VaeDecoderModel.Precision ? modelset.VaeDecoderModel.Precision : default, }, VaeEncoderConfig = new AutoEncoderModelConfig { ScaleFactor = modelset.ScaleFactor, OnnxModelPath = modelset.VaeEncoderModel.OnnxModelPath, + RequiredMemory = modelset.VaeEncoderModel.RequiredMemory, DeviceId = modelset.VaeEncoderModel.IsOverrideEnabled && modelset.DeviceId != modelset.VaeEncoderModel.DeviceId ? modelset.VaeEncoderModel.DeviceId : default, ExecutionMode = modelset.VaeEncoderModel.IsOverrideEnabled && modelset.ExecutionMode != modelset.VaeEncoderModel.ExecutionMode ? modelset.VaeEncoderModel.ExecutionMode : default, ExecutionProvider = modelset.VaeEncoderModel.IsOverrideEnabled && modelset.ExecutionProvider != modelset.VaeEncoderModel.ExecutionProvider ? modelset.VaeEncoderModel.ExecutionProvider : default, IntraOpNumThreads = modelset.VaeEncoderModel.IsOverrideEnabled && modelset.IntraOpNumThreads != modelset.VaeEncoderModel.IntraOpNumThreads ? modelset.VaeEncoderModel.IntraOpNumThreads : default, InterOpNumThreads = modelset.VaeEncoderModel.IsOverrideEnabled && modelset.InterOpNumThreads != modelset.VaeEncoderModel.InterOpNumThreads ? modelset.VaeEncoderModel.InterOpNumThreads : default, + Precision = modelset.VaeEncoderModel.IsOverrideEnabled && modelset.Precision != modelset.VaeEncoderModel.Precision ? modelset.VaeEncoderModel.Precision : default, } }; diff --git a/OnnxStack.UI/Services/ModelFactory.cs b/OnnxStack.UI/Services/ModelFactory.cs index 3e22d72a..f719e758 100644 --- a/OnnxStack.UI/Services/ModelFactory.cs +++ b/OnnxStack.UI/Services/ModelFactory.cs @@ -178,8 +178,8 @@ public FeatureExtractorModelSet CreateFeatureExtractorModelSet(string name, bool IntraOpNumThreads = _settings.DefaultIntraOpNumThreads, FeatureExtractorConfig = new FeatureExtractorModelConfig { - Channels = channels, - Normalize = normalize, + OutputChannels = channels, + NormalizeOutput = normalize, SampleSize = sampleSize, OnnxModelPath = modelFilename }