diff --git a/src/transformers/data/datasets/squad.py b/src/transformers/data/datasets/squad.py index e1589cdf6363a7..b6ec86e4aef02b 100644 --- a/src/transformers/data/datasets/squad.py +++ b/src/transformers/data/datasets/squad.py @@ -103,6 +103,7 @@ def __init__( mode: Union[str, Split] = Split.train, is_language_sensitive: Optional[bool] = False, cache_dir: Optional[str] = None, + dataset_format: Optional[str] = "pt", ): self.args = args self.is_language_sensitive = is_language_sensitive @@ -128,28 +129,43 @@ def __init__( with FileLock(lock_path): if os.path.exists(cached_features_file) and not args.overwrite_cache: start = time.time() - self.features = torch.load(cached_features_file) + self.old_features = torch.load(cached_features_file) + + # Legacy cache files have only features, while new cache files + # will have dataset and examples also. + self.features = self.old_features["features"] + self.dataset = self.old_features.get("dataset", None) + self.examples = self.old_features.get("examples", None) logger.info( f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start ) + + if self.dataset is None or self.examples is None: + logger.warn( + f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run" + ) else: if mode == Split.dev: - examples = self.processor.get_dev_examples(args.data_dir) + self.examples = self.processor.get_dev_examples(args.data_dir) else: - examples = self.processor.get_train_examples(args.data_dir) + self.examples = self.processor.get_train_examples(args.data_dir) - self.features = squad_convert_examples_to_features( - examples=examples, + self.features, self.dataset = squad_convert_examples_to_features( + examples=self.examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=mode == Split.train, threads=args.threads, + return_dataset=dataset_format, ) start = time.time() - torch.save(self.features, cached_features_file) + torch.save( + {"features": self.features, "dataset": self.dataset, "examples": self.examples}, + cached_features_file, + ) # ^ This seems to take a lot of time so I want to investigate why and how we can improve. logger.info( "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start