Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data collator for causal completion training #292

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Ssukriti
Copy link
Collaborator

@Ssukriti Ssukriti commented Dec 7, 2023

This efforts has ability for completion only LM Training as an optional flag. How it should be exposed to users on product is being discussed, but we want to add the ability to support it in library for Causal LMs

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti Ssukriti changed the title Add data collator Add data collator for causal completion training Dec 7, 2023
@@ -300,6 +300,8 @@ def train(
torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]]
silence_progress_bars: Optional[bool] = True,
seed: int = RANDOM_SEED,
train_on_completion: bool = False,
response_template: str = None,
Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decision to expose the response template flag may change, just fYI. I wont be merging this code right away, but I want to ensure rest of it looks ok, assuming it is a user provided argument

Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good sukriti! Some thoughts, thanks!

collator_kwargs["mlm"] = False

return DataCollatorForCompletionOnlyLM(
tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this with multiGPU by any chance? In the past, we've seen some kind of strange behavior for some of the collators around dynamic padding, curious if that's been observed here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raghu's team has tested it with MultiGPU, but the plan is to only use this codebase for single GPU going forward. MultiGPu will go through non caikit path

if "mlm" not in collator_kwargs:
collator_kwargs["mlm"] = False

return DataCollatorForCompletionOnlyLM(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SFT Trainer hasn't been integrated quite yet, correct? Could you add a comment here for a TODO to validate that this can't be used if the trainer is initialized with packing=True so that we don't miss that edge case in the future? https://huggingface.co/docs/trl/v0.7.4/en/sft_trainer#train-on-completions-only

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SFTTRAiner wont be integrated in thsi codebase :) this is probably one of last PRs to go in for tuning, only Lora after this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

background being we wnat to enable completion training for current PT and LOra which will be through thsi codebase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will add the comment though, thanks

@@ -168,6 +169,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator":
Collator to be used for causal language modeling.
"""

if "train_on_completion" in kwargs and kwargs["train_on_completion"]:
applicable_args = ["mlm", "response_template", "instruction_template"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay for now, but we might want to link the place where we get applicable args from for different collator types - otherwise this might get confusing eventually

@@ -28,6 +28,7 @@ dependencies = [
"torch>=2.0.1",
"tqdm>=4.65.0",
"transformers>=4.32.0",
"trl>=0.7.2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pin upper bound, given that trl is in 0.x?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, thanks, will do

help="Train on completion True or False",
default=False,
type=bool,
choices=[True, False],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than choices, I think the encouraged pattern to use for bool flags in argparse is to use action=store_true or action=store_false - In general, using bool as a type converter can do some weird stuff for argparse because it usually converts nonempty strings to True. I think this might not work quite as expected

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--train_on_completion",
    help="Train on completion True or False",
    default=False,
    type=bool,
    choices=[True, False],
)
args = parser.parse_args()
print(args)

running

python3 testit.py --train_on_completion=False

produces Namespace(train_on_completion=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I didnt even mean to commit this example actually, I will make this change and test that example actually works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants