-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data collator for causal completion training #292
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
DataCollatorForLanguageModeling, | ||
) | ||
from transformers.models.auto import modeling_auto | ||
from trl import DataCollatorForCompletionOnlyLM | ||
import torch | ||
|
||
# First Party | ||
|
@@ -168,6 +169,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": | |
Collator to be used for causal language modeling. | ||
""" | ||
|
||
if "train_on_completion" in kwargs and kwargs["train_on_completion"]: | ||
applicable_args = ["mlm", "response_template", "instruction_template"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's okay for now, but we might want to link the place where we get applicable args from for different collator types - otherwise this might get confusing eventually |
||
collator_kwargs = { | ||
key: kwargs[key] for key in applicable_args if key in kwargs | ||
} | ||
|
||
if "mlm" not in collator_kwargs: | ||
collator_kwargs["mlm"] = False | ||
|
||
return DataCollatorForCompletionOnlyLM( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SFT Trainer hasn't been integrated quite yet, correct? Could you add a comment here for a TODO to validate that this can't be used if the trainer is initialized with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SFTTRAiner wont be integrated in thsi codebase :) this is probably one of last PRs to go in for tuning, only Lora after this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. background being we wnat to enable completion training for current PT and LOra which will be through thsi codebase There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i will add the comment though, thanks |
||
tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you tested this with multiGPU by any chance? In the past, we've seen some kind of strange behavior for some of the collators around dynamic padding, curious if that's been observed here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raghu's team has tested it with MultiGPU, but the plan is to only use this codebase for single GPU going forward. MultiGPu will go through non caikit path |
||
) | ||
applicable_args = ["mlm", "pad_to_multiple_of"] | ||
collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -241,6 +241,18 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non | |
default="float32", | ||
choices=["float16", "bfloat16", "float32"], | ||
) | ||
subparser.add_argument( | ||
"--train_on_completion", | ||
help="Train on completion True or False", | ||
default=False, | ||
type=bool, | ||
choices=[True, False], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than choices, I think the encouraged pattern to use for bool flags in argparse is to use import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_on_completion",
help="Train on completion True or False",
default=False,
type=bool,
choices=[True, False],
)
args = parser.parse_args()
print(args) running python3 testit.py --train_on_completion=False produces There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, I didnt even mean to commit this example actually, I will make this change and test that example actually works. |
||
) | ||
subparser.add_argument( | ||
"--response_template", | ||
help="Response template to identify response", | ||
default=None, | ||
) | ||
|
||
|
||
def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser): | ||
|
@@ -414,6 +426,8 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: | |
silence_progress_bars=not args.verbose, | ||
accumulate_steps=args.accumulate_steps, | ||
torch_dtype=args.torch_dtype, | ||
train_on_completion=args.train_on_completion, | ||
response_template=args.response_template, | ||
) | ||
model.save(args.output_dir, save_base_model=not args.prompt_only) | ||
print_colored("[Training Complete]") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ dependencies = [ | |
"torch>=2.0.1", | ||
"tqdm>=4.65.0", | ||
"transformers>=4.32.0", | ||
"trl>=0.7.2", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we pin upper bound, given that trl is in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yes, thanks, will do |
||
# GK-AUG-25-2023 NOTE: mpt branch on Mayank's fork was merged to peft main on Aug 24 and it got deleted | ||
# which broke caikit-nlp build. peft hasn't released newer version yet, so to get | ||
# the build fix, we pulling peft from main branch commit. In future, we will pull PEFT from | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decision to expose the response template flag may change, just fYI. I wont be merging this code right away, but I want to ensure rest of it looks ok, assuming it is a user provided argument