-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for repeat #278
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! I left a few minor comments. But mainly, the core implementation needs to change a bit as the way you have it, the number of arrays/ops scales with the size of the array which won't work well. I left a comment with an example of what I mean. Let me know if it's not clear.
mlx/ops.cpp
Outdated
@@ -1,11 +1,10 @@ | |||
// Copyright © 2023 Apple Inc. | |||
|
|||
#include "mlx/ops.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why move this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, this seems like a slip
tests/ops_tests.cpp
Outdated
@@ -2233,3 +2233,40 @@ TEST_CASE("test quantize dequantize") { | |||
CHECK(max_diff <= 127.0 / (1 << i)); | |||
} | |||
} | |||
|
|||
TEST_CASE("repeat test with axis") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just call this test repeat
(since it also checks without axis)
mlx/ops.cpp
Outdated
} | ||
|
||
array repeat(const array& arr, int repeats, StreamOrDevice s) { | ||
return flatten(repeat(arr, repeats, arr.ndim() - 1, s)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the stream to the call to flatten
. Also I would flatten first then repeat (with axis 0)..since you have to flatten anyway seems slightly more efficient to flatten first.
mlx/ops.cpp
Outdated
std::vector<array> arrays_to_concat; | ||
arrays_to_concat.reserve(repeats * arr.shape(axis)); | ||
|
||
for (int i = 0; i < arr.shape(axis); ++i) { | ||
std::vector<int> start_indices(arr.ndim(), 0); | ||
std::vector<int> stop_indices = arr.shape(); | ||
start_indices[axis] = i; | ||
stop_indices[axis] = i + 1; | ||
for (int j = 0; j < repeats; ++j) { | ||
arrays_to_concat.push_back(slice(arr, start_indices, stop_indices, s)); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be implemented with a reshape and a concatenate that is order repeats
and critically does not scale with the shape of the repeating axis. For example if you are repeating:
a = array([[0, 1], [2, 3]])
along axis 0
(in numpy like psuedo-code) do:
concatenate([a]*repeats, -1).reshape(-1, a.shape[1])
and along axis 1
do:
stack([a]*repeats, -1).reshape(a.shape[0], -1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey just to wrap my head around this, do you mean I should handle the cases of 0 and 1 axis repetitions alone and the other axis with the same logic or for the other axis with only a reshape and a concatenate that is order repeats
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, what I'm saying is we should avoid ever making repeats * array.shape(axis)
subarrays to concatenate using something like the strategy I outlined above for a 2D array which just makes repeats
arrays. I think it should generalize to ND arrays but let me know if I'm missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we definitely can generalize this for ND array using something like this for axis 0 or 1
def repeat(arr, repeats, axis):
new_shape = np.array(arr.shape)
new_shape[axis] *= repeats
if axis == 0:
repeated = np.concatenate([arr]*repeats, axis=-(len(arr.shape)-1)).flatten()
return repeated.reshape(new_shape)
elif axis == 1:
repeated = np.stack([arr for _ in range(repeats)], axis=-(len(arr.shape)-1)).flatten()
return repeated.reshape(new_shape)
but what I'm trying to understand is if you think there is a better way for other than making slices for axis bigger than 0 or 1 for bigger arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea exactly, I think it should work for any axis of any n-darray. You have to concatenate along the correct dimension followed by a reshape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey please check the new commit it should be exactly what you suggested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better thanks!! I left a few more comments. I think we should be good to go after you address them.
mlx/ops.cpp
Outdated
repeated_arrays.push_back(expand_dims(arr, -1, s)); | ||
} | ||
array repeated = | ||
flatten(concatenate(repeated_arrays, -1 * concat_axes, s), s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you flatten the array here?
Also I think it's cleaner if you replace -1 * concate_axes
with axis + 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you flatten the array here?
Oh, thank you for noticing we don't really need it here,
I first implement it in python and didn't realize that I added the dim by hand, but here we have expand_dims
.
Fixed it, what a keen eye 🥇
python/src/ops.cpp
Outdated
R"pbdoc( | ||
repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array | ||
|
||
Repeate an array along a specified axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit period after this.
python/src/ops.cpp
Outdated
Args: | ||
array (array): Input array. | ||
repeats (int): The number of repetitions for each element. | ||
axis (int, optional): The axis in which to repeat the array along. Defaults to ``None``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more about the default behavior here?
mlx/ops.cpp
Outdated
if (repeats <= 0) { | ||
std::vector<int> new_shape(arr.shape()); | ||
new_shape[axis] = repeats > 0 ? repeats : 0; | ||
return zeros(new_shape, arr.dtype(), s); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Throw on negative repeats (as in NumPy) and add a test case to check that it throws.
- For
repeats==0
return an empty array with just one dimensions. Which I think you can do witharray({}, arr.dtype())
_deps/doctest-src
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this. I think you added by mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
am not sure what that is, but always happy to remove things 😅
mlx/ops.cpp
Outdated
axis = normalize_axis(axis, arr.ndim()); | ||
|
||
if (repeats < 0) { | ||
throw std::invalid_argument("Number of repeats cannot be negative"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more nit: add "[repeat]" to the error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 👍
python/src/ops.cpp
Outdated
repeats (int): The number of repetitions for each element. | ||
axis (int, optional): The axis in which to repeat the array along. If | ||
unspecified it uses the flattened array of the input and repeates | ||
along the 0 axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "along axis 0
"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! Looks great. Could you fix the last two nits and then I will merge it?
* add repeat function * fix styling * optimizing repeat * fixed minor issues * not sure why that folder is there xD * fixed now for sure * test repeat not repeat test * Fixed --------- Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
Proposed changes
resolves issue #258
Add op:
mx.repeat
same behavior as np.repeat
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes