Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for repeat #278

Merged
merged 8 commits into from
Dec 27, 2023
Merged

Add support for repeat #278

merged 8 commits into from
Dec 27, 2023

Conversation

Bahaatbb
Copy link
Contributor

Proposed changes

resolves issue #258

Add op:

  • mx.repeat

same behavior as np.repeat

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@dc-dc-dc dc-dc-dc mentioned this pull request Dec 26, 2023
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! I left a few minor comments. But mainly, the core implementation needs to change a bit as the way you have it, the number of arrays/ops scales with the size of the array which won't work well. I left a comment with an example of what I mean. Let me know if it's not clear.

mlx/ops.cpp Outdated
@@ -1,11 +1,10 @@
// Copyright © 2023 Apple Inc.

#include "mlx/ops.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why move this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, this seems like a slip

@@ -2233,3 +2233,40 @@ TEST_CASE("test quantize dequantize") {
CHECK(max_diff <= 127.0 / (1 << i));
}
}

TEST_CASE("repeat test with axis") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just call this test repeat (since it also checks without axis)

mlx/ops.cpp Outdated
}

array repeat(const array& arr, int repeats, StreamOrDevice s) {
return flatten(repeat(arr, repeats, arr.ndim() - 1, s));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the stream to the call to flatten. Also I would flatten first then repeat (with axis 0)..since you have to flatten anyway seems slightly more efficient to flatten first.

mlx/ops.cpp Outdated
Comment on lines 734 to 745
std::vector<array> arrays_to_concat;
arrays_to_concat.reserve(repeats * arr.shape(axis));

for (int i = 0; i < arr.shape(axis); ++i) {
std::vector<int> start_indices(arr.ndim(), 0);
std::vector<int> stop_indices = arr.shape();
start_indices[axis] = i;
stop_indices[axis] = i + 1;
for (int j = 0; j < repeats; ++j) {
arrays_to_concat.push_back(slice(arr, start_indices, stop_indices, s));
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be implemented with a reshape and a concatenate that is order repeats and critically does not scale with the shape of the repeating axis. For example if you are repeating:

a = array([[0, 1], [2, 3]])

along axis 0 (in numpy like psuedo-code) do:

concatenate([a]*repeats, -1).reshape(-1, a.shape[1])

and along axis 1 do:

stack([a]*repeats, -1).reshape(a.shape[0], -1)

Copy link
Contributor Author

@Bahaatbb Bahaatbb Dec 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey just to wrap my head around this, do you mean I should handle the cases of 0 and 1 axis repetitions alone and the other axis with the same logic or for the other axis with only a reshape and a concatenate that is order repeats

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, what I'm saying is we should avoid ever making repeats * array.shape(axis) subarrays to concatenate using something like the strategy I outlined above for a 2D array which just makes repeats arrays. I think it should generalize to ND arrays but let me know if I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we definitely can generalize this for ND array using something like this for axis 0 or 1

def repeat(arr, repeats, axis):
    new_shape = np.array(arr.shape)
    new_shape[axis] *= repeats
    if axis == 0:
        repeated = np.concatenate([arr]*repeats, axis=-(len(arr.shape)-1)).flatten()
        return repeated.reshape(new_shape)
    elif axis == 1:
        repeated = np.stack([arr for _ in range(repeats)], axis=-(len(arr.shape)-1)).flatten()
        return repeated.reshape(new_shape)

but what I'm trying to understand is if you think there is a better way for other than making slices for axis bigger than 0 or 1 for bigger arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea exactly, I think it should work for any axis of any n-darray. You have to concatenate along the correct dimension followed by a reshape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey please check the new commit it should be exactly what you suggested

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better thanks!! I left a few more comments. I think we should be good to go after you address them.

mlx/ops.cpp Outdated
repeated_arrays.push_back(expand_dims(arr, -1, s));
}
array repeated =
flatten(concatenate(repeated_arrays, -1 * concat_axes, s), s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you flatten the array here?

Also I think it's cleaner if you replace -1 * concate_axes with axis + 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you flatten the array here?

Oh, thank you for noticing we don't really need it here,
I first implement it in python and didn't realize that I added the dim by hand, but here we have expand_dims.
Fixed it, what a keen eye 🥇

R"pbdoc(
repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array

Repeate an array along a specified axis
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit period after this.

Args:
array (array): Input array.
repeats (int): The number of repetitions for each element.
axis (int, optional): The axis in which to repeat the array along. Defaults to ``None``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more about the default behavior here?

mlx/ops.cpp Outdated
Comment on lines 725 to 729
if (repeats <= 0) {
std::vector<int> new_shape(arr.shape());
new_shape[axis] = repeats > 0 ? repeats : 0;
return zeros(new_shape, arr.dtype(), s);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Throw on negative repeats (as in NumPy) and add a test case to check that it throws.
  • For repeats==0 return an empty array with just one dimensions. Which I think you can do with array({}, arr.dtype())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this. I think you added by mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am not sure what that is, but always happy to remove things 😅

@Bahaatbb Bahaatbb requested a review from awni December 27, 2023 19:36
mlx/ops.cpp Outdated
axis = normalize_axis(axis, arr.ndim());

if (repeats < 0) {
throw std::invalid_argument("Number of repeats cannot be negative");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more nit: add "[repeat]" to the error message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 👍

repeats (int): The number of repetitions for each element.
axis (int, optional): The axis in which to repeat the array along. If
unspecified it uses the flattened array of the input and repeates
along the 0 axis.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "along axis 0"

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Looks great. Could you fix the last two nits and then I will merge it?

@awni awni merged commit ff2b58e into ml-explore:main Dec 27, 2023
@awni awni mentioned this pull request Jan 2, 2024
Jyun1998 pushed a commit to Jyun1998/mlx that referenced this pull request Jan 7, 2024
* add repeat function

* fix styling

* optimizing repeat

* fixed minor issues

* not sure why that folder is there xD

* fixed now for sure

* test repeat not repeat test

* Fixed

---------

Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants