Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow broadcasting pickable Python objects in distributed setup #538

Closed
aleSuglia opened this issue Sep 20, 2021 · 3 comments
Closed

Allow broadcasting pickable Python objects in distributed setup #538

aleSuglia opened this issue Sep 20, 2021 · 3 comments
Labels

Comments

@aleSuglia
Copy link

aleSuglia commented Sep 20, 2021

🚀 Feature

A metric should support any pickable Python object as state variable.

Motivation

Currently a Metric class cannot have a List[Tuple[str]] as state variable. The synchronization will be skipped in that case and the compute method will have only the values computed in the current rank. This will inevitably invalidate the synchronization among processes for that metric.

Pitch

Add support to the built-in all_gather_object available in PyTorch (please see: pytorch/pytorch#42189).

I've tried to trace the current implementation and seems to me that _sync_dist should be generalised so that apply_to_collection can apply different functions for different keys. For instance, say we have the following class:

class Recall(Metric):
    video_embeddings: List[torch.Tensor]
    text_embeddings: List[torch.Tensor]
    ids: List[Tuple[str]]

    def __init__(self,
                 compute_on_step: bool = True,
                 dist_sync_on_step: bool = False,
                 process_group: Optional[Any] = None,
                 dist_sync_fn: Callable = None):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            dist_sync_fn=dist_sync_fn
        )
        self.add_state("video_embeddings", default=[], dist_reduce_fx=None)
        self.add_state("text_embeddings", default=[], dist_reduce_fx=None)
        self.add_state("ids", default=[], dist_reduce_fx=None)

We want to make sure that video_embeddings and text_embeddings use the current all_gather for torch.Tensors instead all_gather_object is used for the variable ids.

For reference, this is my current (hacky) working implementation of _sync_dist:

    def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
        input_dict = {attr: getattr(self, attr) for attr in self._reductions}

        for attr, reduction_fn in self._reductions.items():
            # pre-concatenate metric states that are lists to reduce number of all_gather operations
            if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
                input_dict[attr] = [dim_zero_cat(input_dict[attr])]

        curr_group = process_group or self.process_group
        output_dict = apply_to_collection(
            input_dict,
            torch.Tensor,
            dist_sync_fn,
            group=curr_group,
        )

        output_dict["ids"] = [None for _ in range(torch.distributed.get_world_size(curr_group))]
        all_gather_object(
            output_dict["ids"],
            obj=[x for xs in input_dict["ids"] for x in xs],
            group=curr_group
        )

        for attr, reduction_fn in self._reductions.items():
            # pre-processing ops (stack or flatten for inputs)
            if isinstance(output_dict[attr][0], torch.Tensor):
                output_dict[attr] = torch.stack(output_dict[attr])
            elif isinstance(output_dict[attr][0], list):
                output_dict[attr] = _flatten(output_dict[attr])

            if not (callable(reduction_fn) or reduction_fn is None):
                raise TypeError('reduction_fn must be callable or None')
            reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
            setattr(self, attr, reduced)

P.S. apply_to_collection shouldn't silently ignore that the function wasn't applied on one of the fields!

@aleSuglia aleSuglia added enhancement New feature or request help wanted Extra attention is needed labels Sep 20, 2021
@Borda Borda removed the help wanted Extra attention is needed label Sep 20, 2021
@aleSuglia
Copy link
Author

aleSuglia commented Sep 21, 2021

Another important consideration to make is that such special attributes should be treated with care. Especially when we use functions like _apply() from PyTorch torch.nn.Module which assume that every element in the class is a Tensor. For instance, I had to re-implement the Metric _apply method as follows:

    def _safe_apply(self, fn: Callable, cur_v):
        try:
            return fn(cur_v)
        except Exception:
            return cur_v

    def _module_apply(self, fn):
        import torch

        for module in self.children():
            module._apply(fn)

        def compute_should_use_set_data(tensor, tensor_applied):
            if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
                # If the new tensor has compatible tensor type as the existing tensor,
                # the current behavior is to change the tensor in-place using `.data =`,
                # and the future behavior is to overwrite the existing tensor. However,
                # changing the current behavior is a BC-breaking change, and we want it
                # to happen in future releases. So for now we introduce the
                # `torch.__future__.get_overwrite_module_params_on_conversion()`
                # global flag to let the user control whether they want the future
                # behavior of overwriting the existing tensor or not.
                return not torch.__future__.get_overwrite_module_params_on_conversion()
            else:
                return False

        for key, param in self._parameters.items():
            if key != "ids" and param is not None:
                # Tensors stored in modules are graph leaves, and we don't want to
                # track autograd history of `param_applied`, so we have to use
                # `with torch.no_grad():`
                with torch.no_grad():
                    param_applied = fn(param)
                should_use_set_data = compute_should_use_set_data(param, param_applied)
                if should_use_set_data:
                    param.data = param_applied
                else:
                    assert isinstance(param, torch.nn.Parameter)
                    assert param.is_leaf
                    self._parameters[key] = torch.nn.Parameter(param_applied, param.requires_grad)

                if param.grad is not None:
                    with torch.no_grad():
                        grad_applied = fn(param.grad)
                    should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                    if should_use_set_data:
                        param.grad.data = grad_applied
                    else:
                        assert param.grad.is_leaf
                        self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)

        for key, buf in self._buffers.items():
            if buf is not None:
                self._buffers[key] = fn(buf)

        return self

    def _apply(self, fn: Callable) -> torch.nn.Module:
        """Overwrite _apply function such that we can also move metric states
        to the correct device when `.to`, `.cuda`, etc methods are called
        """
        this = self._module_apply(fn)

        # Also apply fn to metric states and defaults
        for key, value in self._defaults.items():
            if isinstance(value, torch.Tensor):
                this._defaults[key] = fn(value)
            elif isinstance(value, Sequence):
                this._defaults[key] = [fn(v) for v in value]

            current_val = getattr(this, key)
            if isinstance(current_val, torch.Tensor):
                setattr(this, key, fn(current_val))
            elif isinstance(current_val, Sequence):
                setattr(this, key, [self._safe_apply(fn,cur_v) for cur_v in current_val])
            else:
                raise TypeError(
                    "Expected metric state to be either a Tensor"
                    f"or a list of Tensor, but encountered {current_val}"
                )
        return this

Again, very hacky but it works...

@SkafteNicki
Copy link
Member

Hi @aleSuglia, thanks for raising this issue. It is a great addition to torchmetrics with the only downside that the feature is only compatible with newer version of pytorch (I think from v1.8), meaning that we still need to do some checking in the self.add_state function.
Lets try including it in the bigger API change #344 that is currently going on.

@Borda Borda added this to the v0.7 milestone Oct 15, 2021
@Borda Borda modified the milestones: v0.7, v0.8 Jan 6, 2022
@Borda Borda modified the milestones: v0.8, v0.9 Mar 22, 2022
@SkafteNicki SkafteNicki removed this from the v0.9 milestone May 12, 2022
@stale
Copy link

stale bot commented Jul 22, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 22, 2022
@stale stale bot closed this as completed Jul 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants