-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an examples folder for code downstream tasks #18679
Merged
loubnabnl
merged 10 commits into
huggingface:main
from
loubnabnl:add-examples-downstream-codeparrot
Aug 18, 2022
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
8f21968
add examples subfolder
loubnabnl f8263d1
reformat file
loubnabnl 3794de6
mention examples in codeparrot readme
loubnabnl c961253
reformat imports
loubnabnl 7ab673e
use Trainer optimizer and scheduler type and add output_dir as argument
loubnabnl 1015a06
add example of text-to-python and python-to-text models
loubnabnl fa8a0df
reformat imports
loubnabnl da21da8
mention the downstream examples in the readme
loubnabnl 17e508d
reformat code
loubnabnl 6f49ca4
fix typo
loubnabnl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Examples | ||
In this folder we showcase some examples to use code models for downstream tasks. | ||
|
||
## Complexity prediction | ||
In this task we want to predict the complexity of Java programs in [CodeComplex](https://huggingface.co/datasets/codeparrot/codecomplex) dataset. Using Hugging Face `trainer`, we finetuned [multilingual CodeParrot](https://huggingface.co/codeparrot/codeparrot-small-multi) and [UniXcoder](https://huggingface.co/microsoft/unixcoder-base-nine) on it, and we used the latter to build this Java complexity prediction [space](https://huggingface.co/spaces/codeparrot/code-complexity-predictor) on Hugging Face hub. | ||
|
||
To fine-tune a model on this dataset you can use the following commands: | ||
|
||
```python | ||
python train_complexity_predictor.py \ | ||
--model_ckpt microsoft/unixcoder-base-nine \ | ||
--num_epochs 60 \ | ||
--num_warmup_steps 10 \ | ||
--batch_size 8 \ | ||
--learning_rate 5e-4 | ||
``` | ||
|
||
## Code generation: text to python | ||
In this task we want to train a model to generate code from english text. We finetuned Codeparrot-small on [github-jupyter-text-to-code](https://huggingface.co/datasets/codeparrot/github-jupyter-text-to-code), a dataset where the samples are a succession of docstrings and their Python code, originally extracted from Jupyter notebooks parsed in this [dataset](https://huggingface.co/datasets/codeparrot/github-jupyter-parsed). | ||
|
||
To fine-tune a model on this dataset we use the same [script](https://github.com/huggingface/transformers/blob/main/examples/research_projects/codeparrot/scripts/codeparrot_training.py) as the pretraining of codeparrot: | ||
|
||
```python | ||
accelerate launch scripts/codeparrot_training.py \ | ||
--model_ckpt codeparrot/codeparrot-small \ | ||
--dataset_name_train codeparrot/github-jupyter-text-to-code \ | ||
--dataset_name_valid codeparrot/github-jupyter-text-to-code \ | ||
--train_batch_size 12 \ | ||
--valid_batch_size 12 \ | ||
--learning_rate 5e-4 \ | ||
--num_warmup_steps 100 \ | ||
--gradient_accumulation 1 \ | ||
--gradient_checkpointing False \ | ||
--max_train_steps 3000 \ | ||
--save_checkpoint_steps 200 \ | ||
--save_dir jupyter-text-to-python | ||
``` | ||
|
||
## Code explanation: python to text | ||
In this task we want to train a model to explain python code. We finetuned Codeparrot-small on [github-jupyter-code-to-text](https://huggingface.co/datasets/codeparrot/github-jupyter-code-to-text), a dataset where the samples are a succession of Python code and its explanation as a docstring, we just inverted the order of text and code pairs in github-jupyter-code-to-text dataset and added the delimiters "Explanation:" and "End of explanation" inside the doctrings. | ||
|
||
To fine-tune a model on this dataset we use the same [script](https://github.com/huggingface/transformers/blob/main/examples/research_projects/codeparrot/scripts/codeparrot_training.py) as the pretraining of codeparrot: | ||
|
||
```python | ||
accelerate launch scripts/codeparrot_training.py \ | ||
--model_ckpt codeparrot/codeparrot-small \ | ||
--dataset_name_train codeparrot/github-jupyter-code-to-text \ | ||
--dataset_name_valid codeparrot/github-jupyter-code-to-text \ | ||
--train_batch_size 12 \ | ||
--valid_batch_size 12 \ | ||
--learning_rate 5e-4 \ | ||
--num_warmup_steps 100 \ | ||
--gradient_accumulation 1 \ | ||
--gradient_checkpointing False \ | ||
--max_train_steps 3000 \ | ||
--save_checkpoint_steps 200 \ | ||
--save_dir jupyter-python-to-text | ||
``` |
5 changes: 5 additions & 0 deletions
5
examples/research_projects/codeparrot/examples/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
datasets==2.3.2 | ||
transformers==4.21.1 | ||
wandb==0.13.1 | ||
evaluate==0.2.2 | ||
scikit-learn==1.1.2 |
132 changes: 132 additions & 0 deletions
132
examples/research_projects/codeparrot/examples/train_complexity_predictor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import argparse | ||
from copy import deepcopy | ||
|
||
import numpy as np | ||
from datasets import ClassLabel, DatasetDict, load_dataset | ||
|
||
from evaluate import load | ||
from transformers import ( | ||
AutoModelForSequenceClassification, | ||
AutoTokenizer, | ||
DataCollatorWithPadding, | ||
Trainer, | ||
TrainerCallback, | ||
TrainingArguments, | ||
set_seed, | ||
) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_ckpt", type=str, default="microsoft/unixcoder-base-nine") | ||
parser.add_argument("--num_epochs", type=int, default=5) | ||
parser.add_argument("--batch_size", type=int, default=6) | ||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | ||
parser.add_argument("--freeze", type=bool, default=True) | ||
parser.add_argument("--learning_rate", type=float, default=5e-4) | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--lr_scheduler_type", type=str, default="cosine") | ||
parser.add_argument("--num_warmup_steps", type=int, default=10) | ||
parser.add_argument("--weight_decay", type=float, default=0.01) | ||
parser.add_argument("--output_dir", type=str, default="./results") | ||
return parser.parse_args() | ||
|
||
|
||
metric = load("accuracy") | ||
|
||
|
||
def compute_metrics(eval_pred): | ||
predictions, labels = eval_pred | ||
predictions = np.argmax(predictions, axis=1) | ||
return metric.compute(predictions=predictions, references=labels) | ||
|
||
|
||
class CustomCallback(TrainerCallback): | ||
def __init__(self, trainer) -> None: | ||
super().__init__() | ||
self._trainer = trainer | ||
|
||
def on_epoch_end(self, args, state, control, **kwargs): | ||
if control.should_evaluate: | ||
control_copy = deepcopy(control) | ||
self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train") | ||
return control_copy | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
set_seed(args.seed) | ||
|
||
dataset = load_dataset("codeparrot/codecomplex", split="train") | ||
train_test = dataset.train_test_split(test_size=0.2) | ||
test_validation = train_test["test"].train_test_split(test_size=0.5) | ||
train_test_validation = DatasetDict( | ||
{ | ||
"train": train_test["train"], | ||
"test": test_validation["train"], | ||
"valid": test_validation["test"], | ||
} | ||
) | ||
|
||
print("Loading tokenizer and model") | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
model = AutoModelForSequenceClassification.from_pretrained(args.model_ckpt, num_labels=7) | ||
model.config.pad_token_id = model.config.eos_token_id | ||
|
||
if args.freeze: | ||
for param in model.roberta.parameters(): | ||
param.requires_grad = False | ||
|
||
labels = ClassLabel(num_classes=7, names=list(set(train_test_validation["train"]["complexity"]))) | ||
|
||
def tokenize(example): | ||
inputs = tokenizer(example["src"], truncation=True, max_length=1024) | ||
label = labels.str2int(example["complexity"]) | ||
return { | ||
"input_ids": inputs["input_ids"], | ||
"attention_mask": inputs["attention_mask"], | ||
"label": label, | ||
} | ||
|
||
tokenized_datasets = train_test_validation.map( | ||
tokenize, | ||
batched=True, | ||
remove_columns=train_test_validation["train"].column_names, | ||
) | ||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | ||
|
||
training_args = TrainingArguments( | ||
output_dir=args.output_dir, | ||
learning_rate=args.learning_rate, | ||
lr_scheduler_type=args.lr_scheduler_type, | ||
evaluation_strategy="epoch", | ||
save_strategy="epoch", | ||
logging_strategy="epoch", | ||
per_device_train_batch_size=args.batch_size, | ||
per_device_eval_batch_size=args.batch_size, | ||
num_train_epochs=args.num_epochs, | ||
gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
weight_decay=0.01, | ||
metric_for_best_model="accuracy", | ||
run_name="complexity-java", | ||
report_to="wandb", | ||
) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=tokenized_datasets["train"], | ||
eval_dataset=tokenized_datasets["valid"], | ||
tokenizer=tokenizer, | ||
data_collator=data_collator, | ||
compute_metrics=compute_metrics, | ||
) | ||
|
||
print("Training...") | ||
trainer.add_callback(CustomCallback(trainer)) | ||
trainer.train() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Isn't this the same as
evaluation_strategy="epoch"
in the training arguments? also why do you evaluate on the train set?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this because I wanted to monitor the gap in accuracy between the training set and evaluation set