Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hashlib specification #27038

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datasets import load_dataset
from minhash_deduplication import deduplicate_dataset

from src.transformers.utils.generic import HASHLIB_KWARGS
from transformers import AutoTokenizer, HfArgumentParser


Expand All @@ -21,7 +22,7 @@

def get_hash(example):
"""Get hash of content field."""
return {"hash": hashlib.md5(re.sub(PATTERN, "", example["content"]).encode("utf-8")).hexdigest()}
return {"hash": hashlib.md5(re.sub(PATTERN, "", example["content"]).encode("utf-8"), **HASHLIB_KWARGS).hexdigest()}


def line_stats(example):
Expand Down
6 changes: 4 additions & 2 deletions examples/research_projects/lxmert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from tqdm.auto import tqdm
from yaml import Loader, dump, load

from src.transformers.utils.generic import HASHLIB_KWARGS


try:
import torch
Expand Down Expand Up @@ -402,12 +404,12 @@ def _resumable_file_manager():

def url_to_filename(url, etag=None):
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
url_hash = sha256(url_bytes, **HASHLIB_KWARGS)
filename = url_hash.hexdigest()

if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
etag_hash = sha256(etag_bytes, **HASHLIB_KWARGS)
filename += "." + etag_hash.hexdigest()

if url.endswith(".h5"):
Expand Down
6 changes: 4 additions & 2 deletions examples/research_projects/visual_bert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from tqdm.auto import tqdm
from yaml import Loader, dump, load

from src.transformers.utils.generic import HASHLIB_KWARGS


try:
import torch
Expand Down Expand Up @@ -402,12 +404,12 @@ def _resumable_file_manager():

def url_to_filename(url, etag=None):
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
url_hash = sha256(url_bytes, **HASHLIB_KWARGS)
filename = url_hash.hexdigest()

if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
etag_hash = sha256(etag_bytes, **HASHLIB_KWARGS)
filename += "." + etag_hash.hexdigest()

if url.endswith(".h5"):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/whisper/convert_openai_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn
from tqdm import tqdm

from src.transformers.utils.generic import HASHLIB_KWARGS
from transformers import WhisperConfig, WhisperForConditionalGeneration


Expand Down Expand Up @@ -102,7 +103,7 @@ def _download(url: str, root: str) -> bytes:

if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
if hashlib.sha256(model_bytes, **HASHLIB_KWARGS).hexdigest() == expected_sha256:
return model_bytes
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
Expand All @@ -120,7 +121,7 @@ def _download(url: str, root: str) -> bytes:
loop.update(len(buffer))

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
if hashlib.sha256(model_bytes, **HASHLIB_KWARGS).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import inspect
import sys
import tempfile
from collections import OrderedDict, UserDict
from collections.abc import MutableMapping
Expand Down Expand Up @@ -673,3 +674,7 @@ def infer_framework(model_class):
return "flax"
else:
raise TypeError(f"Could not infer framework from class {model_class}.")


# Defining a hashlib kwarg specified from 3.9+. This is to avoid usage of algorithms like md5, with a "usedforsecurity" kwarg set to True.
HASHLIB_KWARGS = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {}
3 changes: 2 additions & 1 deletion tests/pipelines/test_pipelines_depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hashlib
import unittest

from src.transformers.utils.generic import HASHLIB_KWARGS
from transformers import MODEL_FOR_DEPTH_ESTIMATION_MAPPING, is_torch_available, is_vision_available
from transformers.pipelines import DepthEstimationPipeline, pipeline
from transformers.testing_utils import (
Expand Down Expand Up @@ -44,7 +45,7 @@ def open(*args, **kwargs):


def hashimage(image: Image) -> str:
m = hashlib.md5(image.tobytes())
m = hashlib.md5(image.tobytes(), **HASHLIB_KWARGS)
return m.hexdigest()


Expand Down
3 changes: 2 additions & 1 deletion tests/pipelines/test_pipelines_image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import requests
from datasets import load_dataset

from src.transformers.utils.generic import HASHLIB_KWARGS
from transformers import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
Expand Down Expand Up @@ -59,7 +60,7 @@ def open(*args, **kwargs):


def hashimage(image: Image) -> str:
m = hashlib.md5(image.tobytes())
m = hashlib.md5(image.tobytes(), **HASHLIB_KWARGS)
return m.hexdigest()[:10]


Expand Down
3 changes: 2 additions & 1 deletion tests/pipelines/test_pipelines_mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

from src.transformers.utils.generic import HASHLIB_KWARGS
from transformers import (
MODEL_FOR_MASK_GENERATION_MAPPING,
TF_MODEL_FOR_MASK_GENERATION_MAPPING,
Expand Down Expand Up @@ -46,7 +47,7 @@ def open(*args, **kwargs):


def hashimage(image: Image) -> str:
m = hashlib.md5(image.tobytes())
m = hashlib.md5(image.tobytes(), **HASHLIB_KWARGS)
return m.hexdigest()[:10]


Expand Down