-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
226 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
tmp* | ||
tabby_data_pipeline.egg-info | ||
tabby_data_pipeline.egg-info | ||
log.txt |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
13 changes: 0 additions & 13 deletions
13
python/tabby-eval/tabby_data_pipeline.egg-info/SOURCES.txt
This file was deleted.
Oops, something went wrong.
1 change: 0 additions & 1 deletion
1
python/tabby-eval/tabby_data_pipeline.egg-info/dependency_links.txt
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
import os | ||
import glob | ||
from dataclasses import dataclass, field | ||
from typing import List | ||
|
||
import peft | ||
import torch | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
HfArgumentParser, | ||
Trainer, | ||
TrainingArguments, | ||
) | ||
from datasets import Dataset, load_dataset | ||
|
||
|
||
class ConstantLengthDataset: | ||
""" | ||
Iterable dataset that returns constant length chunks of tokens from stream of text files. | ||
Args: | ||
tokenizer (Tokenizer): The processor used for proccessing the data. | ||
dataset (dataset.Dataset): Dataset with text files. | ||
infinite (bool): If True the iterator is reset after dataset reaches end else stops. | ||
seq_length (int): Length of token sequences to return. | ||
num_of_sequences (int): Number of token sequences to keep in buffer. | ||
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer, | ||
dataset, | ||
infinite=False, | ||
seq_length=1024, | ||
num_of_sequences=1024, | ||
chars_per_token=3.6, | ||
content_field="content", | ||
): | ||
self.tokenizer = tokenizer | ||
self.concat_token_id = tokenizer.eos_token_id | ||
self.dataset = dataset | ||
self.seq_length = seq_length | ||
self.infinite = infinite | ||
self.current_size = 0 | ||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences | ||
self.content_field = content_field | ||
|
||
def __call__(self): | ||
def gen(): | ||
for x in self: | ||
yield x | ||
|
||
return gen() | ||
|
||
def __iter__(self): | ||
for buffer in self._read_dataset_into_buffer(): | ||
yield from self._tokenize(buffer) | ||
|
||
def _tokenize(self, buffer): | ||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] | ||
|
||
all_token_ids = [] | ||
for tokenized_input in tokenized_inputs: | ||
all_token_ids.extend(tokenized_input + [self.concat_token_id]) | ||
|
||
for i in range(0, len(all_token_ids), self.seq_length): | ||
input_ids = all_token_ids[i : i + self.seq_length] | ||
|
||
if len(input_ids) < self.seq_length: | ||
input_ids = all_token_ids[-self.seq_length :] | ||
|
||
if len(input_ids) == self.seq_length: | ||
self.current_size += 1 | ||
yield dict(input_ids=input_ids, labels=input_ids) | ||
|
||
def _read_dataset_into_buffer(self): | ||
iterator = iter(self.dataset) | ||
more_examples = True | ||
while more_examples: | ||
buffer, buffer_len = [], 0 | ||
while True: | ||
if buffer_len >= self.max_buffer_size: | ||
break | ||
try: | ||
buffer.append(next(iterator)[self.content_field]) | ||
buffer_len += len(buffer[-1]) | ||
except StopIteration: | ||
if self.infinite: | ||
iterator = iter(self.dataset) | ||
else: | ||
more_examples = False | ||
break | ||
yield buffer | ||
|
||
|
||
@dataclass | ||
class TrainLoraArguments: | ||
data_path: str = field(metadata={"help": "Dataset dir for training / eval "}) | ||
output_dir: str = field(metadata={"help": "Output dir for checkpoint"}) | ||
base_model: str = field( | ||
default="TabbyML/J-350M", metadata={"help": "Base model for fine-tuning"} | ||
) | ||
|
||
batch_size: int = 128 | ||
micro_batch_size: int = 4 | ||
num_epochs: int = 3 | ||
learning_rate: float = 3e-4 | ||
cutoff_len: int = 256 | ||
|
||
# Evaluations | ||
val_set_size: int = 2000 | ||
eval_steps: int = 200 | ||
|
||
# Lora Hyperparams | ||
lora_r: int = 8 | ||
lora_alpha: int = 16 | ||
lora_dropout: float = 0.05 | ||
lora_target_modules: List[str] = ( | ||
[ | ||
"q_proj", | ||
"v_proj", | ||
], | ||
) | ||
resume_from_checkpoint: str = None # either training checkpoint or final adapter | ||
half: bool = True | ||
|
||
|
||
def parse_args() -> TrainLoraArguments: | ||
parser = HfArgumentParser(TrainLoraArguments) | ||
return parser.parse_args() | ||
|
||
|
||
def train(args: TrainLoraArguments): | ||
gradient_accumulation_steps = args.batch_size // args.micro_batch_size | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
args.base_model, torch_dtype=torch.float16 if args.half else torch.float32 | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.base_model) | ||
|
||
config = peft.LoraConfig( | ||
r=args.lora_r, | ||
lora_alpha=args.lora_alpha, | ||
target_modules=args.lora_target_modules, | ||
lora_dropout=args.lora_dropout, | ||
bias="none", | ||
task_type=peft.TaskType.CAUSAL_LM, | ||
) | ||
model = peft.get_peft_model(model, config) | ||
|
||
data_files = glob.glob(os.path.join(args.data_path, "*.jsonl")) | ||
print("Collected data files...", data_files) | ||
dataset = load_dataset("json", data_files=data_files)["train"] | ||
data = Dataset.from_generator(ConstantLengthDataset(tokenizer, dataset)) | ||
|
||
resume_from_checkpoint = args.resume_from_checkpoint | ||
if resume_from_checkpoint: | ||
# Check the available weights and load them | ||
checkpoint_name = os.path.join( | ||
resume_from_checkpoint, "pytorch_model.bin" | ||
) # Full checkpoint | ||
if not os.path.exists(checkpoint_name): | ||
checkpoint_name = os.path.join( | ||
resume_from_checkpoint, "adapter_model.bin" | ||
) # only LoRA model - LoRA config above has to fit | ||
resume_from_checkpoint = False # So the trainer won't try loading its state | ||
# The two files above have a different name depending on how they were saved, but are actually the same. | ||
if os.path.exists(checkpoint_name): | ||
print(f"Restarting from {checkpoint_name}") | ||
adapters_weights = torch.load(checkpoint_name) | ||
model = peft.set_peft_model_state_dict(model, adapters_weights) | ||
else: | ||
print(f"Checkpoint {checkpoint_name} not found") | ||
|
||
model.print_trainable_parameters() # Be more transparent about the % of trainable params. | ||
|
||
train_val = data.train_test_split( | ||
test_size=args.val_set_size, shuffle=True, seed=42 | ||
) | ||
train_data = train_val["train"].shuffle() | ||
val_data = train_val["test"].shuffle() | ||
|
||
trainer = Trainer( | ||
model=model, | ||
train_dataset=train_data, | ||
eval_dataset=val_data, | ||
args=TrainingArguments( | ||
per_device_train_batch_size=args.micro_batch_size, | ||
gradient_accumulation_steps=gradient_accumulation_steps, | ||
warmup_steps=100, | ||
num_train_epochs=args.num_epochs, | ||
learning_rate=args.learning_rate, | ||
fp16=args.half, | ||
logging_steps=10, | ||
evaluation_strategy="steps", | ||
save_strategy="steps", | ||
eval_steps=args.eval_steps, | ||
save_steps=args.eval_steps, | ||
output_dir=args.output_dir, | ||
save_total_limit=3, | ||
load_best_model_at_end=True, | ||
), | ||
) | ||
model.config.use_cache = False | ||
|
||
old_state_dict = model.state_dict | ||
model.state_dict = ( | ||
lambda self, *_, **__: peft.get_peft_model_state_dict(self, old_state_dict()) | ||
).__get__(model, type(model)) | ||
|
||
model = torch.compile(model) | ||
|
||
trainer.train(resume_from_checkpoint=resume_from_checkpoint) | ||
|
||
model.save_pretrained(args.output_dir) | ||
|
||
print("\n If there's a warning about missing keys above, please disregard :)") | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
train(args) |