Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather #63

Merged
merged 2 commits into from
Apr 16, 2022
Merged

Gather #63

merged 2 commits into from
Apr 16, 2022

Conversation

dfdx
Copy link
Collaborator

@dfdx dfdx commented Apr 16, 2022

It has been the hardest and the most time-consuming operation so far, so I really hope this time I've done it correctly 😄

As a side effect, I've added the take() function, which is similar to numpy.take() and seems to be useful e.g. for taking slices of input tensors in NLP models. So we may want to move it to a more general-purpose package such as NNlib once our implementation is battle tested in real graphs.

dfdx added 2 commits April 2, 2022 19:09
…s of data. Note that ONNX's Gather is different from gather() function found in many deep learning frameworks (including NNlib)
@dfdx dfdx merged commit 5ae53c6 into master Apr 16, 2022
@dfdx dfdx deleted the gather branch April 16, 2022 20:47
@ToucheSir
Copy link
Member

Sounds like a tour de force! Is there a summary of how take differs from NNlib.gather?

@dfdx
Copy link
Collaborator Author

dfdx commented Apr 16, 2022

In short, it's just a different operation.

NNlib.gather can work on individual indices, e.g. you can write:

x = rand(3, 4)
@assert NNlib.gather(y, [(1, 1), (1, 2), (2, 2)]) == [x[1, 1], x[1, 2], x[2, 2]]

For example, it's useful to find implement an NLL loss:

y = softmax(rand(3, 4))   # classifier predictions
c = [1, 3, 1, 2]                 # correct classes
idxs = [(c, i) for (i, c) in enumerate(c)]
loss = mean(NNlib.gather(y, idxs))

NNlib.gather is also the reverse of NNlib.scatter and equivalent of ONNX's GatherND or maybe GatherElements - I haven't figured it out yet.

take() works on slices along a specific axis (dim in Julia version). Say, data is 3D and dim = 2. Then for every entry i in idxs it takes data[:, i, :]. It's like extracting individual slices from a ND array and re-organizing them according to the structure in idxs. ONNX's docs are pretty confusing in this regard, but NumPy explain it well.

take() also reduces to getindex() when data is 1D and length(idxs) == 1 - it turns out this is how ONNX treats expressions like size(X)[1].

I tried t to implement take() using NNlib.gather() at some point, but I spent most of the time figuring out what exactly ONNX's Gather does in each case, and when I got the idea, implementing it using array views turned to be easier than adjusting the indices (but it may still be an option in future).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants