Skip to content

Commit

Permalink
Merge pull request #85 from jianzhnie/dev
Browse files Browse the repository at this point in the history
update save_peft_model_callback
  • Loading branch information
jianzhnie authored Aug 8, 2023
2 parents ea2c181 + 89d5f15 commit f0b8a2a
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 86 deletions.
8 changes: 5 additions & 3 deletions chatllms/model/load_pretrain_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_model_tokenizer(
if not args.full_finetune:
if checkpoint_dir is not None:
# Load pre-trained adapters from checkpoint directory.
logger.info('Loading adapters from checkpoint... ')
logger.info(f'Loading adapters from {checkpoint_dir}... ')
adapter_model_path = join(checkpoint_dir, 'adapter_model')
assert exists(join(adapter_model_path, CONFIG_NAME)) and exists(
join(adapter_model_path, WEIGHTS_NAME)), ValueError(
Expand All @@ -162,9 +162,11 @@ def load_model_tokenizer(

else:
# Add LoRA modules to the model.
logger.info('No checkpoint_dir founded, will init adapters...')
logger.info('Adding LoRA modules...')
logger.info(
'No pretrained adapters checkpoints founded, will init adapters...'
)
modules = find_all_linear_names(args, model)
logger.info(f'Adding LoRA modules: ({modules}) ...')
config = LoraConfig(
r=args.lora_r, # lora层A矩阵的列大小和B矩阵的行大小
lora_alpha=args.lora_alpha, # 缩放因子
Expand Down
64 changes: 32 additions & 32 deletions chatllms/model/save_peft_model_callback.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,91 @@
import os
from typing import Any, Dict

import transformers
from transformers import (PreTrainedModel, TrainerCallback, TrainerControl,
TrainingArguments)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR


class SavePeftModelCallback(transformers.TrainerCallback):
class SavePeftModelCallback(TrainerCallback):
"""
A TrainerCallback that saves the PEFT model checkpoint during training.
Callback to save PEFT model checkpoints during training.
Saves both the full model and the adapter model to separate directories
within the checkpoint directory.
"""
def save_model(self, args: Any, state: transformers.TrainingArguments,
def save_model(self, args: Any, state: TrainingArguments,
kwargs: Dict[str, Any]) -> None:
"""
Saves the PEFT model checkpoint.
Args:
args (Any): The command line arguments passed to the script.
state (transformers.TrainingArguments): The current state of training.
state (TrainingArguments): The current state of training.
kwargs (Dict[str, Any]): A dictionary of additional keyword arguments.
Raises:
TypeError: If `state` is not an instance of `transformers.TrainingArguments`.
TypeError: If `state` is not an instance of `TrainingArguments`.
"""
print('Saving PEFT checkpoint...')
print('+' * 20, 'Saving PEFT Model Checkpoint CallBack', '+' * 20)

# Get the checkpoint directory for saving models.
if state.best_model_checkpoint is not None:
# If best model checkpoint exists, use its directory as the checkpoint folder
checkpoint_folder = os.path.join(state.best_model_checkpoint,
'adapter_model')
checkpoint_dir = os.path.join(state.best_model_checkpoint,
'adapter_model')
else:
# Otherwise, create a new checkpoint folder using the output directory and current global step
checkpoint_folder = os.path.join(
checkpoint_dir = os.path.join(
args.output_dir,
f'{PREFIX_CHECKPOINT_DIR}-{state.global_step}')

# Create path for the PEFT model
peft_model_path = os.path.join(checkpoint_folder, 'adapter_model')
kwargs['model'].save_pretrained(peft_model_path)
peft_model_path = os.path.join(checkpoint_dir, 'adapter_model')
model: PreTrainedModel = kwargs['model']
model.save_pretrained(peft_model_path)

# Create path for the PyTorch model binary file and remove it if it already exists
pytorch_model_path = os.path.join(checkpoint_folder,
'pytorch_model.bin')
pytorch_model_path = os.path.join(checkpoint_dir, 'pytorch_model.bin')
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)

def on_save(
self, args: Any, state: transformers.TrainingArguments,
control: transformers.trainer_callback.TrainerControl,
**kwargs: Dict[str,
Any]) -> transformers.trainer_callback.TrainerControl:
def on_save(self, args: Any, state: TrainingArguments,
control: TrainerControl,
**kwargs: Dict[str, Any]) -> TrainerControl:
"""
Callback method that calls save_model() and returns `control` argument.
Args:
args (Any): The command line arguments passed to the script.
state (transformers.TrainingArguments): The current state of training.
control (transformers.trainer_callback.TrainerControl): \
state (TrainingArguments): The current state of training.
control (trainer_callback.TrainerControl): \
The current state of the TrainerCallback's control flow.
kwargs (Dict[str, Any]): A dictionary of additional keyword arguments.
Returns:
transformers.trainer_callback.TrainerControl: The current state of the TrainerCallback's control flow.
trainer_callback.TrainerControl: The current state of the TrainerCallback's control flow.
Raises:
TypeError: If `state` is not an instance of `transformers.TrainingArguments`.
TypeError: If `state` is not an instance of `TrainingArguments`.
"""
self.save_model(args, state, kwargs)
return control

def on_train_end(self, args: Any, state: transformers.TrainingArguments,
control: transformers.trainer_callback.TrainerControl,
**kwargs: Dict[str, Any]) -> None:
def on_train_end(self, args: Any, state: TrainingArguments,
control: TrainerControl, **kwargs: Dict[str,
Any]) -> None:
"""
Callback method that saves the model checkpoint and creates a 'completed' file in the output directory.
Args:
args (Any): The command line arguments passed to the script.
state (transformers.TrainingArguments): The current state of training.
control (transformers.trainer_callback.TrainerControl): \
state (TrainingArguments): The current state of training.
control (trainer_callback.TrainerControl): \
The current state of the TrainerCallback's control flow.
kwargs (Dict[str, Any]): A dictionary of additional keyword arguments.
Raises:
TypeError: If `state` is not an instance of `transformers.TrainingArguments`.
TypeError: If `state` is not an instance of `TrainingArguments`.
"""

# Define a helper function to create a 'completed' file in the output directory
Expand All @@ -92,6 +95,3 @@ def touch(fname, times=None):

# Create the 'completed' file in the output directory
touch(os.path.join(args.output_dir, 'completed'))

# Save the model checkpoint
self.save_model(args, state, kwargs)
12 changes: 11 additions & 1 deletion chatllms/train/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import math
import os
from typing import Any, Dict

Expand Down Expand Up @@ -28,9 +29,11 @@ def train_and_evaluate(trainer: transformers.Trainer, args: argparse.Namespace,
logger.info('*** Train ***')
logger.info('=' * 80)
train_result = trainer.train(
resume_from_checkpoint=args.checkpoint_dir)
resume_from_checkpoint=args.resume_checkpoint)
metrics = train_result.metrics

metrics['train_samples'] = len(trainer.train_dataset)

# Log and save training metrics
trainer.log_metrics('train', metrics)
trainer.save_metrics('train', metrics)
Expand All @@ -48,6 +51,13 @@ def train_and_evaluate(trainer: transformers.Trainer, args: argparse.Namespace,
# Evaluate the trained model and obtain evaluation metrics
metrics = trainer.evaluate(metric_key_prefix='eval')

try:
perplexity = math.exp(metrics['eval_loss'])
except OverflowError:
perplexity = float('inf')

metrics['perplexity'] = perplexity
metrics['eval_samples'] = len(trainer.eval_dataset)
# Log and save evaluation metrics
trainer.log_metrics('eval', metrics)
trainer.save_metrics('eval', metrics)
Expand Down
80 changes: 39 additions & 41 deletions chatllms/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import bitsandbytes as bnb
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from transformers.trainer_utils import get_last_checkpoint

from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN,
DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN)
Expand Down Expand Up @@ -221,7 +220,8 @@ def verify_dtypes(model: torch.nn.Module) -> None:
return None


def get_last_checkpoint(checkpoint_dir: str) -> Tuple[str, bool]:
def check_training_finished(args: argparse.Namespace,
logger=None) -> Tuple[str, bool]:
"""
Given a directory containing previous saved checkpoints, returns the path to the last checkpoint
if available along with a boolean flag indicating whether training has already been completed.
Expand All @@ -234,30 +234,44 @@ def get_last_checkpoint(checkpoint_dir: str) -> Tuple[str, bool]:
whether training has already been completed.
"""
# Check if provided directory exists
if isdir(checkpoint_dir):

if isdir(args.output_dir) and not args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(args.output_dir)
if last_checkpoint:
logger.info(
f'Find lasest checkpoint: ({last_checkpoint}) in ({args.output_dir})'
)
# Check if 'completed' file exists in the directory - indicates training has completed
is_completed = exists(join(checkpoint_dir, 'completed'))
if is_completed:
return None, True # Already finished

# Find the latest checkpoint by checking all subdirectories named 'checkpoint-*'
max_step = 0
for filename in os.listdir(checkpoint_dir):
if isdir(join(checkpoint_dir,
filename)) and filename.startswith('checkpoint'):
max_step = max(max_step,
int(filename.replace('checkpoint-', '')))
if max_step == 0:
return None, is_completed # Training started, but no checkpoint found

# Return path to the latest checkpoint directory
checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
print(f'Found a previous checkpoint at: {checkpoint_dir}')
return checkpoint_dir, is_completed

is_completed = exists(join(args.output_dir, 'completed'))
if last_checkpoint and is_completed:
raise AssertionError(
f'Detected that training was already completed! Output directory ({args.output_dir}) already exists and is not empty. '
'Use --overwrite_output_dir to overcome.')

elif last_checkpoint:
# Return path to the latest checkpoint directory
logger.info(
f'Checkpoint detected, resuming training at ({last_checkpoint}). To avoid this behavior, change '
'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
)
return last_checkpoint, is_completed
# The directory does not exist, meaning this is the first time the training is being run
return None, False
logger.info(
f'The output directory: ({args.output_dir}) do not exists or emppty or you have set --overwrite_output_dir... will train from scratch'
)
return None, False # first training


def find_last_checkpoint(checkpoint_dir):
# Find the latest checkpoint by checking all subdirectories named 'checkpoint-*'
max_step = 0
last_checkpoint = None
for filename in os.listdir(checkpoint_dir):
if isdir(join(checkpoint_dir,
filename)) and filename.startswith('checkpoint'):
max_step = max(max_step, int(filename.replace('checkpoint-', '')))
if max_step > 0:
last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}')
return last_checkpoint


def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str):
Expand All @@ -270,19 +284,3 @@ def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str):
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa


# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores


def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
2 changes: 1 addition & 1 deletion data/alpaca_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ safety_prompt_part2:
hf_hub_url: ''
local_path: /home/robin/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json
dataset_format: alpaca
multi_turn: False
multi_turn: False
2 changes: 1 addition & 1 deletion data/alpaca_zh_pcyn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ safety_prompt_part2:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json
dataset_format: alpaca
multi_turn: False
multi_turn: False
5 changes: 5 additions & 0 deletions data/run_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
100PoisonMpts:
hf_hub_url: 'damo/100PoisonMpts'
local_path: /home/robin/prompt_data/100PoisonMpts/train_alpaca.json
dataset_format: alpaca
multi_turn: False
2 changes: 1 addition & 1 deletion data/vicuna_zh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ safety_prompt_part2:
hf_hub_url: ''
local_path: /home/robin/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json
dataset_format: sharegpt
multi_turn: True
multi_turn: True
2 changes: 1 addition & 1 deletion data/vicuna_zh_pcyn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ safety_prompt_part2:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json
dataset_format: sharegpt
multi_turn: True
multi_turn: True
14 changes: 9 additions & 5 deletions train_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SavePeftModelCallback, load_model_tokenizer)
from chatllms.train.training import train_and_evaluate
from chatllms.utils.logger_utils import get_root_logger
from chatllms.utils.model_utils import (get_last_checkpoint,
from chatllms.utils.model_utils import (check_training_finished,
print_trainable_parameters,
verify_dtypes)

Expand All @@ -41,12 +41,16 @@ def main():
log_file = os.path.join(args.output_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')

# Log on each process the small summary:
logger.info(
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
+
f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
)
logger.info('Training/evaluation parameters %s', args)
# Check if training was already completed.
checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
args.checkpoint_dir = checkpoint_dir
if completed_training:
logger.warning('Detected that training was already completed!')
checkpoint_dir, completed_training = check_training_finished(args, logger)
args.resume_checkpoint = checkpoint_dir

# load model and tokenizer
model, tokenizer = load_model_tokenizer(
Expand Down

0 comments on commit f0b8a2a

Please sign in to comment.