Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/Split Operator #2490

Merged
merged 35 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
242094e
Create high-level function signatures for `split()` and `split_with_s…
agelas Nov 14, 2024
64df3d9
Add checks for `split()` and `split_with_sizes()`
agelas Nov 14, 2024
6e2ca04
Remove pointless negative checks, usize is an int
agelas Nov 14, 2024
21d06fe
Add element-type specific function signatures
agelas Nov 14, 2024
4e30256
Implement split and split_with_sizes for tch backend
agelas Nov 14, 2024
d4cf026
Add split ops to burn-autodiff
agelas Nov 14, 2024
3cfdd32
Add split ops to list of basic operations in burn-book
agelas Nov 14, 2024
5a14745
Add fallbacks for split and split_with_size since a few backends dont…
agelas Nov 14, 2024
185703f
Cleanup split_with_sizes and fix arguments
agelas Nov 14, 2024
2b1c152
Clippy fixes
agelas Nov 14, 2024
01272f1
Finish todos in tensor.rs
agelas Nov 14, 2024
334fa0b
Add documentation to split functions
agelas Nov 14, 2024
bd5343f
Small consistency/grammar fixes for chunk
agelas Nov 14, 2024
236b919
Punctuation fix
agelas Nov 14, 2024
c5d08ee
Small doc fixes and change to split_with_sizes
agelas Nov 14, 2024
824cb30
Add split functions to float, int, and bool BasicOps impls
agelas Nov 14, 2024
1deee50
Add split functions to int, bool, and quantized tensors
agelas Nov 14, 2024
a99a306
Minor doc + code fixes
agelas Nov 15, 2024
9ac5fc0
Add documentation
agelas Nov 15, 2024
3093b5c
Start adding tests
agelas Nov 15, 2024
36a6ff5
Try to get this to compile
agelas Nov 15, 2024
3e790df
Fix docs and add examples
agelas Nov 15, 2024
d1624f0
Get initial tests to a working state
agelas Nov 15, 2024
d647a9e
Fix doc tests
agelas Nov 15, 2024
54ef092
Add more split tests
agelas Nov 15, 2024
2aedd20
Add more tests for 3D tensor splitting
agelas Nov 15, 2024
b3ed77a
Check with full dims so no indexing panic
agelas Nov 15, 2024
1775c1a
Correct flow of check
agelas Nov 15, 2024
7ddab38
Specify which panic to expect in tests
agelas Nov 15, 2024
3d064f8
Add split_with_sizes tests
agelas Nov 15, 2024
9fb3aad
Small fix to chuk tests
agelas Nov 15, 2024
dfb571a
Resolve merge conflicts
agelas Nov 19, 2024
e98240a
Fix redundant closure
agelas Nov 19, 2024
86f5d31
Fix redunant closure even more
agelas Nov 19, 2024
fe8bd67
Resolve merge conflict
agelas Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
Expand Down Expand Up @@ -195,7 +197,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) |
| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) |
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
| `tensor.greater(other)` | `tensor.gt(other)` |
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_chunk(tensor, chunks, dim)
}

fn bool_split(tensor: BoolTensor<B>, split_size: usize, dim: usize) -> Vec<BoolTensor<B>> {
B::bool_split(tensor, split_size, dim)
}

fn bool_split_with_sizes(
tensor: BoolTensor<B>,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<BoolTensor<B>> {
B::bool_split_with_sizes(tensor, split_sizes, dim)
}

fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
B::bool_permute(tensor, axes)
}
Expand Down
16 changes: 16 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_chunk(tensor, chunks, dim)
}

fn int_split(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
split_size: usize,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive> {
B::int_split(tensor, split_size, dim)
}

fn int_split_with_sizes(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive> {
B::int_split_with_sizes(tensor, split_sizes, dim)
}

fn int_random(
shape: Shape,
distribution: Distribution,
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,29 @@ impl TchOps {
.collect()
}

pub fn split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
tensor
.tensor
.split(split_size as i64, dim as i64)
.into_iter()
.map(TchTensor::new)
.collect()
}

pub fn split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
let split_sizes_i64: Vec<i64> = split_sizes.iter().map(|&s| s as i64).collect();
tensor
.tensor
.split_with_sizes(split_sizes_i64, dim as i64)
.into_iter()
.map(TchTensor::new)
.collect()
}

pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
tensor,
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tch/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
TchOps::chunk(tensor, chunks, dim)
}

fn bool_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
TchOps::split(tensor, split_size, dim)
}

fn bool_split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
TchOps::split_with_sizes(tensor, split_sizes, dim)
}

fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
TchOps::permute(tensor, axes)
}
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,18 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
TchOps::chunk(tensor, chunks, dim)
}

fn int_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
TchOps::split(tensor, split_size, dim)
}

fn int_split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
TchOps::split_with_sizes(tensor, split_sizes, dim)
}

fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
match distribution {
Distribution::Default => {
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
TchOps::chunk(tensor, chunks, dim)
}

fn float_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
TchOps::split(tensor, split_size, dim)
}

fn float_split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
TchOps::split_with_sizes(tensor, split_sizes, dim)
}

fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::powf(lhs, rhs)
}
Expand Down
191 changes: 189 additions & 2 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ where
Self::new(narrow::<B, K>(self.primitive, dim, start, length))
}

/// Attempts to split the tensor along the given dimension into chunks.
/// Attempts to split the tensor into a specified number of chunks along a given dimension.
/// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
///
/// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
Expand Down Expand Up @@ -1247,6 +1247,91 @@ where
.collect()
}

/// Splits the tensor into chunks of a specified size along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// If the tensor size along the given dimension is not divisible by `split_size`,
/// then the last chunk will be smaller.
///
/// # Panics
///
/// If the specified dimension to split along is greater than the number of dimensions of the tensor.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// // Create a 1D tensor with 5 elements
/// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
/// // Split the tensor into chunks of size 2 along dimension 0
/// let chunks = tensor.split(2, 0);
/// // The result is a vector of tensors:
/// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
/// println!("{:?}", chunks);
/// }
/// ```
pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::split::<D>(
self.shape().dims.as_ref(),
split_size,
dim
));
K::split(self.primitive, split_size, dim)
.into_iter()
.map(Self::new)
.collect()
}

