Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions smdebug/trials/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def tensor_names(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collectio
ts.update(self.mode_to_tensors_map[mode])
else:
ts.update(self._tensors_for_step(step, mode))
self.logger.debug(
f"getting tensor_names with params: step:{step} mode:{mode} regex:{regex} collection:{collection}"
)

if regex is None and collection is None:
return sorted(list(ts))
Expand All @@ -357,7 +360,9 @@ def tensor_names(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collectio
xs = self._tensors_matching_regex(regex)
matching_tensors_saved = ts.intersection(xs)
if len(matching_tensors_saved) == 0:
self.logger.warning(f"No tensors matching the regex pattern given were saved")
self.logger.warning(
f"No tensors matching the regex pattern:{regex} given were saved"
)
return sorted(list(matching_tensors_saved))

def _tensors_for_step(self, step, mode=ModeKeys.GLOBAL) -> list:
Expand Down Expand Up @@ -515,7 +520,8 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
return StepState.UNAVAILABLE
return StepState.NOT_YET_AVAILABLE
elif all_steps[bisect_idx] == step:
if len(self.workers_for_global_step[step]) == self.num_workers:
g_step = self.global_step(mode, step)
if len(self.workers_for_global_step[g_step]) == self.num_workers:
return StepState.AVAILABLE
elif self.loaded_all_steps is True:
self.logger.info(
Expand Down