@@ -76,17 +76,20 @@ class __ParamInTrace:
7676 param : Parameter
7777 step_id_last_used_at : int
7878
79- def __init__ (self ,
80- prefetch_bucket_sz : int ,
81- max_reuse_distance_in_numel : int ,
82- max_available_parameters_in_numel : int ,
83- allgather_stream : get_accelerator ().Stream ,
84- inflight_param_registry : InflightParamRegistry ,
85- prefetch_nvme : bool = False ,
86- timers = None ,
87- zero_quantized_weights = False ,
88- zero_quantized_nontrainable_weights = False ,
89- fast_sharding_for_leaf_module = False ) -> None :
79+ def __init__ (
80+ self ,
81+ prefetch_bucket_sz : int ,
82+ max_reuse_distance_in_numel : int ,
83+ max_available_parameters_in_numel : int ,
84+ allgather_stream : get_accelerator ().Stream ,
85+ inflight_param_registry : InflightParamRegistry ,
86+ prefetch_nvme : bool = False ,
87+ timers = None ,
88+ zero_quantized_weights = False ,
89+ zero_quantized_nontrainable_weights = False ,
90+ fast_sharding_for_leaf_module = False ,
91+ log_trace_cache_warnings = False ,
92+ ) -> None :
9093 # mapping of param -> handle for each param that is currently in flight
9194 self .__inflight_param_registry = inflight_param_registry
9295 # keeps track of the number of submodules invoked so far.
@@ -129,6 +132,9 @@ def __init__(self,
129132 self .__max_ongoing_fetch_events : int = 2
130133 self .__profiler = PartitionedParameterProfiler (timers if ENABLE_PROFILER else None )
131134
135+ # Whether to log trace cache warnings, e.g. invalidation events
136+ self .__log_trace_cache_warnings = log_trace_cache_warnings
137+
132138 # whether to enable fast fetch for the z3 leaf module.
133139 # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
134140 self .fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
@@ -177,7 +183,7 @@ def trace_prologue(self, sub_module: Module) -> None:
177183 print_rank_0 (
178184 f"Invalidate trace cache @ step { self .__step_id } and module { sub_module .ds_id } : "
179185 f"cache has only { len (self .__submodule_order )} modules" ,
180- force = True )
186+ force = self . __log_trace_cache_warnings )
181187 self ._invalidate_trace ()
182188 return
183189
@@ -186,7 +192,7 @@ def trace_prologue(self, sub_module: Module) -> None:
186192 print_rank_0 (
187193 f"Invalidate trace cache @ step { self .__step_id } : "
188194 f"expected module { expected_module_id } , but got module { sub_module .ds_id } " ,
189- force = True )
195+ force = self . __log_trace_cache_warnings )
190196 self ._invalidate_trace ()
191197
192198 @compiler .disable
0 commit comments