Skip to content

Commit e946615

Browse files
tjruwaseMasahiro Tanaka
authored andcommitted
Control trace cache warnings (#7039)
Make trace cache warnings configurable, and disabled by default. Fix #6985, #4081, #5033, #5006, #5662 --------- Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
1 parent e3ea926 commit e946615

File tree

6 files changed

+54
-21
lines changed

6 files changed

+54
-21
lines changed

deepspeed/runtime/engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,9 @@ def zero_quantized_gradients(self):
993993
def zeropp_loco_param(self):
994994
return self._config.zero_config.zeropp_loco_param
995995

996+
def zero_log_trace_cache_warnings(self):
997+
return self._config.zero_config.log_trace_cache_warnings
998+
996999
def dump_state(self):
9971000
return self._config.dump_state
9981001

@@ -1702,6 +1705,7 @@ def _configure_zero_optimizer(self, optimizer):
17021705
zero_quantized_weights=self.zero_quantized_weights(),
17031706
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
17041707
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
1708+
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
17051709
)
17061710
else:
17071711
log_dist(
@@ -1750,6 +1754,7 @@ def _configure_zero_optimizer(self, optimizer):
17501754
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
17511755
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
17521756
zeropp_loco_param=self.zeropp_loco_param(),
1757+
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
17531758
)
17541759

17551760
else:

deepspeed/runtime/zero/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"memory_efficient_linear": [true|false],
4646
"override_module_apply": [true|false],
4747
"zeropp_loco_param": {...},
48+
"log_trace_cache_warnings" : [true|false],
4849
}
4950
}
5051
"""
@@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
340341
Override nn.Module apply function, for Stage 3.
341342
"""
342343

344+
log_trace_cache_warnings: bool = False
345+
"""
346+
Whether to log warnings from trace cache, such as invalidation events.
347+
"""
348+
343349
# Validators
344350
@model_validator(mode="after")
345351
def overlap_comm_valid(self):

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
zero_quantized_weights=False,
104104
zero_quantized_nontrainable_weights=False,
105105
zero_module_granularity_threshold=0,
106+
log_trace_cache_warnings=False,
106107
):
107108

108109
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
@@ -118,6 +119,7 @@ def __init__(
118119
self.zero_param_parallel_group = zero_param_parallel_group
119120
self.zero_quantized_weights = zero_quantized_weights
120121
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
122+
self.log_trace_cache_warnings = log_trace_cache_warnings
121123

122124
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
123125
self.offload_device = offload_param_config.device
@@ -165,7 +167,9 @@ def __init__(
165167
timers=self.timers,
166168
zero_quantized_weights=self.zero_quantized_weights,
167169
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
168-
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
170+
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module,
171+
log_trace_cache_warnings=self.log_trace_cache_warnings,
172+
)
169173

170174
self.forward_hooks = []
171175
self.backward_hooks = []

deepspeed/runtime/zero/partitioned_param_coordinator.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,20 @@ class __ParamInTrace:
7676
param: Parameter
7777
step_id_last_used_at: int
7878

79-
def __init__(self,
80-
prefetch_bucket_sz: int,
81-
max_reuse_distance_in_numel: int,
82-
max_available_parameters_in_numel: int,
83-
allgather_stream: get_accelerator().Stream,
84-
inflight_param_registry: InflightParamRegistry,
85-
prefetch_nvme: bool = False,
86-
timers=None,
87-
zero_quantized_weights=False,
88-
zero_quantized_nontrainable_weights=False,
89-
fast_sharding_for_leaf_module=False) -> None:
79+
def __init__(
80+
self,
81+
prefetch_bucket_sz: int,
82+
max_reuse_distance_in_numel: int,
83+
max_available_parameters_in_numel: int,
84+
allgather_stream: get_accelerator().Stream,
85+
inflight_param_registry: InflightParamRegistry,
86+
prefetch_nvme: bool = False,
87+
timers=None,
88+
zero_quantized_weights=False,
89+
zero_quantized_nontrainable_weights=False,
90+
fast_sharding_for_leaf_module=False,
91+
log_trace_cache_warnings=False,
92+
) -> None:
9093
# mapping of param -> handle for each param that is currently in flight
9194
self.__inflight_param_registry = inflight_param_registry
9295
# keeps track of the number of submodules invoked so far.
@@ -129,6 +132,9 @@ def __init__(self,
129132
self.__max_ongoing_fetch_events: int = 2
130133
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)
131134

135+
# Whether to log trace cache warnings, e.g. invalidation events
136+
self.__log_trace_cache_warnings = log_trace_cache_warnings
137+
132138
# whether to enable fast fetch for the z3 leaf module.
133139
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
134140
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
@@ -177,7 +183,7 @@ def trace_prologue(self, sub_module: Module) -> None:
177183
print_rank_0(
178184
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
179185
f"cache has only {len(self.__submodule_order)} modules",
180-
force=True)
186+
force=self.__log_trace_cache_warnings)
181187
self._invalidate_trace()
182188
return
183189

@@ -186,7 +192,7 @@ def trace_prologue(self, sub_module: Module) -> None:
186192
print_rank_0(
187193
f"Invalidate trace cache @ step {self.__step_id}: "
188194
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
189-
force=True)
195+
force=self.__log_trace_cache_warnings)
190196
self._invalidate_trace()
191197

192198
@compiler.disable

deepspeed/runtime/zero/stage3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __init__(
176176
zero_quantized_nontrainable_weights=False,
177177
zero_module_granularity_threshold=0,
178178
zeropp_loco_param=None,
179+
log_trace_cache_warnings=False,
179180
):
180181
see_memory_usage("Stage 3 initialize beginning", force=True)
181182

@@ -247,7 +248,9 @@ def __init__(
247248
zero_param_parallel_group=zero_param_parallel_group,
248249
zero_quantized_weights=zero_quantized_weights,
249250
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
250-
zero_module_granularity_threshold=zero_module_granularity_threshold)
251+
zero_module_granularity_threshold=zero_module_granularity_threshold,
252+
log_trace_cache_warnings=log_trace_cache_warnings,
253+
)
251254

252255
self.persistent_parameters = self.parameter_offload.persistent_parameters
253256
self._configure_offloading(offload_optimizer_config, offload_param_config)
@@ -486,6 +489,7 @@ def initialize_ds_offload(
486489
zero_quantized_weights,
487490
zero_quantized_nontrainable_weights,
488491
zero_module_granularity_threshold,
492+
log_trace_cache_warnings,
489493
):
490494
return DeepSpeedZeRoOffload(module=module,
491495
timers=timers,
@@ -502,7 +506,8 @@ def initialize_ds_offload(
502506
zero_param_parallel_group=zero_param_parallel_group,
503507
zero_quantized_weights=zero_quantized_weights,
504508
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
505-
zero_module_granularity_threshold=zero_module_granularity_threshold)
509+
zero_module_granularity_threshold=zero_module_granularity_threshold,
510+
log_trace_cache_warnings=log_trace_cache_warnings)
506511

507512
def _get_trainable_parameter_groups(self):
508513
param_groups = []

docs/_pages/config-json.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations
371371
"sub_group_size" : 1e12,
372372
"elastic_checkpoint" : [true|false],
373373
"stage3_gather_16bit_weights_on_model_save": [true|false],
374-
"ignore_unused_parameters": [true|false]
375-
"round_robin_gradients": [true|false]
376-
"zero_hpz_partition_size": 1
377-
"zero_quantized_weights": [true|false]
378-
"zero_quantized_gradients": [true|false]
374+
"ignore_unused_parameters": [true|false],
375+
"round_robin_gradients": [true|false],
376+
"zero_hpz_partition_size": 1,
377+
"zero_quantized_weights": [true|false],
378+
"zero_quantized_gradients": [true|false],
379+
"log_trace_cache_warnings": [true|false],
379380
}
380381
```
381382

@@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations
512513
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
513514
|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` |
514515

516+
<i>**log_trace_cache_warnings**</i>: [boolean]
517+
518+
| Description | Default |
519+
| ------------------------------------------------------------------------------------------------------------------- | ------- |
520+
| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` |
521+
515522
***cpu_offload***: [boolean]
516523

517524
**Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.

0 commit comments

Comments
 (0)