Skip to content

Conversation

ge0405
Copy link
Contributor

@ge0405 ge0405 commented Apr 17, 2025

Summary:

Contexts

During S503023, we found that the metrics compute time can take up to 30s on a large OMNI FM model. (Note this is not the metrics update every iteration. This is compute which will create sync and refresh metrics on tensorboard. The frequency of compute can be set by compute_interval_steps (default=100).) If Zoomer happens to capture the iteration of metrics compute, it will even make the trace file too big and can't be opened.

The reason for metrics compute being so long was due to too many all_gather calls. See the screenshot of metrics compute from the above SEV. That single metrics compute takes 30s. And once zooming in, you will find it composed of hundreds of all_gather.
{F1976937278}
{F1976937301}

Therefore, this diff tried to fuse the tensors before all_gather during metrics compute. With this diff plus FUSE_TASKS_COMPUTATION to fuse task-level tensors, it can reduce "RecMetricModule compute" time from 1s 287 ms to 35 ms (36X reduction) on a shrunk OMNI FM model.

What has been changed/added?

  1. Add FUSED_TASKS_AND_STATES_COMPUTATION in RecComputeMode.
    When turned on, this will both fuse tasks tensors and fuse state tensors.
  2. Add fuse_state_tensors to Metric class.
    This will be turned on once FUSED_TASKS_AND_STATES_COMPUTATION is set from config. Then when a metric's compute is called, it will fuse (stack) state tensors before all_gather. Then reconstruct the output tensor to desired format to do reduction.
  3. It is noted currently we only support fusing/stacking 1D state tensors. Therefore, for states with List (e.g. auc) or 2D tensor (e.g. multiclass_ne), the FUSED_TASKS_AND_STATES_COMPUTATION shouldn't be used or should at least fall back to either FUSED_TASKS_COMPUTATION or UNFUSED_TASKS_COMPUTATION.

Reviewed By: iamzainhuda

Differential Revision: D72010614

Summary:
# Contexts
During S503023, we found that the metrics `compute` time can take up to **30s** on a large OMNI FM model. (Note this is not the metrics `update` every iteration. This is `compute` which will create sync and refresh metrics on tensorboard. The frequency of `compute` can be set by `compute_interval_steps` (default=100).) If Zoomer happens to capture the iteration of metrics compute, it will even make the trace file too big and can't be opened. 


The reason for metrics `compute` being so long was due to too many all_gather calls. See the screenshot of metrics `compute` from the above SEV. That single metrics `compute` takes 30s. And once zooming in, you will find it composed of hundreds of all_gather.
 {F1976937278} 
 {F1976937301} 

Therefore, this diff tried to fuse the tensors before all_gather during metrics `compute`. With this diff plus `FUSE_TASKS_COMPUTATION` to fuse task-level tensors, it can reduce "RecMetricModule compute" time from 1s 287 ms to 35 ms (36X reduction) on a shrunk OMNI FM model.


# What has been changed/added?
1. Add `FUSED_TASKS_AND_STATES_COMPUTATION` in RecComputeMode.
When turned on, this will both fuse tasks tensors and fuse state tensors. 
2. Add `fuse_state_tensors` to [Metric](https://www.internalfb.com/code/fbsource/[0c89c01039abfadd62e8ec1b34eb24b249b99f3f]/fbcode/pytorch_lightning_deprecated/metrics/torchmetrics/metric.py?lines=43) class. 
This will be turned on once `FUSED_TASKS_AND_STATES_COMPUTATION` is set from config. Then when a metric's `compute` is called, it will fuse (stack) state tensors before all_gather. Then reconstruct the output tensor to desired format to do reduction.  
3. It is noted currently we only support fusing/stacking 1D state tensors. Therefore, for states with `List` (e.g. auc) or 2D tensor (e.g. multiclass_ne), the `FUSED_TASKS_AND_STATES_COMPUTATION` shouldn't be used or should at least fall back to either `FUSED_TASKS_COMPUTATION` or `UNFUSED_TASKS_COMPUTATION`.

Reviewed By: iamzainhuda

Differential Revision: D72010614
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 17, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72010614

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants