Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data collator for causal completion training #292

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def train(
torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]]
silence_progress_bars: Optional[bool] = True,
seed: int = RANDOM_SEED,
train_on_completion: bool = False,
response_template: str = None,
Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decision to expose the response template flag may change, just fYI. I wont be merging this code right away, but I want to ensure rest of it looks ok, assuming it is a user provided argument

**kwargs,
) -> "PeftPromptTuning":
"""Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model
Expand Down Expand Up @@ -347,6 +349,13 @@ def train(
Silences TQDM progress bars at train time. Default: True.
seed: int
Integer to be used as random seed for training.
train_on_completion: bool
True will train the model on the generated prompts only.
Only applicable to Causal LMs.
Default: False.
response_template: Optional[str] = None
Only if train_on_completion is set to True, pass a response template that
will be used to parse out the response.
Returns:
PeftPromptTuning
Instance of this class with tuned prompt vectors.
Expand All @@ -355,6 +364,13 @@ def train(
"<NLP46653367E>", len(train_stream) > 0, "train_stream cannot be empty"
)

if train_on_completion:
if not response_template:
error.value_check(
"<NLP41651387E>",
"Response template is need for train on completion",
)

# Configure random seed
transformers.set_seed(seed)
# NOTE: Following can be uncommented to allow full determinism
Expand Down Expand Up @@ -505,6 +521,8 @@ def train(
training_args,
checkpoint_dir,
base_model,
train_on_completion=train_on_completion,
response_template=response_template,
)

# Wrap up the trained model in a class instance
Expand Down
11 changes: 11 additions & 0 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def get_trainer(
eval_dataset: Union[IterableDataset, None] = None,
optimizers=(None, None),
model=None,
train_on_completion=False,
response_template=None,
**kwargs,
):
"""
Expand All @@ -296,6 +298,15 @@ def get_trainer(

training_args = TrainingArguments(**kwargs)

if train_on_completion:
if response_template is None:
error(
"<NLP19348182E>",
"Response Template needs to be set to use completion only",
)
kwargs["train_on_completion"] = train_on_completion
kwargs["response_template"] = response_template

data_collator = self._get_data_collator(**kwargs)

trainer_arguments = {
Expand Down
13 changes: 13 additions & 0 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DataCollatorForLanguageModeling,
)
from transformers.models.auto import modeling_auto
from trl import DataCollatorForCompletionOnlyLM
import torch

# First Party
Expand Down Expand Up @@ -168,6 +169,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator":
Collator to be used for causal language modeling.
"""

if "train_on_completion" in kwargs and kwargs["train_on_completion"]:
applicable_args = ["mlm", "response_template", "instruction_template"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay for now, but we might want to link the place where we get applicable args from for different collator types - otherwise this might get confusing eventually

collator_kwargs = {
key: kwargs[key] for key in applicable_args if key in kwargs
}

if "mlm" not in collator_kwargs:
collator_kwargs["mlm"] = False

return DataCollatorForCompletionOnlyLM(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SFT Trainer hasn't been integrated quite yet, correct? Could you add a comment here for a TODO to validate that this can't be used if the trainer is initialized with packing=True so that we don't miss that edge case in the future? https://huggingface.co/docs/trl/v0.7.4/en/sft_trainer#train-on-completions-only

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SFTTRAiner wont be integrated in thsi codebase :) this is probably one of last PRs to go in for tuning, only Lora after this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

background being we wnat to enable completion training for current PT and LOra which will be through thsi codebase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will add the comment though, thanks

tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this with multiGPU by any chance? In the past, we've seen some kind of strange behavior for some of the collators around dynamic padding, curious if that's been observed here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raghu's team has tested it with MultiGPU, but the plan is to only use this codebase for single GPU going forward. MultiGPu will go through non caikit path

)
applicable_args = ["mlm", "pad_to_multiple_of"]
collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs}

Expand Down
8 changes: 7 additions & 1 deletion caikit_nlp/toolkit/text_generation/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,18 @@ def launch_training(
checkpoint_dir,
caikit_resource=None,
tokenizer=None,
train_on_completion=False,
response_template=None,
) -> None:
"""Utility function to wrap trainer and execute training"""
# If we have a caikit resource, grab the trainer through it
if caikit_resource is not None:
trainer = caikit_resource.get_trainer(
train_dataset=training_dataset, model=base_model, **training_args
train_dataset=training_dataset,
model=base_model,
train_on_completion=train_on_completion,
response_template=response_template,
**training_args
)
else:
# If trainer is not provided fetch it from base_model
Expand Down
14 changes: 14 additions & 0 deletions examples/run_peft_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non
default="float32",
choices=["float16", "bfloat16", "float32"],
)
subparser.add_argument(
"--train_on_completion",
help="Train on completion True or False",
default=False,
type=bool,
choices=[True, False],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than choices, I think the encouraged pattern to use for bool flags in argparse is to use action=store_true or action=store_false - In general, using bool as a type converter can do some weird stuff for argparse because it usually converts nonempty strings to True. I think this might not work quite as expected

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--train_on_completion",
    help="Train on completion True or False",
    default=False,
    type=bool,
    choices=[True, False],
)
args = parser.parse_args()
print(args)

running

python3 testit.py --train_on_completion=False

produces Namespace(train_on_completion=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I didnt even mean to commit this example actually, I will make this change and test that example actually works.

)
subparser.add_argument(
"--response_template",
help="Response template to identify response",
default=None,
)


def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser):
Expand Down Expand Up @@ -414,6 +426,8 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
silence_progress_bars=not args.verbose,
accumulate_steps=args.accumulate_steps,
torch_dtype=args.torch_dtype,
train_on_completion=args.train_on_completion,
response_template=args.response_template,
)
model.save(args.output_dir, save_base_model=not args.prompt_only)
print_colored("[Training Complete]")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"torch>=2.0.1",
"tqdm>=4.65.0",
"transformers>=4.32.0",
"trl>=0.7.2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pin upper bound, given that trl is in 0.x?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, thanks, will do

# GK-AUG-25-2023 NOTE: mpt branch on Mayank's fork was merged to peft main on Aug 24 and it got deleted
# which broke caikit-nlp build. peft hasn't released newer version yet, so to get
# the build fix, we pulling peft from main branch commit. In future, we will pull PEFT from
Expand Down
31 changes: 31 additions & 0 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,37 @@ def test_train_model(causal_lm_train_kwargs, set_cpu_device):
assert isinstance(pred, GeneratedTextResult)


def test_train_model_on_completion(causal_lm_train_kwargs, set_cpu_device):
"""Ensure that we can train a model on some toy data for 1+ steps & run inference."""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
caikit_nlp.data_model.GenerationTrainRecord(
input="@foo what a cute dog!", output="no complaint"
),
caikit_nlp.data_model.GenerationTrainRecord(
input="@bar this is the worst idea ever.", output="complaint"
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
"train_on_completion": True,
"response_template": "#answer:",
}
causal_lm_train_kwargs.update(patch_kwargs)
model = caikit_nlp.modules.text_generation.PeftPromptTuning.train(
**causal_lm_train_kwargs
)
# Test fallback to float32 behavior if this machine doesn't support bfloat16
assert model.model.dtype is torch.float32
# Ensure that we can get something out of it
pred = model.run("@bar what a cute cat!")
assert isinstance(pred, GeneratedTextResult)


def test_gen_trained_mpt(causal_lm_train_kwargs, set_cpu_device):
"""Ensure that we are able to do generation on causal-lm model trained
using MPT."""
Expand Down
Loading