-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement kornia-dnn with RTDETR detector #129
Open
edgarriba
wants to merge
6
commits into
main
Choose a base branch
from
dnn-module
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "kornia-dnn" | ||
authors.workspace = true | ||
description = "ONNX Deep Neural Network (DNN) library for Rust" | ||
edition.workspace = true | ||
homepage.workspace = true | ||
license.workspace = true | ||
publish = true | ||
repository.workspace = true | ||
rust-version.workspace = true | ||
version.workspace = true | ||
|
||
[features] | ||
ort-load-dynamic = ["ort/load-dynamic"] | ||
ort-cuda = ["ort/cuda"] | ||
|
||
[dependencies] | ||
kornia-core = { workspace = true } | ||
kornia-image = { workspace = true } | ||
ort = { version = "2.0.0-rc.4", default-features = false } | ||
thiserror = "1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#[derive(thiserror::Error, Debug)] | ||
pub enum DnnError { | ||
#[error("Please set the ORT_DYLIB_PATH environment variable to the path of the ORT dylib. Error: {0}")] | ||
OrtDylibError(String), | ||
|
||
#[error("Failed to create ORT session")] | ||
OrtError(#[from] ort::Error), | ||
|
||
#[error("Image error")] | ||
ImageError(#[from] kornia_image::ImageError), | ||
|
||
#[error("Tensor error")] | ||
TensorError(#[from] kornia_core::TensorError), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//! # Kornia DNN | ||
//! | ||
//! This module contains DNN (Deep Neural Network) related functionality. | ||
|
||
/// Error type for the dnn module. | ||
pub mod error; | ||
|
||
/// This module contains the RT-DETR model. | ||
pub mod rtdetr; | ||
|
||
// re-export ort::ExecutionProvider | ||
pub use ort::{CPUExecutionProvider, CUDAExecutionProvider, TensorRTExecutionProvider}; | ||
|
||
// TODO: put this in to some sort of structs pool module | ||
/// Represents a detected object in an image. | ||
#[derive(Debug)] | ||
pub struct Detection { | ||
/// The class label of the detected object. | ||
pub label: u32, | ||
/// The confidence score of the detection (typically between 0 and 1). | ||
pub score: f32, | ||
/// The x-coordinate of the top-left corner of the bounding box. | ||
pub x: f32, | ||
/// The y-coordinate of the top-left corner of the bounding box. | ||
pub y: f32, | ||
/// The width of the bounding box. | ||
pub w: f32, | ||
/// The height of the bounding box. | ||
pub h: f32, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
//! # RT-DETR | ||
//! | ||
//! This module contains the RT-DETR model. | ||
//! | ||
//! The RT-DETR model is a state-of-the-art object detection model. | ||
|
||
use std::{path::PathBuf, sync::Arc}; | ||
|
||
use crate::{error::DnnError, CPUExecutionProvider, Detection}; | ||
use kornia_core::{CpuAllocator, Tensor}; | ||
use kornia_image::Image; | ||
use ort::{ExecutionProviderDispatch, GraphOptimizationLevel, Session}; | ||
|
||
/// Builder for the RT-DETR detector. | ||
/// | ||
/// This struct provides a convenient way to configure and create an `RTDETRDetector` instance. | ||
pub struct RTDETRDetectorBuilder { | ||
/// Path to the RT-DETR model file. | ||
pub model_path: PathBuf, | ||
/// Number of threads to use for inference. | ||
pub num_threads: usize, | ||
/// Execution providers to use for inference. | ||
pub execution_providers: Vec<ExecutionProviderDispatch>, | ||
} | ||
|
||
impl RTDETRDetectorBuilder { | ||
/// Creates a new `RTDETRDetectorBuilder` with default settings. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `model_path` - Path to the RT-DETR model file. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `Result` containing the `RTDETRDetectorBuilder` if successful, or a `DnnError` if an error occurred. | ||
pub fn new(model_path: PathBuf) -> Result<Self, DnnError> { | ||
Ok(Self { | ||
model_path, | ||
num_threads: 4, | ||
execution_providers: vec![CPUExecutionProvider::default().build()], | ||
}) | ||
} | ||
|
||
/// Sets the number of threads to use for inference. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `num_threads` - The number of threads to use. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The updated `RTDETRDetectorBuilder` instance. | ||
pub fn with_num_threads(mut self, num_threads: usize) -> Self { | ||
self.num_threads = num_threads; | ||
self | ||
} | ||
|
||
/// Sets the execution providers to use for inference. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `execution_providers` - The execution providers to use. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The updated `RTDETRDetectorBuilder` instance. | ||
pub fn with_execution_providers( | ||
mut self, | ||
execution_providers: Vec<ExecutionProviderDispatch>, | ||
) -> Self { | ||
self.execution_providers = execution_providers; | ||
self | ||
} | ||
|
||
/// Builds and returns an `RTDETRDetector` instance. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `Result` containing the `RTDETRDetector` if successful, or a `DnnError` if an error occurred. | ||
pub fn build(self) -> Result<RTDETRDetector, DnnError> { | ||
RTDETRDetector::new(self.model_path, self.num_threads, self.execution_providers) | ||
} | ||
} | ||
|
||
/// RT-DETR object detector. | ||
/// | ||
/// This struct represents an instance of the RT-DETR object detection model. | ||
pub struct RTDETRDetector { | ||
session: Arc<Session>, | ||
} | ||
|
||
impl RTDETRDetector { | ||
edgarriba marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// TODO: default to hf hub | ||
/// Creates a new `RTDETRDetector` instance. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `model_path` - Path to the RT-DETR model file. | ||
/// * `num_threads` - Number of threads to use for inference. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `Result` containing the `RTDETRDetector` if successful, or a `DnnError` if an error occurred. | ||
/// | ||
/// Pre-requisites: | ||
/// - ORT_DYLIB_PATH environment variable must be set to the path of the ORT dylib. | ||
pub fn new( | ||
model_path: PathBuf, | ||
num_threads: usize, | ||
execution_providers: Vec<ExecutionProviderDispatch>, | ||
) -> Result<Self, DnnError> { | ||
// get the ort dylib path from the environment variable | ||
let dylib_path = | ||
std::env::var("ORT_DYLIB_PATH").map_err(|e| DnnError::OrtDylibError(e.to_string()))?; | ||
|
||
// set the ort dylib path | ||
ort::init_from(dylib_path).commit()?; | ||
|
||
// create the ort session | ||
let session = Session::builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Level3)? | ||
.with_intra_threads(num_threads)? | ||
.with_execution_providers(execution_providers)? | ||
.commit_from_file(model_path)?; | ||
|
||
let session = Arc::new(session); | ||
// TODO: perform a dummy run to warm up the model | ||
// let session_clone = session.clone(); | ||
// std::thread::spawn(move || -> Result<(), DnnError> { | ||
// let dummy_input = | ||
// ort::Tensor::from_array(([480, 640 * 3], vec![0.0f32; 480 * 640 * 3]))?; | ||
// session_clone.run(ort::inputs!["input" => dummy_input]?)?; | ||
// Ok(()) | ||
// }); | ||
|
||
Ok(Self { session }) | ||
} | ||
|
||
/// Runs object detection on the given image. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `image` - The input image as an `Image<u8, 3>`. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `Result` containing a vector of `Detection` objects if successful, or a `DnnError` if an error occurred. | ||
pub fn run(&self, image: &Image<u8, 3>) -> Result<Vec<Detection>, DnnError> { | ||
// TODO: explore pre-allocating memory for the image | ||
// cast and scale the image to f32 | ||
let mut image_hwc_f32 = Image::from_size_val(image.size(), 0.0f32)?; | ||
kornia_image::ops::cast_and_scale(image, &mut image_hwc_f32, 1.0 / 255.)?; | ||
|
||
// convert to HWC -> CHW | ||
let image_chw = image_hwc_f32.permute_axes([2, 0, 1]).as_contiguous(); | ||
|
||
// TODO: create a Tensor::insert_axis in kornia-rs | ||
let image_nchw = Tensor::from_shape_vec( | ||
[ | ||
1, | ||
image_chw.shape[0], | ||
image_chw.shape[1], | ||
image_chw.shape[2], | ||
], | ||
image_chw.into_vec(), | ||
CpuAllocator, | ||
)?; | ||
|
||
// make the ort tensor | ||
let ort_tensor = ort::Tensor::from_array((image_nchw.shape, image_nchw.into_vec()))?; | ||
|
||
// run the model | ||
let outputs = self.session.run(ort::inputs!["input" => ort_tensor]?)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @decahedron1 how could we pre-allocate the output tensor here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like the perfect use case for |
||
|
||
// extract the output tensor | ||
let (out_shape, out_ort) = outputs[0].try_extract_raw_tensor::<f32>()?; | ||
|
||
let out_tensor = Tensor::<f32, 3>::from_shape_vec( | ||
[ | ||
out_shape[0] as usize, | ||
out_shape[1] as usize, | ||
out_shape[2] as usize, | ||
], | ||
out_ort.to_vec(), | ||
CpuAllocator, | ||
)?; | ||
|
||
// parse the output tensor | ||
// we expect the output tensor to be a tensor of shape [1, N, 6] | ||
// where each element is a detection [label, score, x, y, w, h] | ||
let detections = out_tensor | ||
.as_slice() | ||
.chunks_exact(6) | ||
.map(|chunk| Detection { | ||
label: chunk[0] as u32, | ||
score: chunk[1], | ||
x: chunk[2], | ||
y: chunk[3], | ||
w: chunk[4], | ||
h: chunk[5], | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
Ok(detections) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "rtdetr" | ||
version = "0.1.0" | ||
authors = ["Edgar Riba <edgar.riba@gmail.com>"] | ||
license = "Apache-2.0" | ||
edition = "2021" | ||
publish = false | ||
|
||
[dependencies] | ||
clap = { version = "4.5.4", features = ["derive"] } | ||
ctrlc = "3.4.4" | ||
kornia = { workspace = true, features = [ | ||
"gstreamer", | ||
"ort-load-dynamic", | ||
"ort-cuda", | ||
] } | ||
rerun = "0.18" | ||
tokio = { version = "1" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
An example showing how to use the RTDETR model with the `kornia::dnn` module and the webcam with the `kornia::io` module with the ability to cancel the feed after a certain amount of time. This example will display the webcam feed in a [`rerun`](https://github.com/rerun-io/rerun) window. | ||
|
||
NOTE: This example requires the gstremer backend to be enabled. To enable the gstreamer backend, use the `gstreamer` feature flag when building the `kornia` crate and its dependencies. | ||
|
||
## Prerequisites | ||
|
||
Maily you need to download onnxruntime from: <https://github.com/microsoft/onnxruntime/releases> | ||
|
||
## Usage | ||
|
||
```bash | ||
Usage: rtdetr [OPTIONS] --model-path <MODEL_PATH> | ||
|
||
Options: | ||
-c, --camera-id <CAMERA_ID> [default: 0] | ||
-f, --fps <FPS> [default: 5] | ||
-m, --model-path <MODEL_PATH> | ||
-n, --num-threads <NUM_THREADS> [default: 8] | ||
-s, --score-threshold <SCORE_THRESHOLD> [default: 0.75] | ||
-h, --help Print help | ||
``` | ||
|
||
Example: | ||
|
||
```bash | ||
ORT_DYLIB_PATH=/path/to/libonnxruntime.so cargo run --bin rtdetr --release -- --camera-id 0 --model-path rtdetr.onnx --num-threads 8 --score-threshold 0.75 | ||
``` |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Execution providers can be registered through the environment so this is super minor but WDYT about an
execution_providers: Vec<ExecutionProviderDispatch>
field &with_execution_providers
to configure EPs specifically for the RTDETR session?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wanted to somehow make ort transparent to the user and avoid them to explicitly pass the
ort::ExecutionProvider
and have something customI'm still experimenting with the execution providers. What is the point of defining multiple as a Vec, just because of a fallback provider ?
I've played a bit with it and i noticed that cuda/tensorrt takes few seconds to run the first frames.
A couple of questions:
commit_from_file
-- the idea is that we will have a bunch of operators/models in our Kornia HF hub which can use thecommit_from_url
but somehow also let the user also to give a local onnx file. Any tips here ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about re-exporting EPs like
pub use ort::{CPUExecutionProvider, CUDAExecutionProvider, TensorRTExecutionProvider}
so users can still configure each EP's options instead of being limited to defaults? That should still keep things neat.Yes.
You could run it on 1 dummy frame inside the constructor, that should get the graph warmed up.
Yes, and for CUDA it is determining the most optimal cuDNN convolution kernels. By default it performs an exhaustive search which gets the best performance at the cost of significant warmup time - this can be configured with
CUDAExecutionProvider::with_conv_algorithm_search
.TensorRT graphs can theoretically be cached with
TensorRTExecutionProvider::with_engine_cache
but some users in the pyke Discord have reported that ONNX Runtime sometimes doesn't respect this option, and session creation can still take a few seconds despite using a cached engine. Your mileage may vary, though; I personally haven't been able to reproduce the issue, but it's just something to keep in mind.How about something like this? (very roughly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually one more idea would be to kinda automatically set based on feature flags?