From 1dbe4bf482add3c7008afff611ec516f5c8de2d1 Mon Sep 17 00:00:00 2001 From: Tomo Lazovich Date: Wed, 15 Jul 2020 13:58:48 +0000 Subject: [PATCH 1/2] [squad] make examples and dataset accessible from SquadDataset object --- src/transformers/data/datasets/squad.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/data/datasets/squad.py b/src/transformers/data/datasets/squad.py index e1589cdf6363a7..5943190dc64134 100644 --- a/src/transformers/data/datasets/squad.py +++ b/src/transformers/data/datasets/squad.py @@ -103,6 +103,7 @@ def __init__( mode: Union[str, Split] = Split.train, is_language_sensitive: Optional[bool] = False, cache_dir: Optional[str] = None, + dataset_format: Optional[str] = "pt", ): self.args = args self.is_language_sensitive = is_language_sensitive @@ -128,28 +129,35 @@ def __init__( with FileLock(lock_path): if os.path.exists(cached_features_file) and not args.overwrite_cache: start = time.time() - self.features = torch.load(cached_features_file) + cache_dict = torch.load(cached_features_file) + self.features = cache_dict["features"] + self.dataset = cache_dict["dataset"] + self.examples = cache_dict["examples"] logger.info( f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start ) else: if mode == Split.dev: - examples = self.processor.get_dev_examples(args.data_dir) + self.examples = self.processor.get_dev_examples(args.data_dir) else: - examples = self.processor.get_train_examples(args.data_dir) + self.examples = self.processor.get_train_examples(args.data_dir) - self.features = squad_convert_examples_to_features( - examples=examples, + self.features, self.dataset = squad_convert_examples_to_features( + examples=self.examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=mode == Split.train, threads=args.threads, + return_dataset=dataset_format, ) start = time.time() - torch.save(self.features, cached_features_file) + torch.save( + {"features": self.features, "dataset": self.dataset, "examples": self.examples}, + cached_features_file, + ) # ^ This seems to take a lot of time so I want to investigate why and how we can improve. logger.info( "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start From 20d0228d0fefeda9929674be4d502d442ea4c59b Mon Sep 17 00:00:00 2001 From: Tomo Lazovich Date: Tue, 25 Aug 2020 12:33:19 -0400 Subject: [PATCH 2/2] [squad] add support for legacy cache files --- src/transformers/data/datasets/squad.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/data/datasets/squad.py b/src/transformers/data/datasets/squad.py index 5943190dc64134..b6ec86e4aef02b 100644 --- a/src/transformers/data/datasets/squad.py +++ b/src/transformers/data/datasets/squad.py @@ -129,13 +129,21 @@ def __init__( with FileLock(lock_path): if os.path.exists(cached_features_file) and not args.overwrite_cache: start = time.time() - cache_dict = torch.load(cached_features_file) - self.features = cache_dict["features"] - self.dataset = cache_dict["dataset"] - self.examples = cache_dict["examples"] + self.old_features = torch.load(cached_features_file) + + # Legacy cache files have only features, while new cache files + # will have dataset and examples also. + self.features = self.old_features["features"] + self.dataset = self.old_features.get("dataset", None) + self.examples = self.old_features.get("examples", None) logger.info( f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start ) + + if self.dataset is None or self.examples is None: + logger.warn( + f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run" + ) else: if mode == Split.dev: self.examples = self.processor.get_dev_examples(args.data_dir)