Skip to content

Commit

Permalink
Disallow pickle.load unless TRUST_REMOTE_CODE=True (#27776)
Browse files Browse the repository at this point in the history
* fix

* fix

* Use TRUST_REMOTE_CODE

* fix doc

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Dec 4, 2023
1 parent e0d2e69 commit 1d63b0e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 62 deletions.
10 changes: 8 additions & 2 deletions docs/source/en/model_doc/transfo-xl.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing

We recommend switching to more recent models for improved security.

In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub:
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub.

```
You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the
usage of `pickle.load()`:

```python
import os
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel

os.environ["TRUST_REMOTE_CODE"] = "True"

checkpoint = 'transfo-xl-wt103'
revision = '40a186da79458c9f9de846edfaea79c412137f97'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_torch_available,
logging,
requires_backends,
strtobool,
torch_only_method,
)

Expand Down Expand Up @@ -212,6 +213,14 @@ def __init__(
vocab_dict = None
if pretrained_vocab_file is not None:
# Priority on pickle files (support PyTorch and TF)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
"potentially malicious. It's recommended to never unpickle data that could have come from an "
"untrusted source, or that could have been tampered with. If you already verified the pickle "
"data and decided to use it, you can set the environment variable "
"`TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(pretrained_vocab_file, "rb") as f:
vocab_dict = pickle.load(f)

Expand Down Expand Up @@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset):
corpus = torch.load(fn_pickle)
elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(fn, "rb") as fp:
corpus = pickle.load(fp)
else:
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/models/rag/retrieval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer

Expand Down Expand Up @@ -131,6 +131,13 @@ def _resolve_path(self, index_path, filename):
def _load_passages(self):
logger.info(f"Loading passages from {self.index_path}")
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(passages_path, "rb") as passages_file:
passages = pickle.load(passages_file)
return passages
Expand All @@ -140,6 +147,13 @@ def _deserialize_index(self):
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
self.index = faiss.read_index(resolved_index_path)
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(resolved_meta_path, "rb") as metadata_file:
self.index_id_to_db_id = pickle.load(metadata_file)
assert (
Expand Down
59 changes: 0 additions & 59 deletions tests/models/rag/test_retrieval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
Expand Down Expand Up @@ -174,37 +173,6 @@ def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
)
return retriever

def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)

index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))

passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))

config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever

def test_canonical_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever()
Expand Down Expand Up @@ -288,33 +256,6 @@ def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self):
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)

def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])

def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_legacy_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)

@require_torch
@require_tokenizers
@require_sentencepiece
Expand Down

0 comments on commit 1d63b0e

Please sign in to comment.