Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFix] Add weights_only to torch.load #8105

Merged
merged 9 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def load_data(traindir, valdir, args):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path)
# TODO: this could probably be weights_only=True
dataset, _ = torch.load(cache_path, weights_only=False)
else:
# We need a default value for the variables below because args may come
# from train_quantization.py which doesn't define them.
Expand Down Expand Up @@ -159,7 +160,8 @@ def load_data(traindir, valdir, args):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
# TODO: this could probably be weights_only=True
dataset_test, _ = torch.load(cache_path, weights_only=False)
else:
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
Expand Down Expand Up @@ -337,7 +339,7 @@ def collate_fn(batch):
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)

if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
Expand Down
2 changes: 1 addition & 1 deletion references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(args):
model_without_ddp = model.module

if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Expand Down
5 changes: 2 additions & 3 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@ def average_checkpoints(inputs):
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
)
# Copies over the settings from the first checkpoint
if new_state is None:
Expand Down Expand Up @@ -367,7 +366,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T

# Deep copy to avoid side effects on the model object.
model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc.)
Expand Down
2 changes: 1 addition & 1 deletion references/depth/stereo/cascade_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def load_checkpoint(args):
utils.setup_ddp(args)

if not args.weights:
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
if "model" in checkpoint:
experiment_args = checkpoint["args"]
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
Expand Down
2 changes: 1 addition & 1 deletion references/depth/stereo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def main(args):
# load them from checkpoint if needed
args.start_step = 0
if args.resume_path is not None:
checkpoint = torch.load(args.resume_path, map_location="cpu")
checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
if "model" in checkpoint:
# this means the user requested to resume from a training checkpoint
model_without_ddp.load_state_dict(checkpoint["model"])
Expand Down
2 changes: 1 addition & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def main(args):
)

if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Expand Down
2 changes: 1 addition & 1 deletion references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def main(args):
model_without_ddp = model

if args.resume is not None:
checkpoint = torch.load(args.resume, map_location="cpu")
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])

if args.test_only:
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def main(args):
lr_scheduler = main_lr_scheduler

if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
Expand Down
2 changes: 1 addition & 1 deletion references/similarity/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def main(args):

model = EmbeddingNet()
if args.resume:
model.load_state_dict(torch.load(args.resume))
model.load_state_dict(torch.load(args.resume, weights_only=True))

model.to(device)

Expand Down
6 changes: 3 additions & 3 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def main(args):

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path)
dataset, _ = torch.load(cache_path, weights_only=True)
dataset.transform = transform_train
else:
if args.distributed:
Expand Down Expand Up @@ -201,7 +201,7 @@ def main(args):

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
dataset_test, _ = torch.load(cache_path, weights_only=True)
dataset_test.transform = transform_test
else:
if args.distributed:
Expand Down Expand Up @@ -295,7 +295,7 @@ def main(args):
model_without_ddp = model.module

if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Expand Down
3 changes: 2 additions & 1 deletion test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,8 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p)

true_cv2_results = torch.load(p, weights_only=False)

if image_size == "small":
tensor = (
Expand Down
6 changes: 3 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
else:
expected = torch.load(expected_file)
expected = torch.load(expected_file, weights_only=True)
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
atol = atol or prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
Expand Down Expand Up @@ -747,7 +747,7 @@ def check_out(out):
# so instead of validating the probability scores, check that the class
# predictions match.
expected_file = _get_expected_file(model_name)
expected = torch.load(expected_file)
expected = torch.load(expected_file, weights_only=True)
torch.testing.assert_close(
out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
)
Expand Down Expand Up @@ -847,7 +847,7 @@ def compute_mean_std(tensor):
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
# scores.
expected_file = _get_expected_file(model_name)
expected = torch.load(expected_file)
expected = torch.load(expected_file, weights_only=True)
torch.testing.assert_close(
output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_save_load(self, dataset_mock, config):
with io.BytesIO() as buffer:
torch.save(sample, buffer)
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)
assert_samples_equal(torch.load(buffer, weights_only=True), sample)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, dataset_mock, config):
Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,7 +3176,8 @@ def test__get_params(self, sigma):
# "26_28_1__23_23_1.7": cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7),
# }
REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS = torch.load(
Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt"
Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt",
weights_only=False,
)

@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_flow_to_image(batch):
assert img.shape == (2, 3, h, w) if batch else (3, h, w)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu")
expected_img = torch.load(path, map_location="cpu", weights_only=True)

if batch:
expected_img = torch.stack([expected_img, expected_img])
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
file = os.path.join(root, file)

if check_integrity(file):
return torch.load(file)
return torch.load(file, weights_only=True)
else:
msg = (
"The meta file {} is not present in the root directory or is corrupted. "
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)

def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
self.cache()

# load the serialized data
self.data, self.labels, self.matches = torch.load(self.data_file)
self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)

def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
"""
Expand Down
Loading