-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data collator for causal completion training #292
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@@ -300,6 +300,8 @@ def train( | |||
torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]] | |||
silence_progress_bars: Optional[bool] = True, | |||
seed: int = RANDOM_SEED, | |||
train_on_completion: bool = False, | |||
response_template: str = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decision to expose the response template flag may change, just fYI. I wont be merging this code right away, but I want to ensure rest of it looks ok, assuming it is a user provided argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good sukriti! Some thoughts, thanks!
collator_kwargs["mlm"] = False | ||
|
||
return DataCollatorForCompletionOnlyLM( | ||
tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested this with multiGPU by any chance? In the past, we've seen some kind of strange behavior for some of the collators around dynamic padding, curious if that's been observed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raghu's team has tested it with MultiGPU, but the plan is to only use this codebase for single GPU going forward. MultiGPu will go through non caikit path
if "mlm" not in collator_kwargs: | ||
collator_kwargs["mlm"] = False | ||
|
||
return DataCollatorForCompletionOnlyLM( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SFT Trainer hasn't been integrated quite yet, correct? Could you add a comment here for a TODO to validate that this can't be used if the trainer is initialized with packing=True
so that we don't miss that edge case in the future? https://huggingface.co/docs/trl/v0.7.4/en/sft_trainer#train-on-completions-only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SFTTRAiner wont be integrated in thsi codebase :) this is probably one of last PRs to go in for tuning, only Lora after this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
background being we wnat to enable completion training for current PT and LOra which will be through thsi codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i will add the comment though, thanks
@@ -168,6 +169,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": | |||
Collator to be used for causal language modeling. | |||
""" | |||
|
|||
if "train_on_completion" in kwargs and kwargs["train_on_completion"]: | |||
applicable_args = ["mlm", "response_template", "instruction_template"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay for now, but we might want to link the place where we get applicable args from for different collator types - otherwise this might get confusing eventually
@@ -28,6 +28,7 @@ dependencies = [ | |||
"torch>=2.0.1", | |||
"tqdm>=4.65.0", | |||
"transformers>=4.32.0", | |||
"trl>=0.7.2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we pin upper bound, given that trl is in 0.x
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, thanks, will do
help="Train on completion True or False", | ||
default=False, | ||
type=bool, | ||
choices=[True, False], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than choices, I think the encouraged pattern to use for bool flags in argparse is to use action=store_true
or action=store_false
- In general, using bool as a type converter can do some weird stuff for argparse because it usually converts nonempty strings to True. I think this might not work quite as expected
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_on_completion",
help="Train on completion True or False",
default=False,
type=bool,
choices=[True, False],
)
args = parser.parse_args()
print(args)
running
python3 testit.py --train_on_completion=False
produces Namespace(train_on_completion=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I didnt even mean to commit this example actually, I will make this change and test that example actually works.
This efforts has ability for
completion only LM
Training as an optional flag. How it should be exposed to users on product is being discussed, but we want to add the ability to support it in library for Causal LMs