Skip to content

Commit

Permalink
Merge pull request #9 from diningphil/torch-load-fix
Browse files Browse the repository at this point in the history
modifying torch.load to add weights_only=True
  • Loading branch information
diningphil authored Sep 10, 2024
2 parents 3c9527b + 0a15421 commit f416bc9
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 18 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Changelog

## [1.0.1] Minor Improvements

## Changed

- Improvements to workflow files
- Added `weights_only=True` to all `torch.load()` calls to address torch warnings

## [1.0.0] First Release
4 changes: 2 additions & 2 deletions mlwiz/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(

else:
# Simply load the dataset in memory
self.dataset = torch.load(self.dataset_filepath)
self.dataset = torch.load(self.dataset_filepath, weights_only=True)

@property
def name(self) -> str:
Expand Down Expand Up @@ -508,7 +508,7 @@ def __iter__(self):
else len(self.shuffled_urls)
)
for url in self.shuffled_urls[self.start_index : end_index]:
url_data = torch.load(url)
url_data = torch.load(url, weights_only=True)

if not isinstance(url_data, list):
url_data = [url_data]
Expand Down
2 changes: 1 addition & 1 deletion mlwiz/data/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def load(cls, path: str):
Returns:
a :class:`~mlwiz.data.splitter.Splitter` object
"""
splits = torch.load(path)
splits = torch.load(path, weights_only=True)

splitter_classname = splits.get("splitter_class", "Splitter")
splitter_class = s2c(splitter_classname)
Expand Down
2 changes: 1 addition & 1 deletion mlwiz/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def load_dataset(
storage_folder, dataset_name, "processed", "dataset_kwargs.pt"
)

dataset_args = torch.load(kwargs_path)
dataset_args = torch.load(kwargs_path, weights_only=True)

# Overwrite original data_root field, which may have changed
dataset_args["storage_folder"] = storage_folder
Expand Down
22 changes: 11 additions & 11 deletions mlwiz/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import ray
import requests
import torch
from torch_geometric.data import Data
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage

from mlwiz.data.provider import DataProvider
from mlwiz.evaluation.config import Config
Expand Down Expand Up @@ -102,7 +105,7 @@ def run_valid(
print(e)
elapsed = -1
else:
_, _, elapsed = torch.load(fold_results_torch_path)
_, _, elapsed = torch.load(fold_results_torch_path, weights_only=True)
return (
dataset_getter.outer_k,
dataset_getter.inner_k,
Expand Down Expand Up @@ -171,7 +174,7 @@ def run_test(
print(e)
elapse = -1
else:
res = torch.load(final_run_torch_path)
res = torch.load(final_run_torch_path, weights_only=True)
elapsed = res[-1]
return outer_k, run_id, elapsed

Expand Down Expand Up @@ -235,6 +238,9 @@ def __init__(
torch.cuda.manual_seed(self.base_seed)
random.seed(self.base_seed)

# Add Data to serializable objects
torch.serialization.add_safe_globals([Data, DataEdgeAttr, DataTensorAttr, GlobalStorage])

self.outer_folds = outer_folds
self.inner_folds = inner_folds
self.experiment_class = experiment_class
Expand Down Expand Up @@ -630,9 +636,6 @@ def model_selection(self, kfold_folder: str, outer_k: int, debug: bool):
),
fold_run_results_torch_path,
)
# else:
# res = torch.load(fold_results_torch_path)
# elapsed = res[-1]

if debug:
self.process_model_selection_runs(fold_exp_folder, k)
Expand Down Expand Up @@ -748,9 +751,6 @@ def run_final_model(self, outer_k: int, debug: bool):
(training_res, val_res, test_res, elapsed),
final_run_torch_path,
)
# else:
# res = torch.load(final_run_torch_path)
# elapsed = res[-1]
if debug:
self.process_final_runs(outer_k)

Expand Down Expand Up @@ -782,7 +782,7 @@ def process_model_selection_runs(self, fold_exp_folder: str, inner_k: int):
)

training_res, validation_res, _ = torch.load(
fold_run_results_torch_path
fold_run_results_torch_path, weights_only=True
)

training_loss, validation_loss = (
Expand Down Expand Up @@ -887,7 +887,7 @@ def process_config(self, config_folder: str, config: Config):
)

training_res, validation_res, _ = torch.load(
fold_results_torch_path
fold_results_torch_path, weights_only=True
)

training_loss, validation_loss = (
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def process_final_runs(self, outer_k: int):
final_run_torch_path = osp.join(
final_run_exp_path, f"run_{i + 1}_results.torch"
)
res = torch.load(final_run_torch_path)
res = torch.load(final_run_torch_path, weights_only=True)

tr_res, vl_res, te_res = {}, {}, {}

Expand Down
3 changes: 2 additions & 1 deletion mlwiz/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ def load_checkpoint(
:param device (torch.device): the device, e.g, "cpu" or "cuda"
"""
ckpt_dict = torch.load(
checkpoint_path, map_location="cpu" if device == "cpu" else None
checkpoint_path, map_location="cpu" if device == "cpu" else None,
weights_only = True
)
model_state = ckpt_dict[MODEL_STATE]

Expand Down
2 changes: 1 addition & 1 deletion mlwiz/training/callback/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.stored_metrics = {"losses": {}, "scores": {}}
self.stored_metrics_path = Path(self.exp_path, "metrics_data.torch")
if os.path.exists(self.stored_metrics_path):
self.stored_metrics = torch.load(self.stored_metrics_path)
self.stored_metrics = torch.load(self.stored_metrics_path, weights_only=True)

def on_epoch_end(self, state: State):
"""
Expand Down
4 changes: 3 additions & 1 deletion mlwiz/training/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,8 @@ def _restore_checkpoint_and_best_results(
# specify explicitly the map location as cpu. The other way around
# (cpu to cuda) poses no problem since GPUs are visible.
ckpt_dict = torch.load(
ckpt_filename, map_location="cpu" if self.device == "cpu" else None
ckpt_filename, map_location="cpu" if self.device == "cpu" else None,
weights_only = True
)

self.state.update(
Expand All @@ -803,6 +804,7 @@ def _restore_checkpoint_and_best_results(
best_ckpt_dict = torch.load(
best_ckpt_filename,
map_location="cpu" if self.device == "cpu" else None,
weights_only = True
)
self.state.update(best_epoch_results=best_ckpt_dict)

Expand Down

0 comments on commit f416bc9

Please sign in to comment.