-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow broadcasting pickable Python objects in distributed setup #538
Comments
Another important consideration to make is that such special attributes should be treated with care. Especially when we use functions like def _safe_apply(self, fn: Callable, cur_v):
try:
return fn(cur_v)
except Exception:
return cur_v
def _module_apply(self, fn):
import torch
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False
for key, param in self._parameters.items():
if key != "ids" and param is not None:
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
else:
assert isinstance(param, torch.nn.Parameter)
assert param.is_leaf
self._parameters[key] = torch.nn.Parameter(param_applied, param.requires_grad)
if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
param.grad.data = grad_applied
else:
assert param.grad.is_leaf
self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def _apply(self, fn: Callable) -> torch.nn.Module:
"""Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called
"""
this = self._module_apply(fn)
# Also apply fn to metric states and defaults
for key, value in self._defaults.items():
if isinstance(value, torch.Tensor):
this._defaults[key] = fn(value)
elif isinstance(value, Sequence):
this._defaults[key] = [fn(v) for v in value]
current_val = getattr(this, key)
if isinstance(current_val, torch.Tensor):
setattr(this, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(this, key, [self._safe_apply(fn,cur_v) for cur_v in current_val])
else:
raise TypeError(
"Expected metric state to be either a Tensor"
f"or a list of Tensor, but encountered {current_val}"
)
return this Again, very hacky but it works... |
Hi @aleSuglia, thanks for raising this issue. It is a great addition to torchmetrics with the only downside that the feature is only compatible with newer version of pytorch (I think from v1.8), meaning that we still need to do some checking in the |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
🚀 Feature
A metric should support any pickable Python object as state variable.
Motivation
Currently a
Metric
class cannot have aList[Tuple[str]]
as state variable. The synchronization will be skipped in that case and thecompute
method will have only the values computed in the current rank. This will inevitably invalidate the synchronization among processes for that metric.Pitch
Add support to the built-in
all_gather_object
available in PyTorch (please see: pytorch/pytorch#42189).I've tried to trace the current implementation and seems to me that
_sync_dist
should be generalised so thatapply_to_collection
can apply different functions for different keys. For instance, say we have the following class:We want to make sure that
video_embeddings
andtext_embeddings
use the currentall_gather
fortorch.Tensors
insteadall_gather_object
is used for the variableids
.For reference, this is my current (hacky) working implementation of
_sync_dist
:P.S.
apply_to_collection
shouldn't silently ignore that the function wasn't applied on one of the fields!The text was updated successfully, but these errors were encountered: