Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect use of collections.OrderedDict in type annotation #1141

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import collections
import os
import tempfile
from typing import Union, Mapping

import pkg_resources
import collections
import torch
from torch import nn, Tensor

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
from super_gradients.common.data_types import StrictLoad
from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.common.data_types import StrictLoad
from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master

from torch import nn, Tensor
from typing import Union, Mapping

try:
from torch.hub import download_url_to_file, load_state_dict_from_url
except (ModuleNotFoundError, ImportError, NameError):
Expand Down Expand Up @@ -127,7 +127,7 @@ def copy_ckpt_to_local_folder(
return ckpt_file_full_local_path


def read_ckpt_state_dict(ckpt_path: str, device="cpu") -> collections.OrderedDict[str, torch.Tensor]:
def read_ckpt_state_dict(ckpt_path: str, device="cpu") -> Mapping[str, torch.Tensor]:
"""
Reads a checkpoint state dict from a given path or url

Expand Down