-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faulty broadcasting when encountering successive singleton dimensions #382
Comments
Explicitly broadcasting indeed leads to an overhead:
So do you have any quick fixes? |
Hello @Xp-speit2018 , import torch
import pykeops
device = "cuda"
M, nh = 64, 32 # batch dimensions
n, c = 1, 256 # number dimensions
d = 2 # spatial dimension
X = torch.randn(M, nh, n, d, device=device)
C = torch.randn(M, 1, c, d, device=device)
print("X.shape", X.shape)
print("C.shape", C.shape)
print("cdist shape", torch.cdist(X, C, p=2).shape)
codes_torch = torch.cdist(X, C, p=2).argmin(dim=3)
print("codes_torch.shape", codes_torch.shape)
print()
# define X_i as symbolic LazyTensor with ind=0, dim=2 and cat=0 (i.e. i-indexed) :
X_i = pykeops.torch.LazyTensor((0,2,0))
C_j = pykeops.torch.LazyTensor(C.unsqueeze(2))
D_ij = ((X_i - C_j) ** 2).sum(4)
# now we input X for the actual computation:
codes_keops = D_ij.argmin(dim=3)(X).squeeze(3)
# (explanation: since D_ij.argmin(dim=3) is symbolic, it outputs a function instead of a tensor,
# and we call it with the actual tensor X as input)
print("codes_keops.shape", codes_keops.shape)
print()
print("Equal?", torch.allclose(codes_torch, codes_keops))
print("Match percentage", (codes_torch == codes_keops).float().mean().item()) |
Hi @joanglaunes, I’m truly grateful for your time and effort in addressing this. Hope this issue can be resolved in the future! |
Appreciate your excellent work! It could accelerate my code by 7x.
Despite that, I've found that the broadcasting is a little buggy. Here's a minimum demo to reproduce the bug:
Note I'm creating
C
with the second batch dimension to be 1 and C is later unsqueezed, indicating two successive implicit broadcast on subtraction.The above code will output:
I've tested the above codes with both cuda and cpu backends and the bug seems to be irrelevant to computing devices.
Several workarounds could help to correct the match percentage to 1:
C = C.expand(M, nh, c, d)
resolves the issue. Since expansion in pytorch will not lead to direct memory copy, this approach has zero overhead.X
: Settingn>1
will also avoid two successive singleton dimensions. However not practical in my case as I'm producing exactly 1 vector ofX
during each iteration.The text was updated successfully, but these errors were encountered: