You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fuse "states" level tensors to reduce all gather during metrics compute (#2892)
Summary:
Pull Request resolved: #2892
# Contexts
During S503023, we found that the metrics `compute` time can take up to **30s** on a large OMNI FM model. (Note this is not the metrics `update` every iteration. This is `compute` which will create sync and refresh metrics on tensorboard. The frequency of `compute` can be set by `compute_interval_steps` (default=100).) If Zoomer happens to capture the iteration of metrics compute, it will even make the trace file too big and can't be opened.
The reason for metrics `compute` being so long was due to too many all_gather calls. See the screenshot of metrics `compute` from the above SEV. That single metrics `compute` takes 30s. And once zooming in, you will find it composed of hundreds of all_gather.
{F1976937278}
{F1976937301}
Therefore, this diff tried to fuse the tensors before all_gather during metrics `compute`. With this diff plus `FUSE_TASKS_COMPUTATION` to fuse task-level tensors, it can reduce "RecMetricModule compute" time from 1s 287 ms to 35 ms (36X reduction) on a shrunk OMNI FM model.
# What has been changed/added?
1. Add `FUSED_TASKS_AND_STATES_COMPUTATION` in RecComputeMode.
When turned on, this will both fuse tasks tensors and fuse state tensors.
2. Add `fuse_state_tensors` to [Metric](https://www.internalfb.com/code/fbsource/[0c89c01039abfadd62e8ec1b34eb24b249b99f3f]/fbcode/pytorch_lightning_deprecated/metrics/torchmetrics/metric.py?lines=43) class.
This will be turned on once `FUSED_TASKS_AND_STATES_COMPUTATION` is set from config. Then when a metric's `compute` is called, it will fuse (stack) state tensors before all_gather. Then reconstruct the output tensor to desired format to do reduction.
3. It is noted currently we only support fusing/stacking 1D state tensors. Therefore, for states with `List` (e.g. auc) or 2D tensor (e.g. multiclass_ne), the `FUSED_TASKS_AND_STATES_COMPUTATION` shouldn't be used or should at least fall back to either `FUSED_TASKS_COMPUTATION` or `UNFUSED_TASKS_COMPUTATION`.
Reviewed By: iamzainhuda
Differential Revision: D72010614
fbshipit-source-id: 5ac77088cf9737ad783ad7e1740daf6afc0224a8
0 commit comments