Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added warning message for dataset license #1846

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/super_gradients/training/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
"yolo_nas_pose_l_coco_pose": "https://sghub.deci.ai/models/yolo_nas_pose_l_coco_pose.pth",
}


PRETRAINED_NUM_CLASSES = {
"imagenet": 1000,
"imagenet21k": 21843,
Expand All @@ -71,3 +70,13 @@
"coco_pose": 17,
"cifar10": 10,
}

DATASET_LICENSES = {
"imagenet": "https://www.image-net.org/download.php",
"imagenet21k": "https://github.com/Alibaba-MIIL/ImageNet21K",
"coco": "https://cocodataset.org/#termsofuse",
"coco_segmentation_subclass": "https://cocodataset.org/#termsofuse",
"coco_pose": "https://cocodataset.org/#termsofuse",
"cityscapes": "https://www.cs.toronto.edu/~kriz/cifar.html",
"objects365": "https://www.objects365.org/download.html",
}
14 changes: 10 additions & 4 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from super_gradients.common.data_types import StrictLoad
from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.training.pretrained_models import MODEL_URLS, DATASET_LICENSES
from super_gradients.training.utils.distributed_training_utils import wait_for_the_master
from super_gradients.common.environment.ddp_utils import get_local_rank
from super_gradients.training.utils.utils import unwrap_model
Expand All @@ -24,7 +24,6 @@
except (ModuleNotFoundError, ImportError, NameError):
from torch.hub import _download_url_to_file as download_url_to_file


logger = get_logger(__name__)


Expand Down Expand Up @@ -52,8 +51,8 @@ def transfer_weights(model: nn.Module, model_state_dict: Mapping[str, Tensor]) -
percentage_of_checkpoint = transfered_weights / len(model_state_dict)
percentage_of_model = transfered_weights / len(model.state_dict())
logger.debug(
f"Transfered {transfered_weights} ({(100*percentage_of_checkpoint):.2f}%) weights from the checkpoint. "
f"{(100*percentage_of_model):.2f}% of the model layers were initialized using checkpoint."
f"Transfered {transfered_weights} ({(100 * percentage_of_checkpoint):.2f}%) weights from the checkpoint. "
f"{(100 * percentage_of_model):.2f}% of the model layers were initialized using checkpoint."
)


Expand Down Expand Up @@ -1562,6 +1561,13 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
if model_url_key not in MODEL_URLS.keys():
raise MissingPretrainedWeightsException(model_url_key)

if pretrained_weights in DATASET_LICENSES:
logger.warning(
f":warning: The pre-trained models provided by SuperGradients may have their own licenses or terms and "
"conditions derived from the dataset used for pre-training.\n It is your responsibility to determine whether you "
"have permission to use the models for your use case.\n The model you have requested was pre-trained on the "
f"{pretrained_weights} dataset, published under the following terms: {DATASET_LICENSES[pretrained_weights]}"
)
url = MODEL_URLS[model_url_key]

if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
Expand Down
Loading