Skip to content

Commit

Permalink
[squad] make examples and dataset accessible from SquadDataset object (
Browse files Browse the repository at this point in the history
…huggingface#6710)

* [squad] make examples and dataset accessible from SquadDataset object

* [squad] add support for legacy cache files
  • Loading branch information
Tomo Lazovich authored and Zigur committed Oct 26, 2020
1 parent ab394e1 commit 50a5edb
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/transformers/data/datasets/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
mode: Union[str, Split] = Split.train,
is_language_sensitive: Optional[bool] = False,
cache_dir: Optional[str] = None,
dataset_format: Optional[str] = "pt",
):
self.args = args
self.is_language_sensitive = is_language_sensitive
Expand All @@ -128,28 +129,43 @@ def __init__(
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
self.features = torch.load(cached_features_file)
self.old_features = torch.load(cached_features_file)

# Legacy cache files have only features, while new cache files
# will have dataset and examples also.
self.features = self.old_features["features"]
self.dataset = self.old_features.get("dataset", None)
self.examples = self.old_features.get("examples", None)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
)

if self.dataset is None or self.examples is None:
logger.warn(
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run"
)
else:
if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir)
self.examples = self.processor.get_dev_examples(args.data_dir)
else:
examples = self.processor.get_train_examples(args.data_dir)
self.examples = self.processor.get_train_examples(args.data_dir)

self.features = squad_convert_examples_to_features(
examples=examples,
self.features, self.dataset = squad_convert_examples_to_features(
examples=self.examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=mode == Split.train,
threads=args.threads,
return_dataset=dataset_format,
)

start = time.time()
torch.save(self.features, cached_features_file)
torch.save(
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
cached_features_file,
)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
Expand Down

0 comments on commit 50a5edb

Please sign in to comment.