diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 435e439d..712d3945 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -300,6 +300,8 @@ def train( torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]] silence_progress_bars: Optional[bool] = True, seed: int = RANDOM_SEED, + train_on_completion: bool = False, + response_template: str = None, **kwargs, ) -> "PeftPromptTuning": """Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model @@ -347,6 +349,13 @@ def train( Silences TQDM progress bars at train time. Default: True. seed: int Integer to be used as random seed for training. + train_on_completion: bool + True will train the model on the generated prompts only. + Only applicable to Causal LMs. + Default: False. + response_template: Optional[str] = None + Only if train_on_completion is set to True, pass a response template that + will be used to parse out the response. Returns: PeftPromptTuning Instance of this class with tuned prompt vectors. @@ -355,6 +364,13 @@ def train( "", len(train_stream) > 0, "train_stream cannot be empty" ) + if train_on_completion: + if not response_template: + error.value_check( + "", + "Response template is need for train on completion", + ) + # Configure random seed transformers.set_seed(seed) # NOTE: Following can be uncommented to allow full determinism @@ -505,6 +521,8 @@ def train( training_args, checkpoint_dir, base_model, + train_on_completion=train_on_completion, + response_template=response_template, ) # Wrap up the trained model in a class instance diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 6adad903..56b76d32 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -280,6 +280,8 @@ def get_trainer( eval_dataset: Union[IterableDataset, None] = None, optimizers=(None, None), model=None, + train_on_completion=False, + response_template=None, **kwargs, ): """ @@ -296,6 +298,15 @@ def get_trainer( training_args = TrainingArguments(**kwargs) + if train_on_completion: + if response_template is None: + error( + "", + "Response Template needs to be set to use completion only", + ) + kwargs["train_on_completion"] = train_on_completion + kwargs["response_template"] = response_template + data_collator = self._get_data_collator(**kwargs) trainer_arguments = { diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index fc09734e..d7704293 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -25,6 +25,7 @@ DataCollatorForLanguageModeling, ) from transformers.models.auto import modeling_auto +from trl import DataCollatorForCompletionOnlyLM import torch # First Party @@ -168,6 +169,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": Collator to be used for causal language modeling. """ + if "train_on_completion" in kwargs and kwargs["train_on_completion"]: + applicable_args = ["mlm", "response_template", "instruction_template"] + collator_kwargs = { + key: kwargs[key] for key in applicable_args if key in kwargs + } + + if "mlm" not in collator_kwargs: + collator_kwargs["mlm"] = False + + return DataCollatorForCompletionOnlyLM( + tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs + ) applicable_args = ["mlm", "pad_to_multiple_of"] collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} diff --git a/caikit_nlp/toolkit/text_generation/training_utils.py b/caikit_nlp/toolkit/text_generation/training_utils.py index ac905fad..c69892a0 100644 --- a/caikit_nlp/toolkit/text_generation/training_utils.py +++ b/caikit_nlp/toolkit/text_generation/training_utils.py @@ -174,12 +174,18 @@ def launch_training( checkpoint_dir, caikit_resource=None, tokenizer=None, + train_on_completion=False, + response_template=None, ) -> None: """Utility function to wrap trainer and execute training""" # If we have a caikit resource, grab the trainer through it if caikit_resource is not None: trainer = caikit_resource.get_trainer( - train_dataset=training_dataset, model=base_model, **training_args + train_dataset=training_dataset, + model=base_model, + train_on_completion=train_on_completion, + response_template=response_template, + **training_args ) else: # If trainer is not provided fetch it from base_model diff --git a/examples/run_peft_tuning.py b/examples/run_peft_tuning.py index 305bd779..9d012134 100644 --- a/examples/run_peft_tuning.py +++ b/examples/run_peft_tuning.py @@ -241,6 +241,18 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non default="float32", choices=["float16", "bfloat16", "float32"], ) + subparser.add_argument( + "--train_on_completion", + help="Train on completion True or False", + default=False, + type=bool, + choices=[True, False], + ) + subparser.add_argument( + "--response_template", + help="Response template to identify response", + default=None, + ) def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser): @@ -414,6 +426,8 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: silence_progress_bars=not args.verbose, accumulate_steps=args.accumulate_steps, torch_dtype=args.torch_dtype, + train_on_completion=args.train_on_completion, + response_template=args.response_template, ) model.save(args.output_dir, save_base_model=not args.prompt_only) print_colored("[Training Complete]") diff --git a/pyproject.toml b/pyproject.toml index 7df014ce..038b71b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "torch>=2.0.1", "tqdm>=4.65.0", "transformers>=4.32.0", + "trl>=0.7.2", # GK-AUG-25-2023 NOTE: mpt branch on Mayank's fork was merged to peft main on Aug 24 and it got deleted # which broke caikit-nlp build. peft hasn't released newer version yet, so to get # the build fix, we pulling peft from main branch commit. In future, we will pull PEFT from diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 5cf82439..1ae58638 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -162,6 +162,37 @@ def test_train_model(causal_lm_train_kwargs, set_cpu_device): assert isinstance(pred, GeneratedTextResult) +def test_train_model_on_completion(causal_lm_train_kwargs, set_cpu_device): + """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + caikit_nlp.data_model.GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + caikit_nlp.data_model.GenerationTrainRecord( + input="@bar this is the worst idea ever.", output="complaint" + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + "train_on_completion": True, + "response_template": "#answer:", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = caikit_nlp.modules.text_generation.PeftPromptTuning.train( + **causal_lm_train_kwargs + ) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.model.dtype is torch.float32 + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + def test_gen_trained_mpt(causal_lm_train_kwargs, set_cpu_device): """Ensure that we are able to do generation on causal-lm model trained using MPT."""