Skip to content

Commit

Permalink
Still working out daemon issues main vs imported
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Dec 24, 2024
1 parent 8232708 commit 428cc1d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
18 changes: 15 additions & 3 deletions credit/datasets/era5_multistep_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ def __getitem__(self, _):
if value.ndimension() == 0:
value = value.unsqueeze(0) # Unsqueeze to make it a 1D tensor

# add the time, which is 1 in all datasets in this example
# this is needed since we use DataLoaderLite, which does not add
# the extra dimension and a mix up between batch and time-dim happens.
if value.ndim in (4, 5):
value = value.unsqueeze(1)

if key not in batch:
batch[key] = value # Initialize the key in the batch dictionary
else:
Expand Down Expand Up @@ -567,17 +573,23 @@ def worker_process(self, k, index_pair, result_dict):
result_dict[k] = sample # Store the result keyed by its order index
except FileNotFoundError:
# Log the error but continue processing
logger.warning(f"Ignoring transient connection error for index {k}.")
# logger.warning(f"Ignoring transient connection error for index {k}.")
self.stop_event.set()
return
except Exception as e:
logger.warning(f"Error in worker process for index {k}: {e}.\n"
"This is likely due to the end of training, or you killed the program. Exiting!\n"
"Not what you expected? Email schreck@ucar.edu for support")
self.stop_event.set()
return
# logger.error(f"Error in worker process for index {k}: {e}")
# raise RuntimeError(f"Error in worker process for index {k}: {e}") from e
# Ensure proper cleanup before exiting worker process
if hasattr(self, 'shutdown') and callable(self.shutdown):
logger.info("Initiating shutdown sequence.")
self.shutdown()
return
except:
raise RuntimeError(f"Error in worker process for index {k}: {e}") from e


def _fetch_batch(self):
"""
Expand Down
9 changes: 8 additions & 1 deletion credit/datasets/load_dataset_and_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def __len__(self):
return len(self.dataset) # Otherwise, fall back to the dataset's length


def collate_fn(batch):
# Only used with ERA5_MultiStep_Batcher
# Prevents time and batch dimension from getting flipped
return batch[0]


def load_dataset(conf, rank=0, world_size=1, is_train=True):
"""
Load the dataset based on the configuration.
Expand Down Expand Up @@ -220,6 +226,7 @@ def load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True):
dataloader = DataLoader(
dataset,
num_workers=1, # Must be 1 to use prefetching
collate_fn=collate_fn,
prefetch_factor=prefetch_factor
)
elif type(dataset) is MultiprocessingBatcher:
Expand Down Expand Up @@ -315,7 +322,7 @@ def load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True):

# Iterate through the dataloader and print samples
for (k, sample) in enumerate(dataloader):
print(k, sample['index'], sample['datetime'], sample['forecast_step'], sample['stop_forecast'])
print(k, sample['index'], sample['datetime'], sample['forecast_step'], sample['stop_forecast'], sample["x"].shape, sample["x_surf"].shape)
if k == 20:
break

Expand Down

0 comments on commit 428cc1d

Please sign in to comment.