Skip to content

Commit

Permalink
Remove unused memory checks to speed up compute (#2719)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2719

The memory checks here have non-significant overhead in every compute step as there are a lot of tensor size calls involved here.
In our runs, this accounted for around 20% time spent in the rec metric compute step.

Given that this is not being used anymore, let's remove this call.

This diff removes the call from the metric_module. In the next set of diffs, I'll remove the argument from the callsites.

Differential Revision: D68995122
  • Loading branch information
atuljangra authored and facebook-github-bot committed Feb 1, 2025
1 parent 27fdfd6 commit ecdabfd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 175 deletions.
41 changes: 2 additions & 39 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@
MODEL_METRIC_LABEL: str = "model"


MEMORY_AVG_WARNING_PERCENTAGE = 20
MEMORY_AVG_WARNING_WARMUP = 100

MetricValue = Union[torch.Tensor, float]


Expand Down Expand Up @@ -146,7 +143,7 @@ class RecMetricModule(nn.Module):
throughput_metric (Optional[ThroughputMetric]): the ThroughputMetric.
state_metrics (Optional[Dict[str, StateMetric]]): the dict of StateMetrics.
compute_interval_steps (int): the intervals between two compute calls in the unit of batch number
memory_usage_limit_mb (float): the memory usage limit for OOM check
memory_usage_limit_mb (float): [Unused] the memory usage limit for OOM check
Call Args:
Not supported.
Expand Down Expand Up @@ -177,8 +174,6 @@ class RecMetricModule(nn.Module):
rec_metrics: RecMetricList
throughput_metric: Optional[ThroughputMetric]
state_metrics: Dict[str, StateMetric]
memory_usage_limit_mb: float
memory_usage_mb_avg: float
oom_count: int
compute_count: int
last_compute_time: float
Expand All @@ -195,6 +190,7 @@ def __init__(
compute_interval_steps: int = 100,
min_compute_interval: float = 0.0,
max_compute_interval: float = float("inf"),
# Unused, but needed for backwards compatibility. TODO: Remove from callsites
memory_usage_limit_mb: float = 512,
) -> None:
super().__init__()
Expand All @@ -205,8 +201,6 @@ def __init__(
self.trained_batches: int = 0
self.batch_size = batch_size
self.world_size = world_size
self.memory_usage_limit_mb = memory_usage_limit_mb
self.memory_usage_mb_avg = 0.0
self.oom_count = 0
self.compute_count = 0

Expand All @@ -230,37 +224,6 @@ def __init__(
)
self.last_compute_time = -1.0

def get_memory_usage(self) -> int:
r"""Total memory of unique RecMetric tensors in bytes"""
total = {}
for metric in self.rec_metrics.rec_metrics:
total.update(metric.get_memory_usage())
return sum(total.values())

def check_memory_usage(self, compute_count: int) -> None:
memory_usage_mb = self.get_memory_usage() / (10**6)
if memory_usage_mb > self.memory_usage_limit_mb:
self.oom_count += 1
logger.warning(
f"MetricModule is using {memory_usage_mb}MB. "
f"This is larger than the limit{self.memory_usage_limit_mb}MB. "
f"This is the f{self.oom_count}th OOM."
)

if (
compute_count > MEMORY_AVG_WARNING_WARMUP
and memory_usage_mb
> self.memory_usage_mb_avg * ((100 + MEMORY_AVG_WARNING_PERCENTAGE) / 100)
):
logger.warning(
f"MetricsModule is using more than {MEMORY_AVG_WARNING_PERCENTAGE}% of "
f"the average memory usage. Current usage: {memory_usage_mb}MB."
)

self.memory_usage_mb_avg = (
self.memory_usage_mb_avg * (compute_count - 1) + memory_usage_mb
) / compute_count

def _update_rec_metrics(
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
) -> None:
Expand Down
136 changes: 0 additions & 136 deletions torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,142 +353,6 @@ def test_initial_states_rank0_checkpointing(self) -> None:
lc, entrypoint=self._run_trainer_initial_states_checkpointing
)()

def test_empty_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
config = EmptyMetricsConfig
metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=128,
world_size=64,
my_rank=0,
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
self.assertEqual(metric_module.get_memory_usage(), 0)

def test_ne_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
config = DefaultMetricsConfig
metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=128,
world_size=64,
my_rank=0,
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
# Default NEMetric's dtype is
# float64 (8 bytes) * 16 tensors of size 1 = 128 bytes
# Tensors in NeMetricComputation:
# 8 in _default, 8 specific attributes: 4 attributes, 4 window
self.assertEqual(metric_module.get_memory_usage(), 128)
metric_module.update(gen_test_batch(128))
self.assertEqual(metric_module.get_memory_usage(), 160)

def test_calibration_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
config = dataclasses.replace(
DefaultMetricsConfig,
rec_metrics={
RecMetricEnum.CALIBRATION: RecMetricDef(
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
)
},
)
metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=128,
world_size=64,
my_rank=0,
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
# Default calibration metric dtype is
# float64 (8 bytes) * 8 tensors, size 1 = 64 bytes
# Tensors in CalibrationMetricComputation:
# 4 in _default, 4 specific attributes: 2 attribute, 2 window
self.assertEqual(metric_module.get_memory_usage(), 64)
metric_module.update(gen_test_batch(128))
self.assertEqual(metric_module.get_memory_usage(), 80)

def test_auc_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
config = dataclasses.replace(
DefaultMetricsConfig,
rec_metrics={
RecMetricEnum.AUC: RecMetricDef(
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
)
},
)
metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=128,
world_size=64,
my_rank=0,
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
# 3 (tensors) * 4 (float)
self.assertEqual(metric_module.get_memory_usage(), 12)
metric_module.update(gen_test_batch(128))
# 3 (tensors) * 128 (batch_size) * 4 (float)
self.assertEqual(metric_module.get_memory_usage(), 1536)

# Test memory usage over multiple updates does not increase unexpectedly, we don't need to force OOM as just knowing if the memory usage is increeasing how we expect is enough
for _ in range(10):
metric_module.update(gen_test_batch(128))

# 3 tensors * 128 batch size * 4 float * 11 updates
self.assertEqual(metric_module.get_memory_usage(), 16896)

# Ensure reset frees memory correctly
metric_module.reset()
self.assertEqual(metric_module.get_memory_usage(), 12)

def test_check_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
config = DefaultMetricsConfig
metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=128,
world_size=64,
my_rank=0,
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
metric_module.update(gen_test_batch(128))
with patch("torchrec.metrics.metric_module.logger") as logger_mock:
# Memory usage is fine.
metric_module.memory_usage_mb_avg = 160 / (10**6)
metric_module.check_memory_usage(1000)
self.assertEqual(metric_module.oom_count, 0)
self.assertEqual(logger_mock.warning.call_count, 0)

# OOM but memory usage does not exceed avg.
metric_module.memory_usage_limit_mb = 0.000001
metric_module.memory_usage_mb_avg = 160 / (10**6)
metric_module.check_memory_usage(1000)
self.assertEqual(metric_module.oom_count, 1)
self.assertEqual(logger_mock.warning.call_count, 1)

# OOM and memory usage exceed avg but warmup is not over.
metric_module.memory_usage_mb_avg = 160 / (10**6) / 10
metric_module.check_memory_usage(2)
self.assertEqual(metric_module.oom_count, 2)
self.assertEqual(logger_mock.warning.call_count, 2)

# OOM and memory usage exceed avg and warmup is over.
metric_module.memory_usage_mb_avg = 160 / (10**6) / 1.25
metric_module.check_memory_usage(1002)
self.assertEqual(metric_module.oom_count, 3)
self.assertEqual(logger_mock.warning.call_count, 4)

def test_should_compute(self) -> None:
metric_module = generate_metric_module(
TestMetricModule,
Expand Down

0 comments on commit ecdabfd

Please sign in to comment.