Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow pickle.load unless TRUST_REMOTE_CODE=True #27776

Merged
merged 5 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/source/en/model_doc/transfo-xl.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing

We recommend switching to more recent models for improved security.

In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub:
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub.

```
You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the
usage of `pickle.load()`:

```python
import os
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel

os.environ["TRUST_REMOTE_CODE"] = "True"

checkpoint = 'transfo-xl-wt103'
revision = '40a186da79458c9f9de846edfaea79c412137f97'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_torch_available,
logging,
requires_backends,
strtobool,
torch_only_method,
)

Expand Down Expand Up @@ -212,6 +213,14 @@ def __init__(
vocab_dict = None
if pretrained_vocab_file is not None:
# Priority on pickle files (support PyTorch and TF)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
"potentially malicious. It's recommended to never unpickle data that could have come from an "
"untrusted source, or that could have been tampered with. If you already verified the pickle "
"data and decided to use it, you can set the environment variable "
"`TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(pretrained_vocab_file, "rb") as f:
vocab_dict = pickle.load(f)

Expand Down Expand Up @@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset):
corpus = torch.load(fn_pickle)
elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(fn, "rb") as fp:
corpus = pickle.load(fp)
else:
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/models/rag/retrieval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer

Expand Down Expand Up @@ -131,6 +131,13 @@ def _resolve_path(self, index_path, filename):
def _load_passages(self):
logger.info(f"Loading passages from {self.index_path}")
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(passages_path, "rb") as passages_file:
passages = pickle.load(passages_file)
return passages
Expand All @@ -140,6 +147,13 @@ def _deserialize_index(self):
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
self.index = faiss.read_index(resolved_index_path)
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(resolved_meta_path, "rb") as metadata_file:
self.index_id_to_db_id = pickle.load(metadata_file)
assert (
Expand Down
59 changes: 0 additions & 59 deletions tests/models/rag/test_retrieval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
Expand Down Expand Up @@ -174,37 +173,6 @@ def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
)
return retriever

def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)

index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))

passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))

config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever

def test_canonical_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever()
Expand Down Expand Up @@ -288,33 +256,6 @@ def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self):
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)

def test_legacy_index_retriever_retrieve(self):
n_docs = 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to test the "legacy" case anymore after (not merged yet) #27748

retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])

def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_legacy_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)

@require_torch
@require_tokenizers
@require_sentencepiece
Expand Down
Loading