/// Splits the tensor into chunks with the specified sizes along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
/// in `split_sizes` must equal the size of the tensor along the specified dimension.
///
/// # Panics
///
/// If the specified dimension to split along is greater than the number of dimensions of the tensor or
/// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// // Create a 1D tensor with 5 elements
/// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
/// // Split the tensor into chunks with sizes [2, 3] along dimension 0
/// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
/// // The result is a vector of tensors:
/// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
/// println!("{:?}", chunks);
/// }
/// ```
pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
check!(TensorCheck::split_with_sizes::<D>(
self.shape().dims.as_ref(),
&split_sizes,
dim
));
K::split_with_sizes(self.primitive, split_sizes, dim)
.into_iter()
.map(Self::new)
.collect()
}

/// Tests if any element in the `tensor` evaluates to True.
///
/// # Arguments
Expand Down Expand Up @@ -2111,10 +2196,58 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// To split a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
/// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
/// which is more high-level and designed for public use.
fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;

/// Splits the tensor into chunks of a specified size along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// # Panics
///
/// If the dimension to split along is greater than the number of dimensions of the tensor.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
/// which is more high-level and designed for public use.
fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this version of split and chunk? Is it only that split takes the size of the new chunks while chunk takes the number of chunks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct- and then split_with_sizes goes a step further and lets you specify the size of each chunk.


/// Splits the tensor into chunks with the specified sizes along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
/// in `split_sizes` must equal the size of the tensor along the specified dimension.
///
/// # Panics
///
/// If the dimension to split along is greater than the number of dimensions of the tensor or
/// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
/// which is more high-level and designed for public use.
fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive>;

/// Equates the given tensors.
///
/// # Arguments
Expand Down Expand Up @@ -2437,6 +2570,36 @@ impl<B: Backend> BasicOps<B> for Float {
.collect(),
}
}

fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
.into_iter()
.map(TensorPrimitive::Float)
.collect(),
TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
.into_iter()
.map(TensorPrimitive::QFloat)
.collect(),
}
}

fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
.into_iter()
.map(TensorPrimitive::Float)
.collect(),
TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
.into_iter()
.map(TensorPrimitive::QFloat)
.collect(),
}
}
}

impl<B: Backend> BasicOps<B> for Int {
Expand Down Expand Up @@ -2536,6 +2699,18 @@ impl<B: Backend> BasicOps<B> for Int {
fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
B::int_chunk(tensor, chunks, dim)
}

fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
B::int_split(tensor, split_size, dim)
}

fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive> {
B::int_split_with_sizes(tensor, split_sizes, dim)
}
}

impl<B: Backend> BasicOps<B> for Bool {
Expand Down Expand Up @@ -2635,6 +2810,18 @@ impl<B: Backend> BasicOps<B> for Bool {
fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
B::bool_chunk(tensor, chunks, dim)
}

fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
B::bool_split(tensor, split_size, dim)
}

fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive> {
B::bool_split_with_sizes(tensor, split_sizes, dim)
}
}

/// Trait used for movedim arguments
Expand Down
Loading