diff --git a/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs new file mode 100644 index 0000000..40551b8 --- /dev/null +++ b/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs @@ -0,0 +1,51 @@ +using OnnxStack.Core.Image; +using OnnxStack.FeatureExtractor.Pipelines; +using System.Diagnostics; + +namespace OnnxStack.Console.Runner +{ + public sealed class BackgroundRemovalImageExample : IExampleRunner + { + private readonly string _outputDirectory; + + public BackgroundRemovalImageExample() + { + _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample"); + Directory.CreateDirectory(_outputDirectory); + } + + public int Index => 20; + + public string Name => "Image Background Removal Example"; + + public string Description => "Remove a background from an image"; + + /// + /// ControlNet Example + /// + public async Task RunAsync() + { + OutputHelpers.WriteConsole("Please enter an image file path and press ENTER", ConsoleColor.Yellow); + var imageFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan); + + var timestamp = Stopwatch.GetTimestamp(); + + OutputHelpers.WriteConsole($"Load Image", ConsoleColor.Gray); + var inputImage = await OnnxImage.FromFileAsync(imageFile); + + OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray); + var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024); + + OutputHelpers.WriteConsole($"Run Pipeline", ConsoleColor.Gray); + var imageFeature = await pipeline.RunAsync(inputImage); + + OutputHelpers.WriteConsole($"Save Image", ConsoleColor.Gray); + await imageFeature.SaveAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png")); + + OutputHelpers.WriteConsole($"Unload pipeline", ConsoleColor.Gray); + await pipeline.UnloadAsync(); + + OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow); + } + } +} diff --git a/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs new file mode 100644 index 0000000..788b8d1 --- /dev/null +++ b/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs @@ -0,0 +1,54 @@ +using OnnxStack.Core.Video; +using OnnxStack.FeatureExtractor.Pipelines; +using System.Diagnostics; + +namespace OnnxStack.Console.Runner +{ + public sealed class BackgroundRemovalVideoExample : IExampleRunner + { + private readonly string _outputDirectory; + + public BackgroundRemovalVideoExample() + { + _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample"); + Directory.CreateDirectory(_outputDirectory); + } + + public int Index => 21; + + public string Name => "Video Background Removal Example"; + + public string Description => "Remove a background from an video"; + + public async Task RunAsync() + { + OutputHelpers.WriteConsole("Please enter an video/gif file path and press ENTER", ConsoleColor.Yellow); + var videoFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan); + + var timestamp = Stopwatch.GetTimestamp(); + + OutputHelpers.WriteConsole($"Read Video", ConsoleColor.Gray); + var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile); + + OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray); + var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024); + + OutputHelpers.WriteConsole($"Load Pipeline", ConsoleColor.Gray); + await pipeline.LoadAsync(); + + OutputHelpers.WriteConsole($"Create Video Stream", ConsoleColor.Gray); + var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate); + + OutputHelpers.WriteConsole($"Create Pipeline Stream", ConsoleColor.Gray); + var pipelineStream = pipeline.RunAsync(videoStream); + + OutputHelpers.WriteConsole($"Write Video Stream", ConsoleColor.Gray); + await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4"), true); + + OutputHelpers.WriteConsole($"Unload", ConsoleColor.Gray); + await pipeline.UnloadAsync(); + + OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow); + } + } +} diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index db21a27..6e93467 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -397,7 +397,7 @@ private static DenseTensor ConcatenateAxis1(DenseTensor tensor1, D // Copy data from the second tensor for (int i = 0; i < dimensions[0]; i++) - for (int j = 0; j < tensor1.Dimensions[1]; j++) + for (int j = 0; j < tensor2.Dimensions[1]; j++) concatenatedTensor[i, j + tensor1.Dimensions[1]] = tensor2[i, j]; return concatenatedTensor; diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs index d0998e8..643e514 100644 --- a/OnnxStack.Core/Image/OnnxImage.cs +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -64,6 +64,7 @@ public OnnxImage(DenseTensor imageTensor, ImageNormalizeType normalizeTyp { var height = imageTensor.Dimensions[2]; var width = imageTensor.Dimensions[3]; + var hasTransparency = imageTensor.Dimensions[1] == 4; _imageData = new Image(width, height); for (var y = 0; y < height; y++) { @@ -74,14 +75,16 @@ public OnnxImage(DenseTensor imageTensor, ImageNormalizeType normalizeTyp _imageData[x, y] = new Rgba32( DenormalizeZeroToOneToByte(imageTensor, 0, y, x), DenormalizeZeroToOneToByte(imageTensor, 1, y, x), - DenormalizeZeroToOneToByte(imageTensor, 2, y, x)); + DenormalizeZeroToOneToByte(imageTensor, 2, y, x), + hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue); } else { _imageData[x, y] = new Rgba32( DenormalizeOneToOneToByte(imageTensor, 0, y, x), DenormalizeOneToOneToByte(imageTensor, 1, y, x), - DenormalizeOneToOneToByte(imageTensor, 2, y, x)); + DenormalizeOneToOneToByte(imageTensor, 2, y, x), + hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue); } } } @@ -337,6 +340,7 @@ private DenseTensor NormalizeToZeroToOne(ReadOnlySpan dimensions) var width = dimensions[3]; var height = dimensions[2]; var channels = dimensions[1]; + var hasTransparency = channels == 4; var imageArray = new DenseTensor(new[] { 1, channels, height, width }); _imageData.ProcessPixelRows(img => { @@ -348,6 +352,8 @@ private DenseTensor NormalizeToZeroToOne(ReadOnlySpan dimensions) imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f); imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f); imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f); + if (hasTransparency) + imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f); } } }); @@ -366,6 +372,7 @@ private DenseTensor NormalizeToOneToOne(ReadOnlySpan dimensions) var width = dimensions[3]; var height = dimensions[2]; var channels = dimensions[1]; + var hasTransparency = channels == 4; var imageArray = new DenseTensor(new[] { 1, channels, height, width }); _imageData.ProcessPixelRows(img => { @@ -377,6 +384,8 @@ private DenseTensor NormalizeToOneToOne(ReadOnlySpan dimensions) imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f; imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f; imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f; + if (hasTransparency) + imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f) * 2.0f - 1.0f; } } }); diff --git a/OnnxStack.Core/Video/OnnxVideo.cs b/OnnxStack.Core/Video/OnnxVideo.cs index 8569dd6..ee8f3e9 100644 --- a/OnnxStack.Core/Video/OnnxVideo.cs +++ b/OnnxStack.Core/Video/OnnxVideo.cs @@ -137,9 +137,9 @@ public void Resize(int height, int width) /// The filename. /// The cancellation token. /// - public Task SaveAsync(string filename, CancellationToken cancellationToken = default) + public Task SaveAsync(string filename, bool preserveTransparency = false, CancellationToken cancellationToken = default) { - return VideoHelper.WriteVideoFramesAsync(this, filename, cancellationToken); + return VideoHelper.WriteVideoFramesAsync(this, filename, preserveTransparency, cancellationToken); } diff --git a/OnnxStack.Core/Video/VideoHelper.cs b/OnnxStack.Core/Video/VideoHelper.cs index a574837..b8c6721 100644 --- a/OnnxStack.Core/Video/VideoHelper.cs +++ b/OnnxStack.Core/Video/VideoHelper.cs @@ -32,9 +32,9 @@ public static void SetConfiguration(OnnxStackConfig configuration) /// The onnx video. /// The filename. /// The cancellation token. - public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filename, CancellationToken cancellationToken = default) + public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filename, bool preserveTransparency = false, CancellationToken cancellationToken = default) { - await WriteVideoFramesAsync(onnxVideo.Frames, filename, onnxVideo.FrameRate, onnxVideo.AspectRatio, cancellationToken); + await WriteVideoFramesAsync(onnxVideo.Frames, filename, onnxVideo.FrameRate, onnxVideo.AspectRatio, preserveTransparency, cancellationToken); } @@ -45,11 +45,11 @@ public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filen /// The filename. /// The frame rate. /// The cancellation token. - public static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate = 15, CancellationToken cancellationToken = default) + public static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate = 15, bool preserveTransparency = false, CancellationToken cancellationToken = default) { var firstImage = onnxImages.First(); var aspectRatio = (double)firstImage.Width / firstImage.Height; - await WriteVideoFramesAsync(onnxImages, filename, frameRate, aspectRatio, cancellationToken); + await WriteVideoFramesAsync(onnxImages, filename, frameRate, aspectRatio, preserveTransparency, cancellationToken); } @@ -61,12 +61,12 @@ public static async Task WriteVideoFramesAsync(IEnumerable onnxImages /// The frame rate. /// The aspect ratio. /// The cancellation token. - private static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate, double aspectRatio, CancellationToken cancellationToken = default) + private static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate, double aspectRatio, bool preserveTransparency, CancellationToken cancellationToken = default) { if (File.Exists(filename)) File.Delete(filename); - using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio)) + using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio, preserveTransparency)) { // Start FFMPEG videoWriter.Start(); @@ -91,12 +91,12 @@ private static async Task WriteVideoFramesAsync(IEnumerable onnxImage /// The frame rate. /// The aspect ratio. /// The cancellation token. - public static async Task WriteVideoStreamAsync(VideoInfo videoInfo, IAsyncEnumerable videoStream, string filename, CancellationToken cancellationToken = default) + public static async Task WriteVideoStreamAsync(VideoInfo videoInfo, IAsyncEnumerable videoStream, string filename, bool preserveTransparency = false, CancellationToken cancellationToken = default) { if (File.Exists(filename)) File.Delete(filename); - using (var videoWriter = CreateWriter(filename, videoInfo.FrameRate, videoInfo.AspectRatio)) + using (var videoWriter = CreateWriter(filename, videoInfo.FrameRate, videoInfo.AspectRatio, preserveTransparency)) { // Start FFMPEG videoWriter.Start(); @@ -323,11 +323,13 @@ private static Process CreateReader(string inputFile, float fps) /// The FPS. /// The aspect ratio. /// - private static Process CreateWriter(string outputFile, float fps, double aspectRatio) + private static Process CreateWriter(string outputFile, float fps, double aspectRatio, bool preserveTransparency) { var ffmpegProcess = new Process(); + var codec = preserveTransparency ? "png" : "libx264"; + var format = preserveTransparency ? "yuva420p" : "yuv420p"; ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; - ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v libx264 -movflags +faststart -vf format=yuv420p -aspect {aspectRatio} {outputFile}"; + ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v {codec} -movflags +faststart -vf format={format} -aspect {aspectRatio} {outputFile}"; ffmpegProcess.StartInfo.RedirectStandardInput = true; ffmpegProcess.StartInfo.UseShellExecute = false; ffmpegProcess.StartInfo.CreateNoWindow = true; diff --git a/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs new file mode 100644 index 0000000..7cd69b8 --- /dev/null +++ b/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs @@ -0,0 +1,205 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Image; +using OnnxStack.Core.Model; +using OnnxStack.Core.Video; +using OnnxStack.FeatureExtractor.Common; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.FeatureExtractor.Pipelines +{ + public class BackgroundRemovalPipeline + { + private readonly string _name; + private readonly ILogger _logger; + private readonly FeatureExtractorModel _model; + + /// + /// Initializes a new instance of the class. + /// + /// The name. + /// The model. + /// The logger. + public BackgroundRemovalPipeline(string name, FeatureExtractorModel model, ILogger logger = default) + { + _name = name; + _logger = logger; + _model = model; + } + + + /// + /// Gets the name. + /// + /// + public string Name => _name; + + + /// + /// Loads the model. + /// + /// + public Task LoadAsync() + { + return _model.LoadAsync(); + } + + + /// + /// Unloads the models. + /// + public async Task UnloadAsync() + { + await Task.Yield(); + _model?.Dispose(); + } + + + /// + /// Generates the background removal image result + /// + /// The input image. + /// + public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Removing video background..."); + var result = await RunInternalAsync(inputImage, cancellationToken); + _logger?.LogEnd("Removing video background complete.", timestamp); + return result; + } + + + /// + /// Generates the background removal video result + /// + /// The input video. + /// + public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Removing video background..."); + var videoFrames = new List(); + foreach (var videoFrame in video.Frames) + { + videoFrames.Add(await RunInternalAsync(videoFrame, cancellationToken)); + } + _logger?.LogEnd("Removing video background complete.", timestamp); + return new OnnxVideo(video.Info with + { + Height = videoFrames[0].Height, + Width = videoFrames[0].Width, + }, videoFrames); + } + + + /// + /// Generates the background removal video stream + /// + /// The image frames. + /// The cancellation token. + /// + public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Extracting video stream features..."); + await foreach (var imageFrame in imageFrames) + { + yield return await RunInternalAsync(imageFrame, cancellationToken); + } + _logger?.LogEnd("Extracting video stream features complete.", timestamp); + } + + + /// + /// Runs the pipeline + /// + /// The input image. + /// The cancellation token. + /// + private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + { + var souceImageTenssor = await inputImage.GetImageTensorAsync(_model.SampleSize, _model.SampleSize, ImageNormalizeType.ZeroToOne); + var metadata = await _model.GetMetadataAsync(); + cancellationToken.ThrowIfCancellationRequested(); + var outputShape = new[] { 1, _model.Channels, _model.SampleSize, _model.SampleSize }; + var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..]; + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(souceImageTenssor); + inferenceParameters.AddOutputBuffer(outputBuffer); + + var results = await _model.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) + { + cancellationToken.ThrowIfCancellationRequested(); + + var imageTensor = AddAlphaChannel(souceImageTenssor, result.GetTensorDataAsSpan()); + return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne); + } + } + } + + + /// + /// Adds an alpha channel to the RGB tensor. + /// + /// The source image. + /// The alpha channel. + /// + private static DenseTensor AddAlphaChannel(DenseTensor sourceImage, ReadOnlySpan alphaChannel) + { + var resultTensor = new DenseTensor(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] }); + sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span); + alphaChannel.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span); + return resultTensor; + } + + + /// + /// Creates the pipeline from a FeatureExtractorModelSet. + /// + /// The model set. + /// The logger. + /// + public static BackgroundRemovalPipeline CreatePipeline(FeatureExtractorModelSet modelSet, ILogger logger = default) + { + var model = new FeatureExtractorModel(modelSet.FeatureExtractorConfig.ApplyDefaults(modelSet)); + return new BackgroundRemovalPipeline(modelSet.Name, model, logger); + } + + + /// + /// Creates the pipeline from the specified file. + /// + /// The model file. + /// The device identifier. + /// The execution provider. + /// The logger. + /// + public static BackgroundRemovalPipeline CreatePipeline(string modelFile, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) + { + var name = Path.GetFileNameWithoutExtension(modelFile); + var configuration = new FeatureExtractorModelSet + { + Name = name, + IsEnabled = true, + DeviceId = deviceId, + ExecutionProvider = executionProvider, + FeatureExtractorConfig = new FeatureExtractorModelConfig + { + OnnxModelPath = modelFile, + SampleSize = sampleSize, + Normalize = false, + Channels = 1 + } + }; + return CreatePipeline(configuration, logger); + } + } +} diff --git a/OnnxStack.FeatureExtractor/README.md b/OnnxStack.FeatureExtractor/README.md index a09360b..35ebc2d 100644 --- a/OnnxStack.FeatureExtractor/README.md +++ b/OnnxStack.FeatureExtractor/README.md @@ -14,6 +14,10 @@ ### OpenPose (TODO) * https://huggingface.co/axodoxian/controlnet_onnx/resolve/main/annotators/openpose.onnx +### Background Removal +* https://huggingface.co/briaai/RMBG-1.4/resolve/main/onnx/model.onnx + + # Image Example ```csharp // Load Input Image