-
Notifications
You must be signed in to change notification settings - Fork 481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor sorting operations #1488
Conversation
944c1c3
to
12c0f67
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1488 +/- ##
==========================================
+ Coverage 85.80% 85.86% +0.06%
==========================================
Files 650 659 +9
Lines 72524 74082 +1558
==========================================
+ Hits 62229 63612 +1383
- Misses 10295 10470 +175 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a ticket to remove the async
version of those methods, since it adds complexity and is only viable because we don't have a wgpu implementation yet. We could call read_sync
or an equivalent in the default implementation.
pub async fn sort(self, dim: usize) -> Tensor<B, D> { | ||
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await) | ||
} | ||
|
||
/// Sort the elements by value in ascending order along a given dimension. | ||
/// Also returns the indices. | ||
/// | ||
/// This sort is unstable (i.e., may reorder equal elements). | ||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] | ||
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) { | ||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim)); | ||
let (values, indices) = | ||
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ false).await; | ||
(Tensor::new(values), Tensor::new(indices)) | ||
} | ||
|
||
/// Returns the indices that sort the elements by value in ascending order along a given dimension. | ||
/// | ||
/// This sort is unstable (i.e., may reorder equal elements). | ||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] | ||
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we put those into numeric
to avoid duplication? With newer version of Rust, we can add async
in traits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I guess I should update my rust version, it wouldn't let me do it initially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind it was a warning not an error 🤦
use of `async fn` in public traits is discouraged as auto trait bounds cannot be specified
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will leave it like this for now to avoid the warnings. This will be temporary anyways with #1490.
pub trait ElementComparison { | ||
/// Returns and [Ordering] between `self` and `other`. | ||
fn cmp(&self, other: &Self) -> Ordering; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect 🔥
Checklist
run-checks all
script has been executed.Related Issues/PRs
Closes #1369
Unblocks easy implementation for
topk
issue #1421Changes
Added
sort(dim)
,sort_with_indices(dim)
andargsort(dim)
numeric tensor ops.Also added
sort_descending(dim)
,sort_descending_with_indices(dim)
andargsort_descending(dim)
for sorting in descending order.ElementComparison
trait forElement
so elements can return an ordering (as required bysort
methods)tch
backendTODO
run_par
,iter_par
anditer_range_par
macros defined in burn-ndarrayTesting
Forward and backward unit tests.