Skip to content

Commit 5246cda

Browse files
authored
Bug fix in trial.py has_passed_step (#140)
* Fix for a bug of last complete step * More logs and fixes to trial and index reader
1 parent e29065b commit 5246cda

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

smdebug/core/index_reader.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def read_index_files(
310310
self.logger.debug(f'Loaded Index Files: {",".join(index_files)}')
311311
for index_file in index_files:
312312
if self.index_file_cache.has_not_read(index_file):
313+
313314
step = IndexFileLocationUtils.parse_step_from_index_file_name(index_file)
314315
if (
315316
range_steps is not None and step_in_range(range_steps, step)
@@ -319,9 +320,15 @@ def read_index_files(
319320
object_requests.append(
320321
ReadObjectRequest(format(f"s3://{self.bucket_name}/") + index_file)
321322
)
322-
self.index_file_cache.add(index_file, start_after_key)
323+
self.logger.debug(f"Will read index_file: {index_file}")
324+
self.index_file_cache.add(index_file, start_after_key)
325+
else:
326+
self.logger.debug(
327+
f"index_file:{index_file} Indexcache contents:{self.index_file_cache.lookup_set}"
328+
)
323329

324330
responses = S3Handler.get_objects(object_requests)
331+
assert len(responses) == len(object_requests)
325332
return responses, steps, start_after_key, workers
326333

327334
def list_index_files(self, start_after_key=None):
@@ -416,7 +423,11 @@ def read_index_files(
416423
start_after_index = bisect_left(index_files, start_after_key)
417424
else:
418425
start_after_index = 0
426+
self.logger.debug(f"Found index_files:{index_files}")
419427
index_files = index_files[start_after_index:] # ignore files we have already read
428+
self.logger.debug(
429+
f"Curtailed Found index_files to :{index_files} start_after_index:{start_after_index} start_after_key:{start_after_key}"
430+
)
420431
for index_file in index_files:
421432
if self.index_file_cache.has_not_read(index_file):
422433
step = IndexFileLocationUtils.parse_step_from_index_file_name(index_file)
@@ -428,9 +439,15 @@ def read_index_files(
428439
self.logger.debug(
429440
f"Sagemaker-Debugger: Read {os.path.getsize(index_file)} bytes from file {index_file}"
430441
)
442+
self.logger.debug(f"Will read index file:{index_file}")
431443
with open(index_file) as f:
432444
responses.append(f.read().encode())
433-
self.index_file_cache.add(index_file, start_after_key)
445+
self.index_file_cache.add(index_file, start_after_key)
446+
else:
447+
self.logger.debug(
448+
f"IndexFile:{index_file} Indexcache contents:{self.index_file_cache.lookup_set}"
449+
)
450+
434451
if len(index_files) > 0:
435452
start_after_key = index_files[-1] # Last file that we have read
436453
return responses, steps, start_after_key, workers

smdebug/trials/trial.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,14 @@ def _populate_workers_for_global_step(self, step, worker) -> None:
253253
if step not in self.workers_for_global_step:
254254
self.workers_for_global_step[step] = set()
255255
self.workers_for_global_step[step].add(worker)
256+
self.logger.debug(f"Populated workers for global step:{step} worker: {worker}")
257+
256258
if (
257259
len(self.workers_for_global_step[step]) == self.num_workers
258260
and step > self.last_complete_step
259261
):
260262
self.last_complete_step = step
263+
self.logger.debug(f"Populating last completing step to: {step}")
261264

262265
def _populate_global_step_to_tensor_name_map(self, tensor: TensorLocation, step_num) -> None:
263266
"""
@@ -514,13 +517,14 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
514517
"""
515518
all_steps = self.steps(mode=mode, show_incomplete_steps=True)
516519
bisect_idx = bisect_left(all_steps, step)
520+
g_step = self._global_step_currently(mode, step)
521+
517522
if bisect_idx < len(all_steps):
518523
if all_steps[bisect_idx] > step:
519-
if self.last_complete_step > step:
524+
if self.last_complete_step > g_step:
520525
return StepState.UNAVAILABLE
521526
return StepState.NOT_YET_AVAILABLE
522527
elif all_steps[bisect_idx] == step:
523-
g_step = self.global_step(mode, step)
524528
if len(self.workers_for_global_step[g_step]) == self.num_workers:
525529
return StepState.AVAILABLE
526530
elif self.loaded_all_steps is True:
@@ -531,9 +535,9 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
531535
f"Step {step} of mode {mode} was marked complete because the job is complete"
532536
)
533537
return StepState.AVAILABLE
534-
elif step <= self.last_complete_step:
538+
elif g_step <= self.last_complete_step:
535539
self.logger.info(
536-
f"Step {step} of mode {mode} was written only by workers: {self.workers_for_global_step[step]}"
540+
f"Step {step} of mode {mode} was written only by workers: {self.workers_for_global_step[g_step]}"
537541
)
538542
self.logger.info(
539543
f"Step {step} of mode {mode} was marked complete because the last complete step is {self.last_complete_step}"
@@ -552,7 +556,7 @@ def _load_tensors(self):
552556
def _update_last_index_token(self, new_index_token: str) -> None:
553557
"""
554558
This function updates the last_index_token in the following scenarios:
555-
1. last_complete_step > last_index_token_step :
559+
1. last_complete_step >= last_index_token_step :
556560
this means that the token isn't pointing to the latest completed step
557561
2. number of steps available ( complete or incomplete ) - (last_completed_step+1) > window_size_limit:
558562
we maintain a window to stop querying for older steps that have not completed.
@@ -569,7 +573,7 @@ def _update_last_index_token(self, new_index_token: str) -> None:
569573
)
570574

571575
# Case 1:
572-
if self.last_complete_step > last_index_token_step:
576+
if self.last_complete_step >= last_index_token_step:
573577
prefix = IndexFileLocationUtils.get_prefix_from_index_file(new_index_token)
574578
# sort lexicographically and select the last worker
575579
last_worker = sorted(list(self.worker_set))[-1]
@@ -579,6 +583,7 @@ def _update_last_index_token(self, new_index_token: str) -> None:
579583
self.last_index_token = IndexFileLocationUtils.get_index_key_for_step(
580584
prefix, self.last_complete_step, last_worker_serialized
581585
)
586+
self.logger.debug(f"Updated last index token to:{self.last_index_token}")
582587

583588
# Case 2:
584589
available_step = self._global_to_mode.keys()

0 commit comments

Comments
 (0)