Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch implementation of np.lexsort() in sort_csc util fucnction #7743

Closed
ebrahimpichka opened this issue Jul 14, 2023 · 3 comments · Fixed by #7775
Closed

PyTorch implementation of np.lexsort() in sort_csc util fucnction #7743

ebrahimpichka opened this issue Jul 14, 2023 · 3 comments · Fixed by #7775
Labels

Comments

@ebrahimpichka
Copy link
Contributor

ebrahimpichka commented Jul 14, 2023

🛠 Proposed Refactor

Hi,

I came upon this lexsort implementation in the sort_csc function in the following line:

And looks like an unofficial PyTorch implementation is proposed here:
https://discuss.pytorch.org/t/numpy-lexsort-equivalent-in-pytorch/47850/5?u=ebrahim.pichka

As below (with minor changes):

def torch_lexsort(keys, dim=-1):
    if len(keys) < 2:
        raise ValueError(f"keys must be at least 2 sequences, but {len(keys)=}.")

    
    idx = keys[0].argsort(dim=dim, stable=True)
    for k in keys[1:]:
        idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True))
    
    return idx

I also ran some tests on the proposed implementation as follows:

In [1]: import torch; import numpy as np

In [2]: N = 1000000

In [3]: a = np.random.rand(N); b = np.random.randint(N // 4, size=N)

In [4]: a_t = torch.tensor(a); b_t = torch.tensor(b)

In [5]: a_t_cu, b_t_cu = a_t.to(torch.device("cuda")), b_t.to(torch.device("cuda"))

In [6]: def torch_lexsort(keys, dim=-1):
   ...:     # defined as above

In [7]: %timeit -n 2 -r 20 np.lexsort([a, b])
302 ms ± 31.2 ms per loop (mean ± std. dev. of 20 runs, 2 loops each)

In [8]: %timeit -n 2 -r 20 torch_lexsort([a_t, b_t])
293 ms ± 35.1 ms per loop (mean ± std. dev. of 20 runs, 2 loops each)

In [9]: %timeit -n 20 -r 100 torch_lexsort([a_t_cu, b_t_cu])
The slowest run took 5.27 times longer than the fastest. This could mean that an intermediate result is being cached.
3.97 ms ± 334 µs per loop (mean ± std. dev. of 100 runs, 20 loops each)

In [10]: idx_np = np.lexsort([a, b]); idx_pt = torch_lexsort([a_t, b_t])

In [11]: (idx_np == idx_pt.numpy()).all()
Out[11]: True

Seems it helps with GPU support as well.

Thought it could replace the current implementation which detaches and recasts the tensor.

Suggest a potential alternative/fix

Mentioned Above

@rusty1s
Copy link
Member

rusty1s commented Jul 19, 2023

Amazing. Do you wanna send a PR for this? Otherwise, I can go ahead and make the corresponding change as well.

@ebrahimpichka
Copy link
Contributor Author

@rusty1s I'd like to make a PR.

@rusty1s
Copy link
Member

rusty1s commented Jul 19, 2023

Super :)

rusty1s added a commit that referenced this issue Jul 19, 2023
resolves #7743
Added Pytorch implementation of numpy lexsort.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants