-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metric API re-design #344
Comments
why not provide all the basic blocks and let people do whatever order they want? metric.update()
metric.update().compute() # same as the current `compute_on_step` behaviour
metric.compute()
metric.sync().compute()
metric.sync().compute().unsync() where each returns self so they can be chained this looks more pytorch-y to me - similar to how to get a numpy tensor you do |
Related issue: #126 |
After a chat with @carmocca, we were thinking about this API.
We include two options (A, B) - but prefer (B) # basic update
metric.update(...)
# using `__call__`/`forward`
metric(..., accumulated=False) # update + compute
# Sync for batch states and accumulated_states
(A)
metric.compute(sync=True)
(B)
metric.sync().compute().unsync()
# Sync for batch states
(A)
metric.compute(accumulated=False, sync=True)
(B)
metric.sync().compute(accumulated=False).unsync()
(B)
# default is to sync both per-step states and accumulated states.
# setting a bool allows choosing one or the other
metric.sync(accumulated: Optional[bool] = None)
# defaults `accumulated=True`
metric.compute(accumulated: bool = True)
# can be changed
metric.compute(accumulated=False)
# unsyncs all - could also have the accumulated flag if we find it's useful
metric.unsync() |
@tchaton would this API mean that |
Yes, still can be used - I've updated Thomas' snippet |
Why does the metric need to distinguish these? What exactly goes into accumulated states? Are there other assumptions we could simplify/clarify in the current API? For example, could we remove |
Given that we have functional metrics APIs, I'm wondering if extra complexity for allowing Module-based metrics to be computed on one batch is worth it? Practically speaking, the only thing extra we get from batch-level compute is when we do syncing and it does make sense to reduce code there. So how about something like: class Metric:
@staticmethod
def compute_for_batch(cls, sync:bool, *args, **kwargs):
metric = cls()
return metric.update(*args, **kwargs).sync().compute() That way, we will make it much clearer to users that computation happens per-batch with no state accumulation (as they will call |
In here, does it mean we'll end up paying full sync cost even if we want to do just one batch computation? In general I like the more explicit API with sync/unsync, but worry that users will forget to unsync and will get into weird states. What happens if we run update() after sync/compute without unsync? Is it a meaningful operation? In case of |
A few more concerns/feature requests:
|
can you elaborate on how one function does not return anything but is responsible to return a dictionary
any reason why here you would need to repeat the |
@tchaton ^^
yep, that would be also quite hard to debug, so some context wrapper around? |
In this case, would the values of these two be the same?
We would still provide the |
I like these 5 primitives: In this state machine, after To @SkafteNicki 's point, @carmocca @tchaton - could you describe why the metric needs to track accumulated states? this could be useful for keeping track of the last N inputs to |
@ananthsub I think to have minimal breaking changes we could still allow forward to run |
To add on @ananthsub @SkafteNicki, forward isn't required, but PyTorch users are expecting the forward to return the current batch output. I think we should support those 6 primitives:
The metrics are accumulator. They accumulate the states from multiple batches and are used to perform reduction for the epoch end. @justusschock @SkafteNicki We could even imagine this design where we keep track of the states history: metric[-1].compute() # latest batch metric.compute() # all batches
equivalent to
metric[:].compute metric[-5:].compute() # on 5 latest batches class Metric()
def __init__(self, keep_history=False):
def _wrap_update()
batch_states = update(...)
if keep_history:
self.states.append(batch_states) # same states history. Users can re-compute for any specific batch, slice, or all
else:
self.add_to_rank_accumulated_states(batch_states) # uses reduction function
class MyMetric(Metric):
def update(self) -> Dict[str, Tensor]:
...
return {"state_1": state_1, ...}
metric = MyMetric() IMO, I like having the history available. |
Above, the metric API isn't expected to return anything from |
No it won't :) The update function is wrapped into a _wrap_update function which capture the state, store it and return None. |
if we're cleaning up the API, do we want to flip this? as in:
Would this be easier/simpler than needing to wrap the function like this?
|
also related to #143 |
that is exactly me asking, which is not clear... #344 (comment)
so from API for the user does the call update return anything or not? |
@ananthsub I would prefer to keep the |
I agree with @SkafteNicki there. We could do more by adding a third function, but at the same time, the API is simple to understand. I would prefer to keep wrapping the |
Actually I think I'm with @ananthsub on this - given that That being said, I still think that forcing two different things (per-batch and accumulated metric computation) in the same API with flags is giving us more hassle than it's worth. Separating single-batch compute operation into separate (static) function makes life easier and code more readable by being more explicit with intents. It is a bigger BC-impacting change though (although we can easily make it in a few BC-friendly steps). |
I think I agree with @ananthsub and @maximsch2 that a separate function is better. Just for the naming of this function: IMO it should not be a protected function, since this is exposed to the user upon implementation of new metrics. |
Here is a document resuming the current discussions for the Metric API refactor: https://docs.google.com/document/d/19Xm8z-Rf2d745z3plq5BrnB4pknFiJSJGJ0gNY_Fzrk/edit?usp=sharing |
Hello! Adrian from Lightning core here <3 In #840 we changed two core Metric APIs: The change
My recommendation is to revert #840 and I feel quite strongly about this. My reasoning:
These are the main points I would like to raise and my recommendation to revert the recent changes of update and compute. Some of these were discussed here on the issue but also some of the arguments were discussed on Slack. Please also see the message from @williamFalcon on the linked PR. cc @PyTorchLightning/core-metrics @tchaton @williamFalcon @ananthsub @carmocca Happy to hear everybody's thoughts ! <3 EDIT:
|
@awaelchli I see points 3 and 4 as significantly more valid than points 2. To be clear I do not mind and even prefer wrapping in a static fashion with decorator or context manager, but I am against the actual wrapping in runtime which is impossible to debug in IDE without starting the program...
agree, just to clarify, I do not like the runtime overwrite/wrap of the implemented method
I think it is not about names, but how we shall resolve/handle the updated/wrapped public |
@Borda Yes, in my post I missed to emphasize that I'm not against removing the wrapping. This is an internal detail but when choosing between one or the other I would put more weight on the user facing API, then work backwards from there. |
My two cents:
- It is far better to be making these API decisions now pre-1.0 than
realizing this too late. Isn't the expectation that a library before 1.0
can make changes like this?
- For Metric specifically, what started out as a pure interface has
turned into a concrete class. This has muddled what exactly the base class
does for users vs what users are expected to extend and implement
themselves. I believe there's a need to clarify this for developers.
- IMO, The most glaring need is that the current Metric API is not
principled about how it declares metric states (ie COMPUTE, SYNC, UPDATED),
and handles the state transitions in a very adhoc manner. There are
warnings/errors sprinkled around the code, but it's not immediately obvious
to users what this state diagram is. This leaks out to callers needing to
reach into private metric attributes to properly handle this. For example,
here:
- Supporting this work becomes much easier to do after the change in #840 because Metric
developers know exactly what to implement, and how the parent class relies
on it. To that end, I see the corollary of update/_update being what
the Lightning Trainer does for fit and _fit_impl.
def update(*args, **kwargs) -> None:
_validate_state_transition(src=self._current_state, dest=state.UPDATED)
return self._update(*args, **kwargs)
I find this^ much easier to follow than the current .
- If i'm only using metrics, then this entire discussion is irrelevant
because there's no changes from this POV.
Yes, that is why we are doing it now rather than later :)
|
🚀 Feature
Re-design the internal API for Base Metric class to return batch states from the update function.
The following proposal is a BC !
Limitations:
Currently, the Metric API is confusing for several reasons:
update
function actually performs acompute
ifcompute_on_step=True
update
function actually perform 2 updates, which is an expensive operation for some metrics, FID for exampleProposal:
compute
andcompute_step
Here is the internal re-design of Metric
Additional context
sync = store cache + all_gather + reduction
unsync = restore cache
The text was updated successfully, but these errors were encountered: