Skip to content

Commit

Permalink
Fix training from scratch in new scripts (#8623)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored Nov 18, 2020
1 parent 1e62e99 commit a0c62d2
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
7 changes: 5 additions & 2 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,12 @@ def group_texts(examples):

# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload

# Evaluation
Expand Down
7 changes: 5 additions & 2 deletions examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,12 @@ def group_texts(examples):

# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload

# Evaluation
Expand Down
7 changes: 5 additions & 2 deletions examples/language-modeling/run_mlm_wwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,12 @@ def tokenize_function(examples):

# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload

# Evaluation
Expand Down
7 changes: 5 additions & 2 deletions examples/language-modeling/run_plm.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,12 @@ def group_texts(examples):

# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload

# Evaluation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,18 @@ def tokenize_function(examples):

# Training
if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %}
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
{%- elif cookiecutter.can_train_from_scratch == "True" %}
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
{% endif %}
trainer.save_model() # Saves the tokenizer too for easy upload

# Evaluation
Expand Down

0 comments on commit a0c62d2

Please sign in to comment.