Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an examples folder for code downstream tasks #18679

Merged

Conversation

loubnabnl
Copy link
Contributor

What does this PR do?

This PR adds a folder in CodeParrot directory to store examples for downstream tasks on code models.

@loubnabnl loubnabnl requested a review from lvwerra August 18, 2022 09:13
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 18, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I mostly left comments for simplifying the training code a bit.

Comment on lines 114 to 120
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=args.num_epochs,
num_warmup_steps=args.num_warmup_steps,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can define all that in the training arguments, no? No need to pass an optimizer/lr_scheduler explicitely?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think by default the linear scheduler is used and I needed cosine scheduler here, but you're right we don't need to specify the optimizer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 45 to 55
def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
params_without_wd.append(p)
else:
params_with_wd.append(p)
return [
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you don't pass the optimizer explicitly, the Trainer will take care of that for you, no?

Comment on lines +58 to +53
class CustomCallback(TrainerCallback):
def __init__(self, trainer) -> None:
super().__init__()
self._trainer = trainer

def on_epoch_end(self, args, state, control, **kwargs):
if control.should_evaluate:
control_copy = deepcopy(control)
self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train")
return control_copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Isn't this the same as evaluation_strategy="epoch" in the training arguments? also why do you evaluate on the train set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because I wanted to monitor the gap in accuracy between the training set and evaluation set

@@ -12,7 +12,7 @@ This is an open-source effort to train and evaluate code generation models. Code
- continuously push checkpoints to the hub with `huggingface_hub`
- stream the dataset with `datasets` during training to avoid disk bottlenecks
- apply the `code_eval` metric in `datasets` to evaluate on [OpenAI's _HumanEval_ benchmark](https://huggingface.co/datasets/openai_humaneval)

- showcase examples for downstream tasks with code models in [examples](https://github.com/huggingface/transformers/tree/main/examples/research_projects/codeparrot/examples) folder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can say what examples we show there. should we also add the code for the text2py and py2text here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but the text2py and py2text examples use a similar script to the one for pretraining codeparrot just with a different dataset and model checkpoint, maybe I can just mention that in the README?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I think mentioning the settings would also be useful to document.

@loubnabnl loubnabnl force-pushed the add-examples-downstream-codeparrot branch from 12762f5 to 17e508d Compare August 18, 2022 13:15
@loubnabnl loubnabnl merged commit bbbb453 into huggingface:main Aug 18, 2022
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
* add examples subfolder

* mention examples in codeparrot readme

* use Trainer optimizer and scheduler type and add output_dir as argument

* add example of text-to-python and python-to-text models

* mention the downstream examples in the readme

* fix typo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants