Fuse "states" level tensors to reduce all gather during metrics compute #2892
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Contexts
During S503023, we found that the metrics
compute
time can take up to 30s on a large OMNI FM model. (Note this is not the metricsupdate
every iteration. This iscompute
which will create sync and refresh metrics on tensorboard. The frequency ofcompute
can be set bycompute_interval_steps
(default=100).) If Zoomer happens to capture the iteration of metrics compute, it will even make the trace file too big and can't be opened.The reason for metrics
compute
being so long was due to too many all_gather calls. See the screenshot of metricscompute
from the above SEV. That single metricscompute
takes 30s. And once zooming in, you will find it composed of hundreds of all_gather.{F1976937278}
{F1976937301}
Therefore, this diff tried to fuse the tensors before all_gather during metrics
compute
. With this diff plusFUSE_TASKS_COMPUTATION
to fuse task-level tensors, it can reduce "RecMetricModule compute" time from 1s 287 ms to 35 ms (36X reduction) on a shrunk OMNI FM model.What has been changed/added?
FUSED_TASKS_AND_STATES_COMPUTATION
in RecComputeMode.When turned on, this will both fuse tasks tensors and fuse state tensors.
fuse_state_tensors
to Metric class.This will be turned on once
FUSED_TASKS_AND_STATES_COMPUTATION
is set from config. Then when a metric'scompute
is called, it will fuse (stack) state tensors before all_gather. Then reconstruct the output tensor to desired format to do reduction.List
(e.g. auc) or 2D tensor (e.g. multiclass_ne), theFUSED_TASKS_AND_STATES_COMPUTATION
shouldn't be used or should at least fall back to eitherFUSED_TASKS_COMPUTATION
orUNFUSED_TASKS_COMPUTATION
.Reviewed By: iamzainhuda
Differential Revision: D72010614