Skip to content

Commit f4b0f58

Browse files
tjruwasegyou2021
authored andcommitted
Control trace cache warnings (deepspeedai#7039)
Make trace cache warnings configurable, and disabled by default. Fix deepspeedai#6985, deepspeedai#4081, deepspeedai#5033, deepspeedai#5006, deepspeedai#5662 --------- Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: gyou2021 <ganmei.you@intel.com>
1 parent ba8ef57 commit f4b0f58

File tree

6 files changed

+54
-21
lines changed

6 files changed

+54
-21
lines changed

deepspeed/runtime/engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,9 @@ def zero_quantized_gradients(self):
983983
def zeropp_loco_param(self):
984984
return self._config.zero_config.zeropp_loco_param
985985

986+
def zero_log_trace_cache_warnings(self):
987+
return self._config.zero_config.log_trace_cache_warnings
988+
986989
def dump_state(self):
987990
return self._config.dump_state
988991

@@ -1692,6 +1695,7 @@ def _configure_zero_optimizer(self, optimizer):
16921695
zero_quantized_weights=self.zero_quantized_weights(),
16931696
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
16941697
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
1698+
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
16951699
)
16961700
else:
16971701
log_dist(
@@ -1740,6 +1744,7 @@ def _configure_zero_optimizer(self, optimizer):
17401744
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
17411745
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
17421746
zeropp_loco_param=self.zeropp_loco_param(),
1747+
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
17431748
)
17441749

17451750
else:

deepspeed/runtime/zero/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"memory_efficient_linear": [true|false],
4646
"override_module_apply": [true|false],
4747
"zeropp_loco_param": {...},
48+
"log_trace_cache_warnings" : [true|false],
4849
}
4950
}
5051
"""
@@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
340341
Override nn.Module apply function, for Stage 3.
341342
"""
342343

344+
log_trace_cache_warnings: bool = False
345+
"""
346+
Whether to log warnings from trace cache, such as invalidation events.
347+
"""
348+
343349
# Validators
344350
@model_validator(mode="after")
345351
def overlap_comm_valid(self):

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
zero_quantized_weights=False,
104104
zero_quantized_nontrainable_weights=False,
105105
zero_module_granularity_threshold=0,
106+
log_trace_cache_warnings=False,
106107
):
107108

108109
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
@@ -118,6 +119,7 @@ def __init__(
118119
self.zero_param_parallel_group = zero_param_parallel_group
119120
self.zero_quantized_weights = zero_quantized_weights
120121
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
122+
self.log_trace_cache_warnings = log_trace_cache_warnings
121123

122124
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
123125
self.offload_device = offload_param_config.device
@@ -165,7 +167,9 @@ def __init__(
165167
timers=self.timers,
166168
zero_quantized_weights=self.zero_quantized_weights,
167169
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
168-
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
170+
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module,
171+
log_trace_cache_warnings=self.log_trace_cache_warnings,
172+
)
169173

170174
self.forward_hooks = []
171175
self.backward_hooks = []

deepspeed/runtime/zero/partitioned_param_coordinator.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,20 @@ class __ParamInTrace:
7676
param: Parameter
7777
step_id_last_used_at: int
7878

79-
def __init__(self,
80-
prefetch_bucket_sz: int,
81-
max_reuse_distance_in_numel: int,
82-
max_available_parameters_in_numel: int,
83-
allgather_stream: get_accelerator().Stream,
84-
inflight_param_registry: InflightParamRegistry,
85-
prefetch_nvme: bool = False,
86-
timers=None,
87-
zero_quantized_weights=False,
88-
zero_quantized_nontrainable_weights=False,
89-
fast_sharding_for_leaf_module=False) -> None:
79+
def __init__(
80+
self,
81+
prefetch_bucket_sz: int,
82+
max_reuse_distance_in_numel: int,
83+
max_available_parameters_in_numel: int,
84+
allgather_stream: get_accelerator().Stream,
85+
inflight_param_registry: InflightParamRegistry,
86+
prefetch_nvme: bool = False,
87+
timers=None,
88+
zero_quantized_weights=False,
89+
zero_quantized_nontrainable_weights=False,
90+
fast_sharding_for_leaf_module=False,
91+
log_trace_cache_warnings=False,
92+
) -> None:
9093
# mapping of param -> handle for each param that is currently in flight
9194
self.__inflight_param_registry = inflight_param_registry
9295
# keeps track of the number of submodules invoked so far.
@@ -129,6 +132,9 @@ def __init__(self,
129132
self.__max_ongoing_fetch_events: int = 2
130133
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)
131134

135+
# Whether to log trace cache warnings, e.g. invalidation events
136+
self.__log_trace_cache_warnings = log_trace_cache_warnings
137+
132138
# whether to enable fast fetch for the z3 leaf module.
133139
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
134140
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
@@ -177,7 +183,7 @@ def trace_prologue(self, sub_module: Module) -> None:
177183
print_rank_0(
178184
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
179185
f"cache has only {len(self.__submodule_order)} modules",
180-
force=True)
186+
force=self.__log_trace_cache_warnings)
181187
self._invalidate_trace()
182188
return
183189

@@ -186,7 +192,7 @@ def trace_prologue(self, sub_module: Module) -> None:
186192
print_rank_0(
187193
f"Invalidate trace cache @ step {self.__step_id}: "
188194
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
189-
force=True)
195+
force=self.__log_trace_cache_warnings)
190196
self._invalidate_trace()
191197

192198
@compiler.disable

deepspeed/runtime/zero/stage3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160
zero_quantized_nontrainable_weights=False,
161161
zero_module_granularity_threshold=0,
162162
zeropp_loco_param=None,
163+
log_trace_cache_warnings=False,
163164
):
164165
see_memory_usage("Stage 3 initialize beginning", force=True)
165166

@@ -231,7 +232,9 @@ def __init__(
231232
zero_param_parallel_group=zero_param_parallel_group,
232233
zero_quantized_weights=zero_quantized_weights,
233234
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
234-
zero_module_granularity_threshold=zero_module_granularity_threshold)
235+
zero_module_granularity_threshold=zero_module_granularity_threshold,
236+
log_trace_cache_warnings=log_trace_cache_warnings,
237+
)
235238

236239
self.persistent_parameters = self.parameter_offload.persistent_parameters
237240
self._configure_offloading(offload_optimizer_config, offload_param_config)
@@ -465,6 +468,7 @@ def initialize_ds_offload(
465468
zero_quantized_weights,
466469
zero_quantized_nontrainable_weights,
467470
zero_module_granularity_threshold,
471+
log_trace_cache_warnings,
468472
):
469473
return DeepSpeedZeRoOffload(module=module,
470474
timers=timers,
@@ -481,7 +485,8 @@ def initialize_ds_offload(
481485
zero_param_parallel_group=zero_param_parallel_group,
482486
zero_quantized_weights=zero_quantized_weights,
483487
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
484-
zero_module_granularity_threshold=zero_module_granularity_threshold)
488+
zero_module_granularity_threshold=zero_module_granularity_threshold,
489+
log_trace_cache_warnings=log_trace_cache_warnings)
485490

486491
def _get_trainable_parameter_groups(self):
487492
param_groups = []

docs/_pages/config-json.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations
371371
"sub_group_size" : 1e12,
372372
"elastic_checkpoint" : [true|false],
373373
"stage3_gather_16bit_weights_on_model_save": [true|false],
374-
"ignore_unused_parameters": [true|false]
375-
"round_robin_gradients": [true|false]
376-
"zero_hpz_partition_size": 1
377-
"zero_quantized_weights": [true|false]
378-
"zero_quantized_gradients": [true|false]
374+
"ignore_unused_parameters": [true|false],
375+
"round_robin_gradients": [true|false],
376+
"zero_hpz_partition_size": 1,
377+
"zero_quantized_weights": [true|false],
378+
"zero_quantized_gradients": [true|false],
379+
"log_trace_cache_warnings": [true|false],
379380
}
380381
```
381382

@@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations
512513
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
513514
|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` |
514515

516+
<i>**log_trace_cache_warnings**</i>: [boolean]
517+
518+
| Description | Default |
519+
| ------------------------------------------------------------------------------------------------------------------- | ------- |
520+
| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` |
521+
515522
***cpu_offload***: [boolean]
516523

517524
**Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.

0 commit comments

Comments
 (0)