Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples template] added max_sample args and metrics changes #10602

Merged
merged 1 commit into from
Mar 9, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -317,13 +331,37 @@ def main():
def tokenize_function(examples):
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)

tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
# Select Sample from Dataset
train_dataset = train_dataset.select(range(data_args.max_train_samples))
# tokenize train dataset in batch
train_dataset = train_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_eval:
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
# Selecting samples from dataset
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
# tokenize validation dataset
eval_dataset = eval_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)

# Data collator
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
Expand All @@ -332,8 +370,8 @@ def tokenize_function(examples):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
Expand All @@ -358,33 +396,27 @@ def tokenize_function(examples):
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload

output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")

results = trainer.evaluate()
metrics = trainer.evaluate()

output_eval_file = os.path.join(training_args.output_dir, "eval_results_{{cookiecutter.example_shortcut}}.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))

return results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


def _mp_fn(index):
Expand Down