diff --git a/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs
new file mode 100644
index 0000000..40551b8
--- /dev/null
+++ b/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs
@@ -0,0 +1,51 @@
+using OnnxStack.Core.Image;
+using OnnxStack.FeatureExtractor.Pipelines;
+using System.Diagnostics;
+
+namespace OnnxStack.Console.Runner
+{
+ public sealed class BackgroundRemovalImageExample : IExampleRunner
+ {
+ private readonly string _outputDirectory;
+
+ public BackgroundRemovalImageExample()
+ {
+ _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample");
+ Directory.CreateDirectory(_outputDirectory);
+ }
+
+ public int Index => 20;
+
+ public string Name => "Image Background Removal Example";
+
+ public string Description => "Remove a background from an image";
+
+ ///
+ /// ControlNet Example
+ ///
+ public async Task RunAsync()
+ {
+ OutputHelpers.WriteConsole("Please enter an image file path and press ENTER", ConsoleColor.Yellow);
+ var imageFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
+
+ var timestamp = Stopwatch.GetTimestamp();
+
+ OutputHelpers.WriteConsole($"Load Image", ConsoleColor.Gray);
+ var inputImage = await OnnxImage.FromFileAsync(imageFile);
+
+ OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray);
+ var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024);
+
+ OutputHelpers.WriteConsole($"Run Pipeline", ConsoleColor.Gray);
+ var imageFeature = await pipeline.RunAsync(inputImage);
+
+ OutputHelpers.WriteConsole($"Save Image", ConsoleColor.Gray);
+ await imageFeature.SaveAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png"));
+
+ OutputHelpers.WriteConsole($"Unload pipeline", ConsoleColor.Gray);
+ await pipeline.UnloadAsync();
+
+ OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
+ }
+ }
+}
diff --git a/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs
new file mode 100644
index 0000000..788b8d1
--- /dev/null
+++ b/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs
@@ -0,0 +1,54 @@
+using OnnxStack.Core.Video;
+using OnnxStack.FeatureExtractor.Pipelines;
+using System.Diagnostics;
+
+namespace OnnxStack.Console.Runner
+{
+ public sealed class BackgroundRemovalVideoExample : IExampleRunner
+ {
+ private readonly string _outputDirectory;
+
+ public BackgroundRemovalVideoExample()
+ {
+ _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample");
+ Directory.CreateDirectory(_outputDirectory);
+ }
+
+ public int Index => 21;
+
+ public string Name => "Video Background Removal Example";
+
+ public string Description => "Remove a background from an video";
+
+ public async Task RunAsync()
+ {
+ OutputHelpers.WriteConsole("Please enter an video/gif file path and press ENTER", ConsoleColor.Yellow);
+ var videoFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
+
+ var timestamp = Stopwatch.GetTimestamp();
+
+ OutputHelpers.WriteConsole($"Read Video", ConsoleColor.Gray);
+ var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile);
+
+ OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray);
+ var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024);
+
+ OutputHelpers.WriteConsole($"Load Pipeline", ConsoleColor.Gray);
+ await pipeline.LoadAsync();
+
+ OutputHelpers.WriteConsole($"Create Video Stream", ConsoleColor.Gray);
+ var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate);
+
+ OutputHelpers.WriteConsole($"Create Pipeline Stream", ConsoleColor.Gray);
+ var pipelineStream = pipeline.RunAsync(videoStream);
+
+ OutputHelpers.WriteConsole($"Write Video Stream", ConsoleColor.Gray);
+ await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4"), true);
+
+ OutputHelpers.WriteConsole($"Unload", ConsoleColor.Gray);
+ await pipeline.UnloadAsync();
+
+ OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
+ }
+ }
+}
diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs
index db21a27..6e93467 100644
--- a/OnnxStack.Core/Extensions/TensorExtension.cs
+++ b/OnnxStack.Core/Extensions/TensorExtension.cs
@@ -397,7 +397,7 @@ private static DenseTensor ConcatenateAxis1(DenseTensor tensor1, D
// Copy data from the second tensor
for (int i = 0; i < dimensions[0]; i++)
- for (int j = 0; j < tensor1.Dimensions[1]; j++)
+ for (int j = 0; j < tensor2.Dimensions[1]; j++)
concatenatedTensor[i, j + tensor1.Dimensions[1]] = tensor2[i, j];
return concatenatedTensor;
diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs
index d0998e8..643e514 100644
--- a/OnnxStack.Core/Image/OnnxImage.cs
+++ b/OnnxStack.Core/Image/OnnxImage.cs
@@ -64,6 +64,7 @@ public OnnxImage(DenseTensor imageTensor, ImageNormalizeType normalizeTyp
{
var height = imageTensor.Dimensions[2];
var width = imageTensor.Dimensions[3];
+ var hasTransparency = imageTensor.Dimensions[1] == 4;
_imageData = new Image(width, height);
for (var y = 0; y < height; y++)
{
@@ -74,14 +75,16 @@ public OnnxImage(DenseTensor imageTensor, ImageNormalizeType normalizeTyp
_imageData[x, y] = new Rgba32(
DenormalizeZeroToOneToByte(imageTensor, 0, y, x),
DenormalizeZeroToOneToByte(imageTensor, 1, y, x),
- DenormalizeZeroToOneToByte(imageTensor, 2, y, x));
+ DenormalizeZeroToOneToByte(imageTensor, 2, y, x),
+ hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
}
else
{
_imageData[x, y] = new Rgba32(
DenormalizeOneToOneToByte(imageTensor, 0, y, x),
DenormalizeOneToOneToByte(imageTensor, 1, y, x),
- DenormalizeOneToOneToByte(imageTensor, 2, y, x));
+ DenormalizeOneToOneToByte(imageTensor, 2, y, x),
+ hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
}
}
}
@@ -337,6 +340,7 @@ private DenseTensor NormalizeToZeroToOne(ReadOnlySpan dimensions)
var width = dimensions[3];
var height = dimensions[2];
var channels = dimensions[1];
+ var hasTransparency = channels == 4;
var imageArray = new DenseTensor(new[] { 1, channels, height, width });
_imageData.ProcessPixelRows(img =>
{
@@ -348,6 +352,8 @@ private DenseTensor NormalizeToZeroToOne(ReadOnlySpan dimensions)
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f);
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f);
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f);
+ if (hasTransparency)
+ imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f);
}
}
});
@@ -366,6 +372,7 @@ private DenseTensor NormalizeToOneToOne(ReadOnlySpan dimensions)
var width = dimensions[3];
var height = dimensions[2];
var channels = dimensions[1];
+ var hasTransparency = channels == 4;
var imageArray = new DenseTensor(new[] { 1, channels, height, width });
_imageData.ProcessPixelRows(img =>
{
@@ -377,6 +384,8 @@ private DenseTensor NormalizeToOneToOne(ReadOnlySpan dimensions)
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f;
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f;
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f;
+ if (hasTransparency)
+ imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f) * 2.0f - 1.0f;
}
}
});
diff --git a/OnnxStack.Core/Video/OnnxVideo.cs b/OnnxStack.Core/Video/OnnxVideo.cs
index 8569dd6..ee8f3e9 100644
--- a/OnnxStack.Core/Video/OnnxVideo.cs
+++ b/OnnxStack.Core/Video/OnnxVideo.cs
@@ -137,9 +137,9 @@ public void Resize(int height, int width)
/// The filename.
/// The cancellation token.
///
- public Task SaveAsync(string filename, CancellationToken cancellationToken = default)
+ public Task SaveAsync(string filename, bool preserveTransparency = false, CancellationToken cancellationToken = default)
{
- return VideoHelper.WriteVideoFramesAsync(this, filename, cancellationToken);
+ return VideoHelper.WriteVideoFramesAsync(this, filename, preserveTransparency, cancellationToken);
}
diff --git a/OnnxStack.Core/Video/VideoHelper.cs b/OnnxStack.Core/Video/VideoHelper.cs
index a574837..b8c6721 100644
--- a/OnnxStack.Core/Video/VideoHelper.cs
+++ b/OnnxStack.Core/Video/VideoHelper.cs
@@ -32,9 +32,9 @@ public static void SetConfiguration(OnnxStackConfig configuration)
/// The onnx video.
/// The filename.
/// The cancellation token.
- public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filename, CancellationToken cancellationToken = default)
+ public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filename, bool preserveTransparency = false, CancellationToken cancellationToken = default)
{
- await WriteVideoFramesAsync(onnxVideo.Frames, filename, onnxVideo.FrameRate, onnxVideo.AspectRatio, cancellationToken);
+ await WriteVideoFramesAsync(onnxVideo.Frames, filename, onnxVideo.FrameRate, onnxVideo.AspectRatio, preserveTransparency, cancellationToken);
}
@@ -45,11 +45,11 @@ public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filen
/// The filename.
/// The frame rate.
/// The cancellation token.
- public static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate = 15, CancellationToken cancellationToken = default)
+ public static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate = 15, bool preserveTransparency = false, CancellationToken cancellationToken = default)
{
var firstImage = onnxImages.First();
var aspectRatio = (double)firstImage.Width / firstImage.Height;
- await WriteVideoFramesAsync(onnxImages, filename, frameRate, aspectRatio, cancellationToken);
+ await WriteVideoFramesAsync(onnxImages, filename, frameRate, aspectRatio, preserveTransparency, cancellationToken);
}
@@ -61,12 +61,12 @@ public static async Task WriteVideoFramesAsync(IEnumerable onnxImages
/// The frame rate.
/// The aspect ratio.
/// The cancellation token.
- private static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate, double aspectRatio, CancellationToken cancellationToken = default)
+ private static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate, double aspectRatio, bool preserveTransparency, CancellationToken cancellationToken = default)
{
if (File.Exists(filename))
File.Delete(filename);
- using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio))
+ using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio, preserveTransparency))
{
// Start FFMPEG
videoWriter.Start();
@@ -91,12 +91,12 @@ private static async Task WriteVideoFramesAsync(IEnumerable onnxImage
/// The frame rate.
/// The aspect ratio.
/// The cancellation token.
- public static async Task WriteVideoStreamAsync(VideoInfo videoInfo, IAsyncEnumerable videoStream, string filename, CancellationToken cancellationToken = default)
+ public static async Task WriteVideoStreamAsync(VideoInfo videoInfo, IAsyncEnumerable videoStream, string filename, bool preserveTransparency = false, CancellationToken cancellationToken = default)
{
if (File.Exists(filename))
File.Delete(filename);
- using (var videoWriter = CreateWriter(filename, videoInfo.FrameRate, videoInfo.AspectRatio))
+ using (var videoWriter = CreateWriter(filename, videoInfo.FrameRate, videoInfo.AspectRatio, preserveTransparency))
{
// Start FFMPEG
videoWriter.Start();
@@ -323,11 +323,13 @@ private static Process CreateReader(string inputFile, float fps)
/// The FPS.
/// The aspect ratio.
///
- private static Process CreateWriter(string outputFile, float fps, double aspectRatio)
+ private static Process CreateWriter(string outputFile, float fps, double aspectRatio, bool preserveTransparency)
{
var ffmpegProcess = new Process();
+ var codec = preserveTransparency ? "png" : "libx264";
+ var format = preserveTransparency ? "yuva420p" : "yuv420p";
ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath;
- ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v libx264 -movflags +faststart -vf format=yuv420p -aspect {aspectRatio} {outputFile}";
+ ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v {codec} -movflags +faststart -vf format={format} -aspect {aspectRatio} {outputFile}";
ffmpegProcess.StartInfo.RedirectStandardInput = true;
ffmpegProcess.StartInfo.UseShellExecute = false;
ffmpegProcess.StartInfo.CreateNoWindow = true;
diff --git a/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs
new file mode 100644
index 0000000..7cd69b8
--- /dev/null
+++ b/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs
@@ -0,0 +1,205 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Image;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Video;
+using OnnxStack.FeatureExtractor.Common;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace OnnxStack.FeatureExtractor.Pipelines
+{
+ public class BackgroundRemovalPipeline
+ {
+ private readonly string _name;
+ private readonly ILogger _logger;
+ private readonly FeatureExtractorModel _model;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The name.
+ /// The model.
+ /// The logger.
+ public BackgroundRemovalPipeline(string name, FeatureExtractorModel model, ILogger logger = default)
+ {
+ _name = name;
+ _logger = logger;
+ _model = model;
+ }
+
+
+ ///
+ /// Gets the name.
+ ///
+ ///
+ public string Name => _name;
+
+
+ ///
+ /// Loads the model.
+ ///
+ ///
+ public Task LoadAsync()
+ {
+ return _model.LoadAsync();
+ }
+
+
+ ///
+ /// Unloads the models.
+ ///
+ public async Task UnloadAsync()
+ {
+ await Task.Yield();
+ _model?.Dispose();
+ }
+
+
+ ///
+ /// Generates the background removal image result
+ ///
+ /// The input image.
+ ///
+ public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Removing video background...");
+ var result = await RunInternalAsync(inputImage, cancellationToken);
+ _logger?.LogEnd("Removing video background complete.", timestamp);
+ return result;
+ }
+
+
+ ///
+ /// Generates the background removal video result
+ ///
+ /// The input video.
+ ///
+ public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Removing video background...");
+ var videoFrames = new List();
+ foreach (var videoFrame in video.Frames)
+ {
+ videoFrames.Add(await RunInternalAsync(videoFrame, cancellationToken));
+ }
+ _logger?.LogEnd("Removing video background complete.", timestamp);
+ return new OnnxVideo(video.Info with
+ {
+ Height = videoFrames[0].Height,
+ Width = videoFrames[0].Width,
+ }, videoFrames);
+ }
+
+
+ ///
+ /// Generates the background removal video stream
+ ///
+ /// The image frames.
+ /// The cancellation token.
+ ///
+ public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Extracting video stream features...");
+ await foreach (var imageFrame in imageFrames)
+ {
+ yield return await RunInternalAsync(imageFrame, cancellationToken);
+ }
+ _logger?.LogEnd("Extracting video stream features complete.", timestamp);
+ }
+
+
+ ///
+ /// Runs the pipeline
+ ///
+ /// The input image.
+ /// The cancellation token.
+ ///
+ private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
+ {
+ var souceImageTenssor = await inputImage.GetImageTensorAsync(_model.SampleSize, _model.SampleSize, ImageNormalizeType.ZeroToOne);
+ var metadata = await _model.GetMetadataAsync();
+ cancellationToken.ThrowIfCancellationRequested();
+ var outputShape = new[] { 1, _model.Channels, _model.SampleSize, _model.SampleSize };
+ var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..];
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(souceImageTenssor);
+ inferenceParameters.AddOutputBuffer(outputBuffer);
+
+ var results = await _model.RunInferenceAsync(inferenceParameters);
+ using (var result = results.First())
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ var imageTensor = AddAlphaChannel(souceImageTenssor, result.GetTensorDataAsSpan());
+ return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne);
+ }
+ }
+ }
+
+
+ ///
+ /// Adds an alpha channel to the RGB tensor.
+ ///
+ /// The source image.
+ /// The alpha channel.
+ ///
+ private static DenseTensor AddAlphaChannel(DenseTensor sourceImage, ReadOnlySpan alphaChannel)
+ {
+ var resultTensor = new DenseTensor(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] });
+ sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span);
+ alphaChannel.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span);
+ return resultTensor;
+ }
+
+
+ ///
+ /// Creates the pipeline from a FeatureExtractorModelSet.
+ ///
+ /// The model set.
+ /// The logger.
+ ///
+ public static BackgroundRemovalPipeline CreatePipeline(FeatureExtractorModelSet modelSet, ILogger logger = default)
+ {
+ var model = new FeatureExtractorModel(modelSet.FeatureExtractorConfig.ApplyDefaults(modelSet));
+ return new BackgroundRemovalPipeline(modelSet.Name, model, logger);
+ }
+
+
+ ///
+ /// Creates the pipeline from the specified file.
+ ///
+ /// The model file.
+ /// The device identifier.
+ /// The execution provider.
+ /// The logger.
+ ///
+ public static BackgroundRemovalPipeline CreatePipeline(string modelFile, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ {
+ var name = Path.GetFileNameWithoutExtension(modelFile);
+ var configuration = new FeatureExtractorModelSet
+ {
+ Name = name,
+ IsEnabled = true,
+ DeviceId = deviceId,
+ ExecutionProvider = executionProvider,
+ FeatureExtractorConfig = new FeatureExtractorModelConfig
+ {
+ OnnxModelPath = modelFile,
+ SampleSize = sampleSize,
+ Normalize = false,
+ Channels = 1
+ }
+ };
+ return CreatePipeline(configuration, logger);
+ }
+ }
+}
diff --git a/OnnxStack.FeatureExtractor/README.md b/OnnxStack.FeatureExtractor/README.md
index a09360b..35ebc2d 100644
--- a/OnnxStack.FeatureExtractor/README.md
+++ b/OnnxStack.FeatureExtractor/README.md
@@ -14,6 +14,10 @@
### OpenPose (TODO)
* https://huggingface.co/axodoxian/controlnet_onnx/resolve/main/annotators/openpose.onnx
+### Background Removal
+* https://huggingface.co/briaai/RMBG-1.4/resolve/main/onnx/model.onnx
+
+
# Image Example
```csharp
// Load Input Image