diff --git a/docs/source/en/model_doc/transfo-xl.md b/docs/source/en/model_doc/transfo-xl.md index 05afc76f1114dd..dae7e532be66f3 100644 --- a/docs/source/en/model_doc/transfo-xl.md +++ b/docs/source/en/model_doc/transfo-xl.md @@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing We recommend switching to more recent models for improved security. -In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub: +In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub. -``` +You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the +usage of `pickle.load()`: + +```python +import os from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel +os.environ["TRUST_REMOTE_CODE"] = "True" + checkpoint = 'transfo-xl-wt103' revision = '40a186da79458c9f9de846edfaea79c412137f97' diff --git a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py index 93f0bfedc0400e..cea74e76bc15a6 100644 --- a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py @@ -34,6 +34,7 @@ is_torch_available, logging, requires_backends, + strtobool, torch_only_method, ) @@ -212,6 +213,14 @@ def __init__( vocab_dict = None if pretrained_vocab_file is not None: # Priority on pickle files (support PyTorch and TF) + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is " + "potentially malicious. It's recommended to never unpickle data that could have come from an " + "untrusted source, or that could have been tampered with. If you already verified the pickle " + "data and decided to use it, you can set the environment variable " + "`TRUST_REMOTE_CODE` to `True` to allow it." + ) with open(pretrained_vocab_file, "rb") as f: vocab_dict = pickle.load(f) @@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset): corpus = torch.load(fn_pickle) elif os.path.exists(fn): logger.info("Loading cached dataset from pickle...") + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially " + "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or " + "that could have been tampered with. If you already verified the pickle data and decided to use it, " + "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it." + ) with open(fn, "rb") as fp: corpus = pickle.load(fp) else: diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py index 88cb54115bf548..76f6231ec28fbb 100644 --- a/src/transformers/models/rag/retrieval_rag.py +++ b/src/transformers/models/rag/retrieval_rag.py @@ -23,7 +23,7 @@ from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils_base import BatchEncoding -from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends +from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool from .configuration_rag import RagConfig from .tokenization_rag import RagTokenizer @@ -131,6 +131,13 @@ def _resolve_path(self, index_path, filename): def _load_passages(self): logger.info(f"Loading passages from {self.index_path}") passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME) + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially " + "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or " + "that could have been tampered with. If you already verified the pickle data and decided to use it, " + "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it." + ) with open(passages_path, "rb") as passages_file: passages = pickle.load(passages_file) return passages @@ -140,6 +147,13 @@ def _deserialize_index(self): resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr") self.index = faiss.read_index(resolved_index_path) resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr") + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially " + "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or " + "that could have been tampered with. If you already verified the pickle data and decided to use it, " + "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it." + ) with open(resolved_meta_path, "rb") as metadata_file: self.index_id_to_db_id = pickle.load(metadata_file) assert ( diff --git a/tests/models/rag/test_retrieval_rag.py b/tests/models/rag/test_retrieval_rag.py index d4c119815c96f4..1cf155b39a9e2d 100644 --- a/tests/models/rag/test_retrieval_rag.py +++ b/tests/models/rag/test_retrieval_rag.py @@ -14,7 +14,6 @@ import json import os -import pickle import shutil import tempfile from unittest import TestCase @@ -174,37 +173,6 @@ def get_dummy_custom_hf_index_retriever(self, from_disk: bool): ) return retriever - def get_dummy_legacy_index_retriever(self): - dataset = Dataset.from_dict( - { - "id": ["0", "1"], - "text": ["foo", "bar"], - "title": ["Foo", "Bar"], - "embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)], - } - ) - dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) - - index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index") - dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr") - pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb")) - - passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl") - passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset} - pickle.dump(passages, open(passages_file_name, "wb")) - - config = RagConfig( - retrieval_vector_size=self.retrieval_vector_size, - question_encoder=DPRConfig().to_dict(), - generator=BartConfig().to_dict(), - index_name="legacy", - index_path=self.tmpdirname, - ) - retriever = RagRetriever( - config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer() - ) - return retriever - def test_canonical_hf_index_retriever_retrieve(self): n_docs = 1 retriever = self.get_dummy_canonical_hf_index_retriever() @@ -288,33 +256,6 @@ def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self): out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None) - def test_legacy_index_retriever_retrieve(self): - n_docs = 1 - retriever = self.get_dummy_legacy_index_retriever() - hidden_states = np.array( - [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 - ) - retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) - self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) - self.assertEqual(len(doc_dicts), 2) - self.assertEqual(sorted(doc_dicts[0]), ["text", "title"]) - self.assertEqual(len(doc_dicts[0]["text"]), n_docs) - self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc - self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc - self.assertListEqual(doc_ids.tolist(), [[1], [0]]) - - def test_legacy_hf_index_retriever_save_and_from_pretrained(self): - retriever = self.get_dummy_legacy_index_retriever() - with tempfile.TemporaryDirectory() as tmp_dirname: - retriever.save_pretrained(tmp_dirname) - retriever = RagRetriever.from_pretrained(tmp_dirname) - self.assertIsInstance(retriever, RagRetriever) - hidden_states = np.array( - [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 - ) - out = retriever.retrieve(hidden_states, n_docs=1) - self.assertTrue(out is not None) - @require_torch @require_tokenizers @require_sentencepiece