Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dist_sync_fn seems dysfunctional #1183

Closed
yellowdolphin opened this issue Aug 13, 2022 · 2 comments · Fixed by #1301
Closed

dist_sync_fn seems dysfunctional #1183

yellowdolphin opened this issue Aug 13, 2022 · 2 comments · Fixed by #1301
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@yellowdolphin
Copy link

yellowdolphin commented Aug 13, 2022

🐛 Bug

dist_sync_fn is a documented kwarg of class Metric and suggests to pass an alternative to torch.distributed.all_gather. I suppose, the intend is to allow use of distributed contexts other than torch.distributed, for which the default works fine. This use case fails due to the following issues:

  • passed dist_sync_fn is never called because hardcoded test jit_distributed_available() returns False in any other sync context
  • this test is assigned as default to kwarg distributed_available of sync() and sync_context(), but neither can be called by the user because all metrics already wrap their compute() method at init
  • sync_context can be applied only once to a function, else the inner wrapper will raise TorchMetricsUserError "The Metric has already been synced."

To Reproduce

Code for testing Accuracy and MetricCollection in a TPU context (colab): https://colab.research.google.com/drive/1MlxWSrkKKuZ3WSb9duf1c0MoO2A8jDAE?usp=sharing

Code sample

acc = torchmetrics.Accuracy(dist_sync_fn=gather_fn).to(device)
acc.update(scores.detach(), labels)
acc_value = acc.compute()

Expected behavior

Automatically sync and compute correct metrics.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.10dev
  • Python & PyTorch Version (e.g., 1.0):
  • Any other relevant information such as OS (e.g., Linux): e.g. torch_xla.distributed

Additional context

I found that with minimal changes in metrics.py, dist_sync_fn can actually be used to run torchmetrics on TPUs (torch_xla), like in the code sample above. I could send a PR (see fork in colab notebook).

@yellowdolphin yellowdolphin added bug / fix Something isn't working help wanted Extra attention is needed labels Aug 13, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki SkafteNicki added this to the v0.11 milestone Sep 14, 2022
@Borda
Copy link
Member

Borda commented Oct 19, 2022

@justusschock, mind having a look at it? 🦦

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants