Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric API re-design #344

Closed
tchaton opened this issue Jul 1, 2021 · 29 comments
Closed

Metric API re-design #344

tchaton opened this issue Jul 1, 2021 · 29 comments
Assignees
Labels
API / design enhancement New feature or request Important milestonish
Milestone

Comments

@tchaton
Copy link
Contributor

tchaton commented Jul 1, 2021

🚀 Feature

Re-design the internal API for Base Metric class to return batch states from the update function.

The following proposal is a BC !

Limitations:

Currently, the Metric API is confusing for several reasons:

  • The update function actually performs a compute if compute_on_step=True
  • The update function actually perform 2 updates, which is an expensive operation for some metrics, FID for example
  • Users don't have a clear API to perform computation
  • The Metric internals is tailored to Lightning. IMO, TM should define its own API to better faciliate metrics computation and Lightning should adapt to it.

Proposal:

  • The update function doesn't return anything
  • There is 2 functions for reduction, compute and compute_step
  • The update function is responsible to return a dictionary containing the batch states

Here is the internal re-design of Metric

class Metric()

    def _wrap_update()
        batch_states = update(...)
        self.add_to_rank_accumulated_states(batch_states) # uses reduction function
        self.batch_states = batch_states

class MyMetric(Metric):
    def update(self) -> Dict[str, Tensor]:
        ...
        return {"state_1": state_1, ...}

metric = MyMetric()

metric(...) # compute batch states, add them to accumulated states using reduction functions
compute() # accumulated_states compute on all ranks
compute_on_step() # batch_states compute on all ranks
compute(sync_dist=False) # accumulated_states compute per rank
compute_on_step(sync_dist=False) #  batch_states compute per rank 
class Accuracy

    def __init__(self):
        self.add_state("correct", torch.tensor(0.), sync_dist_fn=torch.sum)
        self.add_state("total", torch.tensor(0.), sync_dist_fn=torch.sum)

    def update(self, preds, targets):
        return {"total": preds.shape[0], "correct": (preds == targets).sum()}

    def compute(self):
        return self.correct / self.total


metric = Accuracy()

None = metric.update([0, 1], [0, 0])

0.5 = metric([0, 1], [0, 0], sync_dist=True) # compute batch states and cache it, add batch states to accumulated states
1 = metric([0, 0], [0, 0], sync_dist=True)

0.75 = metric.compute()
1. = metric.compute(accumulated=False)
# accumulated=True means computing accuracy on 3 batches
# accumulated=False means computing accuracy on latest batch
acc()
acc()
acc()

Additional context

sync = store cache + all_gather + reduction
unsync = restore cache

@tchaton tchaton added enhancement New feature or request help wanted Extra attention is needed labels Jul 1, 2021
@carmocca
Copy link
Contributor

carmocca commented Jul 1, 2021

why not provide all the basic blocks and let people do whatever order they want?

metric.update()
metric.update().compute()  # same as the current `compute_on_step` behaviour
metric.compute()
metric.sync().compute()
metric.sync().compute().unsync()

where each returns self so they can be chained

this looks more pytorch-y to me - similar to how to get a numpy tensor you do tensor.detach().cpu().numpy()
not tensor.numpy(device=cpu, detach=True)

@carmocca carmocca changed the title Metric API resign Metric API re-design Jul 1, 2021
@carmocca carmocca added API / design and removed help wanted Extra attention is needed labels Jul 1, 2021
@SkafteNicki
Copy link
Member

SkafteNicki commented Jul 1, 2021

Related issue: #126

@tchaton
Copy link
Contributor Author

tchaton commented Jul 1, 2021

After a chat with @carmocca, we were thinking about this API.

sync: cache states + all_gather + reductions # expensive
unsync: restore cache states. # not expensive

We include two options (A, B) - but prefer (B)

# basic update
metric.update(...)
# using `__call__`/`forward`
metric(..., accumulated=False) # update + compute


# Sync for batch states and accumulated_states
(A)
metric.compute(sync=True)
(B)
metric.sync().compute().unsync()


# Sync for batch states
(A)
metric.compute(accumulated=False, sync=True)
(B)
metric.sync().compute(accumulated=False).unsync()


(B)
# default is to sync both per-step states and accumulated states.
# setting a bool allows choosing one or the other
metric.sync(accumulated: Optional[bool] = None)
# defaults `accumulated=True`
metric.compute(accumulated: bool = True)
# can be changed
metric.compute(accumulated=False)
# unsyncs all - could also have the accumulated flag if we find it's useful
metric.unsync()

@SkafteNicki
Copy link
Member

@tchaton would this API mean that forward is longer a method that is intended to be used?

@carmocca
Copy link
Contributor

carmocca commented Jul 1, 2021

would this API mean that forward is longer a method that is intended to be used?

Yes, still can be used - I've updated Thomas' snippet

@Borda Borda added the Important milestonish label Jul 1, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Jul 1, 2021

default is to sync both per-step states and accumulated states.

Why does the metric need to distinguish these? What exactly goes into accumulated states?

Are there other assumptions we could simplify/clarify in the current API? For example, could we remove dist_sync_on_step fro the constructor in favor of metric.update(....).sync() ?

@maximsch2
Copy link
Contributor

maximsch2 commented Jul 1, 2021

Given that we have functional metrics APIs, I'm wondering if extra complexity for allowing Module-based metrics to be computed on one batch is worth it? Practically speaking, the only thing extra we get from batch-level compute is when we do syncing and it does make sense to reduce code there. So how about something like:

class Metric:
    @staticmethod
    def compute_for_batch(cls, sync:bool, *args, **kwargs):
        metric = cls()
        return metric.update(*args, **kwargs).sync().compute()

That way, we will make it much clearer to users that computation happens per-batch with no state accumulation (as they will call Accuracy.compute_for_batch instead of calling a method on an instance of the metric).

@maximsch2
Copy link
Contributor

maximsch2 commented Jul 1, 2021

@tchaton,

metric.sync().compute(accumulated=False).unsync()

In here, does it mean we'll end up paying full sync cost even if we want to do just one batch computation?

In general I like the more explicit API with sync/unsync, but worry that users will forget to unsync and will get into weird states. What happens if we run update() after sync/compute without unsync? Is it a meaningful operation?

In case of sync taking accumulated argument, what happens with metric.sync(accumulated=False).compute(accumulated=True)?

@maximsch2
Copy link
Contributor

A few more concerns/feature requests:

  • sync not being idempotent is annoying. Should we make semantics of sync such that it syncs + resets on all ranks except rank 0? That way we can sync multiple times in a row and don't need unsync() any more either.
  • In the future, as things are scaled up, we might want to implement sharded metrics. Say we have a problem with 1M classes and we want to compute PRAUC for each class. This might be pretty expensive if we store data for each class on each worker, but we can shard the values and accumulate values for each class on separate worker. In this world, sync becomes different and checkpointing becomes tricky (similar to the FSDP training issue).

@Borda
Copy link
Member

Borda commented Jul 1, 2021

  • The update function doesn't return anything
  • ...
  • The update function is responsible to return a dictionary containing the batch states

can you elaborate on how one function does not return anything but is responsible to return a dictionary

# Sync for batch states
(A)
metric.compute(accumulated=False, sync=True)
(B)
metric.sync().compute(accumulated=False).unsync()

any reason why here you would need to repeat the accumulated=False?

@Borda
Copy link
Member

Borda commented Jul 1, 2021

In here, does it mean we'll end up paying full sync cost even if we want to do just one batch computation?

@tchaton ^^

In general I like the more explicit API with sync/unsync, but worry that users will forget to unsync and will get into weird states. What happens if we run update() after sync/compute without unsync? Is it a meaningful operation?

yep, that would be also quite hard to debug, so some context wrapper around?

@carmocca
Copy link
Contributor

carmocca commented Jul 1, 2021

In case of sync taking accumulated argument, what happens with metric.sync(accumulated=False).compute(accumulated=True)?

In this case, would the values of these two be the same?

metric.sync(accumulated=False).compute(accumulated=True) == metric.compute(accumulated=True)

yep, that would be also quite hard to debug, so some context wrapper around?

We would still provide the sync_context context manager we currently have. Could also be renamed to autosync (as autocast)

@ananthsub
Copy link
Contributor

ananthsub commented Jul 2, 2021

In general I like the more explicit API with sync/unsync, but worry that users will forget to unsync and will get into weird states. What happens if we run update() after sync/compute without unsync? Is it a meaningful operation?

I like these 5 primitives: reset, update, compute, sync, unsync

In this state machine, after sync() the only valid operations on the metric are either reset, unsync, or compute - update or sync can't be allowed. Making this available with a context manager that calls either reset or unsync sounds good to me. Users could elect to chain calls themselves, but the metric should raise exceptions on invalid call sequences (sync().sync(), sync().update() )

To @SkafteNicki 's point, forward isn't strictly needed then. Does forward need to be called at all? Is this at the discretion of the subclass's implementation? Or does the framework need to provide a default implementation?

@carmocca @tchaton - could you describe why the metric needs to track accumulated states? this could be useful for keeping track of the last N inputs to update but do all metrics need this by default?

@justusschock
Copy link
Member

@ananthsub I think to have minimal breaking changes we could still allow forward to run update and compute on the current batch only. But it was not required before and wouldn't be required now.

@tchaton
Copy link
Contributor Author

tchaton commented Jul 2, 2021

In general I like the more explicit API with sync/unsync, but worry that users will forget to unsync and will get into weird states. What happens if we run update() after sync/compute without unsync? Is it a meaningful operation?

I like these 5 primitives: reset, update, compute, sync, unsync

In this state machine, after sync() the only valid operations on the metric are either reset, unsync, or compute - update or sync can't be allowed. Making this available with a context manager that calls either reset or unsync sounds good to me. Users could elect to chain calls themselves, but the metric should raise exceptions on invalid call sequences (sync().sync(), sync().update())

To @SkafteNicki 's point, forward isn't strictly needed then. Does forward need to be called at all? Is this at the discretion of the subclass's implementation? Or does the framework need to provide a default implementation?

@carmocca @tchaton - could you describe why the metric needs to track accumulated states? this could be useful for keeping track of the last N inputs to update but do all metrics need this by default?

To add on @ananthsub @SkafteNicki, forward isn't required, but PyTorch users are expecting the forward to return the current batch output. I think we should support those 6 primitives:

reset, update, compute, sync, unsync, [call or foward](update + compute)

@ananthsub could you describe why the metric needs to track accumulated states? this could be useful for keeping track of the last N inputs to update but do all metrics need this by default?

The metrics are accumulator. They accumulate the states from multiple batches and are used to perform reduction for the epoch end.

@justusschock @SkafteNicki We could even imagine this design where we keep track of the states history:

metric[-1].compute() # latest batch
metric.compute() # all batches

equivalent to

metric[:].compute
metric[-5:].compute() # on 5 latest batches
class Metric()

    def __init__(self, keep_history=False):
 

    def _wrap_update()
        batch_states = update(...)
        if keep_history:
               self.states.append(batch_states) # same states history. Users can re-compute for any specific batch, slice, or all 
        else:
              self.add_to_rank_accumulated_states(batch_states) # uses reduction function

class MyMetric(Metric):
    def update(self) -> Dict[str, Tensor]:
        ...
        return {"state_1": state_1, ...}

metric = MyMetric()

IMO, I like having the history available.

@ananthsub
Copy link
Contributor

def update(self) -> Dict[str, Tensor]:

Above, the metric API isn't expected to return anything from update but here it's returning the state?

@tchaton
Copy link
Contributor Author

tchaton commented Jul 2, 2021

def update(self) -> Dict[str, Tensor]:

Above, the metric API isn't expected to return anything from update but here it's returning the state?

No it won't :) The update function is wrapped into a _wrap_update function which capture the state, store it and return None.

@ananthsub
Copy link
Contributor

ananthsub commented Jul 2, 2021

if we're cleaning up the API, do we want to flip this? as in:

  • We define an abstract _update_impl that metric authors/implementors must fill out
  • The base Metric defines update which calls the implementation's _update_impl which captures the state and stores it.
  • Users continue calling update as is

Would this be easier/simpler than needing to wrap the function like this?

@Borda
Copy link
Member

Borda commented Jul 2, 2021

also related to #143

@Borda
Copy link
Member

Borda commented Jul 2, 2021

Above, the metric API isn't expected to return anything from update but here it's returning the state?

that is exactly me asking, which is not clear... #344 (comment)

No it won't :) The update function is wrapped into a _wrap_update function which capture the state, store it and return None.

so from API for the user does the call update return anything or not?

@SkafteNicki
Copy link
Member

@ananthsub I would prefer to keep the update, compute just because it would mean minimal breaking changes from the users perspective. It is also very simple for the user to remember the API, when the methods they are implementing is exactly the once they are calling.

@Borda Borda added this to the v0.5 milestone Jul 2, 2021
@tchaton
Copy link
Contributor Author

tchaton commented Jul 2, 2021

I agree with @SkafteNicki there. We could do more by adding a third function, but at the same time, the API is simple to understand. I would prefer to keep wrapping the update and compute function internally for minimal breaking change and user API simplicity.

@maximsch2
Copy link
Contributor

Actually I think I'm with @ananthsub on this - given that update that users are implementing are not the same they are calling (even the type signature is different!), having them with the same name will bring confusion. In terms of BC issues for this change - for all metrics implemented inside torchmetrics, there will be no issue, the only issue will be for metrics implemented outside of it. As long as we check in __init__ and throw an error on user override of update in a new metric class, this issue will be immediately obvious to people who implement their custom metrics and change will be trivial to make.

That being said, I still think that forcing two different things (per-batch and accumulated metric computation) in the same API with flags is giving us more hassle than it's worth. Separating single-batch compute operation into separate (static) function makes life easier and code more readable by being more explicit with intents. It is a bigger BC-impacting change though (although we can easily make it in a few BC-friendly steps).

@justusschock
Copy link
Member

I think I agree with @ananthsub and @maximsch2 that a separate function is better. Just for the naming of this function: IMO it should not be a protected function, since this is exposed to the user upon implementation of new metrics.

@tchaton
Copy link
Contributor Author

tchaton commented Nov 8, 2021

Here is a document resuming the current discussions for the Metric API refactor: https://docs.google.com/document/d/19Xm8z-Rf2d745z3plq5BrnB4pknFiJSJGJ0gNY_Fzrk/edit?usp=sharing

@Borda Borda modified the milestones: v0.7, v0.8 Jan 6, 2022
@Borda Borda unpinned this issue Jan 9, 2022
@awaelchli
Copy link
Contributor

awaelchli commented Mar 3, 2022

Hello! Adrian from Lightning core here <3

In #840 we changed two core Metric APIs: update was renamed to _update and compute was renamed to _compute. The dynamic wrapping around these methods was removed.

The change

  • has no impact on users of built-in metrics
  • is a breaking change for users who have custom metrics

My recommendation is to revert #840 and I feel quite strongly about this. My reasoning:

  1. update and compute are essential methods for implementing custom metrics. They are not optional. Having them as protected overrides is deviating too strongly from a familiar API. This pattern does not exist anywhere in our core Lightning APIs and also not in PyTorch (e.g. nn.Module). In general, in Lightning we use protected attributes and methods to signal to the user that it is an internal api that should normally not be accessed by them directly. I'm pretty sure we don't want to send this signal for the update and compute methods in Metrics.

  2. For a metric developer it looks strange that they have to implement _update (protected) but then have to call update (public) to use the metric. More so for compute, as it is more common to call this one.

  3. The argument of "patching methods is a bad practice" is weak in my opinion when talking specifically about at how Metric did it. In the update/compute methods we were wrapping the methods. It was essentially the same as having a decorator (side note: one of the main motivations it was implemented this way was to avoid decorators in the first place). The call to the wrapper was always guaranteed to call the inner method (contract). Plus, the Metric class itself is responsible for applying the wrapper, i.e., there are no external components patching a method on the Metric object dynamically (an example in stark contrast here: Trainer was patching methods on the LightningModule, which was bad and we removed that). These three properties make the patching that Metric did well behaved in my opinion. The only inconvenience I could get from the conversation here was that the type annotations were mismatching. This is the only place where I don't have a good resolution but I'm not sure how big of an issue it is? Comments about this would be very appreciated <3

  4. The argument that "the wrapped methods are hard to debug" does not hold as best as I can see. I'm struggling to find evidence of this. Does anybody have a concrete example where a) the stack trace is insane, misleading or contains wrong information? b) the debugger can't handle the breakpoints inside the wrapped methods? Now, of course if the error originates from inside the wrapper instead of the user defined method, then that's on us to fix or provide better error messages. In which cases would that occur?

These are the main points I would like to raise and my recommendation to revert the recent changes of update and compute. Some of these were discussed here on the issue but also some of the arguments were discussed on Slack. Please also see the message from @williamFalcon on the linked PR.

cc @PyTorchLightning/core-metrics @tchaton @williamFalcon @ananthsub @carmocca

Happy to hear everybody's thoughts ! <3

EDIT:

  • not saying we should stop thinking about solutions to avoid the wrapping.
  • can we solve the problem by just choosing a different name?

@Borda
Copy link
Member

Borda commented Mar 3, 2022

@awaelchli I see points 3 and 4 as significantly more valid than points 2. To be clear I do not mind and even prefer wrapping in a static fashion with decorator or context manager, but I am against the actual wrapping in runtime which is impossible to debug in IDE without starting the program...
And I would oppose point 2 also as it is the general programming practice (maybe less common in PyTorch but still widely used in python) to overwrite protected methods to steer the child behavior but do not change the public API, see for example Access Modifiers in Python: Public, Private and Protected

not saying we should stop thinking about solutions to avoid the wrapping.

agree, just to clarify, I do not like the runtime overwrite/wrap of the implemented method

can we solve the problem by just choosing a different name?

I think it is not about names, but how we shall resolve/handle the updated/wrapped public update/compute

@awaelchli
Copy link
Contributor

awaelchli commented Mar 3, 2022

@Borda Yes, in my post I missed to emphasize that I'm not against removing the wrapping. This is an internal detail but when choosing between one or the other I would put more weight on the user facing API, then work backwards from there.

@Borda
Copy link
Member

Borda commented Mar 3, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API / design enhancement New feature or request Important milestonish
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants