@@ -253,11 +253,14 @@ def _populate_workers_for_global_step(self, step, worker) -> None:
253253 if step not in self .workers_for_global_step :
254254 self .workers_for_global_step [step ] = set ()
255255 self .workers_for_global_step [step ].add (worker )
256+ self .logger .debug (f"Populated workers for global step:{ step } worker: { worker } " )
257+
256258 if (
257259 len (self .workers_for_global_step [step ]) == self .num_workers
258260 and step > self .last_complete_step
259261 ):
260262 self .last_complete_step = step
263+ self .logger .debug (f"Populating last completing step to: { step } " )
261264
262265 def _populate_global_step_to_tensor_name_map (self , tensor : TensorLocation , step_num ) -> None :
263266 """
@@ -514,13 +517,14 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
514517 """
515518 all_steps = self .steps (mode = mode , show_incomplete_steps = True )
516519 bisect_idx = bisect_left (all_steps , step )
520+ g_step = self ._global_step_currently (mode , step )
521+
517522 if bisect_idx < len (all_steps ):
518523 if all_steps [bisect_idx ] > step :
519- if self .last_complete_step > step :
524+ if self .last_complete_step > g_step :
520525 return StepState .UNAVAILABLE
521526 return StepState .NOT_YET_AVAILABLE
522527 elif all_steps [bisect_idx ] == step :
523- g_step = self .global_step (mode , step )
524528 if len (self .workers_for_global_step [g_step ]) == self .num_workers :
525529 return StepState .AVAILABLE
526530 elif self .loaded_all_steps is True :
@@ -531,9 +535,9 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
531535 f"Step { step } of mode { mode } was marked complete because the job is complete"
532536 )
533537 return StepState .AVAILABLE
534- elif step <= self .last_complete_step :
538+ elif g_step <= self .last_complete_step :
535539 self .logger .info (
536- f"Step { step } of mode { mode } was written only by workers: { self .workers_for_global_step [step ]} "
540+ f"Step { step } of mode { mode } was written only by workers: { self .workers_for_global_step [g_step ]} "
537541 )
538542 self .logger .info (
539543 f"Step { step } of mode { mode } was marked complete because the last complete step is { self .last_complete_step } "
@@ -552,7 +556,7 @@ def _load_tensors(self):
552556 def _update_last_index_token (self , new_index_token : str ) -> None :
553557 """
554558 This function updates the last_index_token in the following scenarios:
555- 1. last_complete_step > last_index_token_step :
559+ 1. last_complete_step >= last_index_token_step :
556560 this means that the token isn't pointing to the latest completed step
557561 2. number of steps available ( complete or incomplete ) - (last_completed_step+1) > window_size_limit:
558562 we maintain a window to stop querying for older steps that have not completed.
@@ -569,7 +573,7 @@ def _update_last_index_token(self, new_index_token: str) -> None:
569573 )
570574
571575 # Case 1:
572- if self .last_complete_step > last_index_token_step :
576+ if self .last_complete_step >= last_index_token_step :
573577 prefix = IndexFileLocationUtils .get_prefix_from_index_file (new_index_token )
574578 # sort lexicographically and select the last worker
575579 last_worker = sorted (list (self .worker_set ))[- 1 ]
@@ -579,6 +583,7 @@ def _update_last_index_token(self, new_index_token: str) -> None:
579583 self .last_index_token = IndexFileLocationUtils .get_index_key_for_step (
580584 prefix , self .last_complete_step , last_worker_serialized
581585 )
586+ self .logger .debug (f"Updated last index token to:{ self .last_index_token } " )
582587
583588 # Case 2:
584589 available_step = self ._global_to_mode .keys ()
0 commit comments