diff --git a/python/tabby-eval/.gitignore b/python/tabby-eval/.gitignore index ee393bdda82c..32491554bbf3 100644 --- a/python/tabby-eval/.gitignore +++ b/python/tabby-eval/.gitignore @@ -1,2 +1,3 @@ tmp* -tabby_data_pipeline.egg-info \ No newline at end of file +tabby_data_pipeline.egg-info +log.txt \ No newline at end of file diff --git a/python/tabby-eval/log.txt b/python/tabby-eval/log.txt deleted file mode 100644 index 32d5597a48f6..000000000000 --- a/python/tabby-eval/log.txt +++ /dev/null @@ -1,6 +0,0 @@ -model: TabbyML/StarCoder-1B; language: python; file: line_completion.jsonlSkipped 0 rows, 10 rows with predictions, 0 rows with errors - -model: TabbyML/StarCoder-1B; language: python; file: line_completion_rg1_bm25.jsonlSkipped 0 rows, 10 rows with predictions, 0 rows with errors - -model: TabbyML/StarCoder-1B; language: python; file: line_completion_oracle_bm25.jsonlSkipped 0 rows, 10 rows with predictions, 0 rows with errors - diff --git a/python/tabby-eval/tabby_data_pipeline.egg-info/PKG-INFO b/python/tabby-eval/tabby_data_pipeline.egg-info/PKG-INFO deleted file mode 100644 index df90e2417d21..000000000000 --- a/python/tabby-eval/tabby_data_pipeline.egg-info/PKG-INFO +++ /dev/null @@ -1,8 +0,0 @@ -Metadata-Version: 2.1 -Name: tabby-data-pipeline -Version: 0.0.0 -Requires-Dist: dagster -Requires-Dist: dagster-cloud -Provides-Extra: dev -Requires-Dist: dagster-webserver; extra == "dev" -Requires-Dist: pytest; extra == "dev" diff --git a/python/tabby-eval/tabby_data_pipeline.egg-info/SOURCES.txt b/python/tabby-eval/tabby_data_pipeline.egg-info/SOURCES.txt deleted file mode 100644 index 3766c8b5d294..000000000000 --- a/python/tabby-eval/tabby_data_pipeline.egg-info/SOURCES.txt +++ /dev/null @@ -1,13 +0,0 @@ -README.md -pyproject.toml -setup.cfg -setup.py -tabby_data_pipeline/__init__.py -tabby_data_pipeline/analyze.py -tabby_data_pipeline/assets.py -tabby_data_pipeline/predict.py -tabby_data_pipeline.egg-info/PKG-INFO -tabby_data_pipeline.egg-info/SOURCES.txt -tabby_data_pipeline.egg-info/dependency_links.txt -tabby_data_pipeline.egg-info/requires.txt -tabby_data_pipeline.egg-info/top_level.txt \ No newline at end of file diff --git a/python/tabby-eval/tabby_data_pipeline.egg-info/dependency_links.txt b/python/tabby-eval/tabby_data_pipeline.egg-info/dependency_links.txt deleted file mode 100644 index 8b137891791f..000000000000 --- a/python/tabby-eval/tabby_data_pipeline.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/python/tabby-eval/tabby_data_pipeline.egg-info/requires.txt b/python/tabby-eval/tabby_data_pipeline.egg-info/requires.txt deleted file mode 100644 index 520bc0b6a76e..000000000000 --- a/python/tabby-eval/tabby_data_pipeline.egg-info/requires.txt +++ /dev/null @@ -1,6 +0,0 @@ -dagster -dagster-cloud - -[dev] -dagster-webserver -pytest diff --git a/python/tabby-eval/tabby_data_pipeline.egg-info/top_level.txt b/python/tabby-eval/tabby_data_pipeline.egg-info/top_level.txt deleted file mode 100644 index 01ca25ad7613..000000000000 --- a/python/tabby-eval/tabby_data_pipeline.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -tabby_data_pipeline diff --git a/python/tabby/trainer.py b/python/tabby/trainer.py new file mode 100644 index 000000000000..210373b6fb44 --- /dev/null +++ b/python/tabby/trainer.py @@ -0,0 +1,224 @@ +import os +import glob +from dataclasses import dataclass, field +from typing import List + +import peft +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from datasets import Dataset, load_dataset + + +class ConstantLengthDataset: + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + Args: + tokenizer (Tokenizer): The processor used for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + infinite (bool): If True the iterator is reset after dataset reaches end else stops. + seq_length (int): Length of token sequences to return. + num_of_sequences (int): Number of token sequences to keep in buffer. + chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. + """ + + def __init__( + self, + tokenizer, + dataset, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + content_field="content", + ): + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id + self.dataset = dataset + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.content_field = content_field + + def __call__(self): + def gen(): + for x in self: + yield x + + return gen() + + def __iter__(self): + for buffer in self._read_dataset_into_buffer(): + yield from self._tokenize(buffer) + + def _tokenize(self, buffer): + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + + all_token_ids = [] + for tokenized_input in tokenized_inputs: + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + + if len(input_ids) < self.seq_length: + input_ids = all_token_ids[-self.seq_length :] + + if len(input_ids) == self.seq_length: + self.current_size += 1 + yield dict(input_ids=input_ids, labels=input_ids) + + def _read_dataset_into_buffer(self): + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + buffer.append(next(iterator)[self.content_field]) + buffer_len += len(buffer[-1]) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + yield buffer + + +@dataclass +class TrainLoraArguments: + data_path: str = field(metadata={"help": "Dataset dir for training / eval "}) + output_dir: str = field(metadata={"help": "Output dir for checkpoint"}) + base_model: str = field( + default="TabbyML/J-350M", metadata={"help": "Base model for fine-tuning"} + ) + + batch_size: int = 128 + micro_batch_size: int = 4 + num_epochs: int = 3 + learning_rate: float = 3e-4 + cutoff_len: int = 256 + + # Evaluations + val_set_size: int = 2000 + eval_steps: int = 200 + + # Lora Hyperparams + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: List[str] = ( + [ + "q_proj", + "v_proj", + ], + ) + resume_from_checkpoint: str = None # either training checkpoint or final adapter + half: bool = True + + +def parse_args() -> TrainLoraArguments: + parser = HfArgumentParser(TrainLoraArguments) + return parser.parse_args() + + +def train(args: TrainLoraArguments): + gradient_accumulation_steps = args.batch_size // args.micro_batch_size + + model = AutoModelForCausalLM.from_pretrained( + args.base_model, torch_dtype=torch.float16 if args.half else torch.float32 + ) + + tokenizer = AutoTokenizer.from_pretrained(args.base_model) + + config = peft.LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=args.lora_target_modules, + lora_dropout=args.lora_dropout, + bias="none", + task_type=peft.TaskType.CAUSAL_LM, + ) + model = peft.get_peft_model(model, config) + + data_files = glob.glob(os.path.join(args.data_path, "*.jsonl")) + print("Collected data files...", data_files) + dataset = load_dataset("json", data_files=data_files)["train"] + data = Dataset.from_generator(ConstantLengthDataset(tokenizer, dataset)) + + resume_from_checkpoint = args.resume_from_checkpoint + if resume_from_checkpoint: + # Check the available weights and load them + checkpoint_name = os.path.join( + resume_from_checkpoint, "pytorch_model.bin" + ) # Full checkpoint + if not os.path.exists(checkpoint_name): + checkpoint_name = os.path.join( + resume_from_checkpoint, "adapter_model.bin" + ) # only LoRA model - LoRA config above has to fit + resume_from_checkpoint = False # So the trainer won't try loading its state + # The two files above have a different name depending on how they were saved, but are actually the same. + if os.path.exists(checkpoint_name): + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + model = peft.set_peft_model_state_dict(model, adapters_weights) + else: + print(f"Checkpoint {checkpoint_name} not found") + + model.print_trainable_parameters() # Be more transparent about the % of trainable params. + + train_val = data.train_test_split( + test_size=args.val_set_size, shuffle=True, seed=42 + ) + train_data = train_val["train"].shuffle() + val_data = train_val["test"].shuffle() + + trainer = Trainer( + model=model, + train_dataset=train_data, + eval_dataset=val_data, + args=TrainingArguments( + per_device_train_batch_size=args.micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + warmup_steps=100, + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + fp16=args.half, + logging_steps=10, + evaluation_strategy="steps", + save_strategy="steps", + eval_steps=args.eval_steps, + save_steps=args.eval_steps, + output_dir=args.output_dir, + save_total_limit=3, + load_best_model_at_end=True, + ), + ) + model.config.use_cache = False + + old_state_dict = model.state_dict + model.state_dict = ( + lambda self, *_, **__: peft.get_peft_model_state_dict(self, old_state_dict()) + ).__get__(model, type(model)) + + model = torch.compile(model) + + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + model.save_pretrained(args.output_dir) + + print("\n If there's a warning about missing keys above, please disregard :)") + + +if __name__ == "__main__": + args = parse_args() + train(args)