diff --git a/generate.py b/generate.py index 36c0d3f7..dec270a8 100644 --- a/generate.py +++ b/generate.py @@ -45,9 +45,6 @@ def main( prompts=[(instruction, input)], ) - if max_seq_len is None: - max_seq_len = model.config_.max_seq_len_ - output = mlora.generate( model, tokenizer, diff --git a/inference.py b/inference.py index b3d32990..9fc8899e 100644 --- a/inference.py +++ b/inference.py @@ -168,7 +168,7 @@ def generate_with_streaming(**kwargs): minimum=1, maximum=model.config_.max_seq_len_, step=1, - value=128, + value=1024, label="Max Tokens", ), gr.components.Checkbox(label="Stream Output", value=True), diff --git a/mlora.py b/mlora.py index a051a917..44e6d15c 100644 --- a/mlora.py +++ b/mlora.py @@ -234,22 +234,25 @@ def inference_callback(cur_pos, outputs): def inference( - llm_model: mlora.LLMModel, + model: mlora.LLMModel, tokenizer: mlora.Tokenizer, - adapters: List[mlora.GenerateConfig], + configs: List[mlora.GenerateConfig], + concurrent_jobs: int, ): while True: input_raw = input("INPUT WITHOUT PROMPT: ") if input_raw == "QUIT": return - for config in adapters: + for config in configs: config.prompts = [input_raw] callback = None if args.disable_log else inference_callback outputs = mlora.generate( - llm_model, + model, tokenizer, - adapters, + configs, + max_gen_len=128, use_cache=args.disable_cache, + concurrent_jobs=concurrent_jobs, cache_implementation=args.cache_implementation, stream_callback=callback, ) @@ -298,7 +301,12 @@ def inference( mlora_backend.empty_cache() if args.inference: - inference(model, tokenizer, adapters) + inference( + model=model, + tokenizer=tokenizer, + configs=adapters, + concurrent_jobs=config.get("inference_lora_simultaneously_num", 2), + ) elif args.evaluate: mlora.evaluate( model=model, diff --git a/mlora/evaluator.py b/mlora/evaluator.py index 404ede66..178744ae 100644 --- a/mlora/evaluator.py +++ b/mlora/evaluator.py @@ -80,11 +80,7 @@ def _dispatch_task_in(tokenizer, configs, concurrent_jobs, max_seq_len): if len(tokens) > max_seq_len: tokens = tokens[:max_seq_len] max_tokens_len = max(len(tokens), max_tokens_len) - # sequence_lengths.append(len(tokens)) - # while len(tokens) < max_seq_len: - # tokens.append(tokenizer.pad_id_) batch_tokens.append(tokens) - # atten_masks.append(tokenizer.mask_from(tokens)) batch_labels.append(labels.copy()) config.batch_start_idx_ = config.batch_end_idx_ diff --git a/mlora/generator.py b/mlora/generator.py index beaebab1..fae4cf59 100644 --- a/mlora/generator.py +++ b/mlora/generator.py @@ -1,21 +1,32 @@ import logging +import sys from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +from mlora.backends import get_backend from mlora.common import LLMBatchConfig, LLMModelInput, Tokens, cache_factory from mlora.model import LLMModel from mlora.prompter import Prompter from mlora.tokenizer import Tokenizer +@dataclass +class GenerateData: + adapter_name_: str = None + prompt_index_: int = None + prefix_length_: int = None + raw_tokens_: Tokens = None + + @dataclass class GenerateConfig: adapter_name: str = None prompts: List[Union[str, Tuple[str, str]]] = None prompt_template: str = None # Generate Arguments + batch_size: int = 8 stop_token: str = None temperature: float = 1 top_p: float = 0.9 @@ -24,10 +35,9 @@ class GenerateConfig: repetition_penalty: float = 1.1 renormalize_logits: bool = True # Do not set these manually - stop_token_: torch.Tensor = None - batch_start_idx_: int = -1 - batch_end_idx_: int = -1 prompter_: Prompter = None + stop_token_: torch.Tensor = None + data_: List[GenerateData] = None # Set prompt_template_ to enable the prompter def generate_prompt(self, instruction: str, input: str = None) -> str: @@ -50,6 +60,11 @@ def get_response(self, output: str) -> str: else: return self.prompter_.get_response(output) + def reset_parameters(self): + self.prompter_ = Prompter(self.prompt_template) + self.stop_token_ = None + self.data_ = [] + def _logits_sample_top_p(probs, p, filter_value=float("-inf"), min_tokens_to_keep=1): sorted_logits, sorted_indices = torch.sort(probs, descending=False) @@ -119,79 +134,172 @@ def logits_process( return next_token.reshape(-1) -def gen_outputs(configs, tokenizer, prompts, tokens, max_gen_len): - outputs = [] - for i, toks in enumerate(tokens.tolist()): - start = len(prompts[i]) - toks = toks[start : start + max_gen_len] - if tokenizer.pad_id_ in toks: - pad_idx = toks.index(tokenizer.pad_id_) - toks = toks[:pad_idx] +def _extract_effective_tokens( + tokenizer: Tokenizer, + prefix_length: int, + tokens: Tokens, + remove_prefix=True, + remove_pad=True, + remove_eos=True, +): + if remove_prefix: + tokens = tokens[prefix_length:] + + if remove_pad and tokenizer.pad_id_ in tokens: + pad_idx = tokens.index(tokenizer.pad_id_) + tokens = tokens[:pad_idx] - if tokenizer.eos_id_ in toks: - stop_idx = toks.index(tokenizer.eos_id_) - toks = toks[:stop_idx] + if remove_eos and tokenizer.eos_id_ in tokens: + stop_idx = tokens.index(tokenizer.eos_id_) + tokens = tokens[:stop_idx] - outputs.append(tokenizer.decode(toks)) + return tokens - packed_outputs = {} - for config in configs: - packed_outputs[config.adapter_name] = [ - config.get_response(output) - for output in outputs[config.batch_start_idx_ : config.batch_end_idx_] - ] + +def _gen_outputs( + tokenizer: Tokenizer, + config_dict: Dict[str, GenerateConfig], + current_jobs: List[GenerateData], + tokens: torch.Tensor, +): + tokens = tokens.tolist() + packed_outputs: Dict[str, List[str]] = {} + for idx, data in enumerate(current_jobs): + output = config_dict[data.adapter_name_].get_response( + tokenizer.decode( + _extract_effective_tokens( + tokenizer, + data.prefix_length_, + tokens[idx], + remove_prefix=True, + remove_pad=True, + remove_eos=True, + ) + ) + ) + if data.adapter_name_ in packed_outputs: + packed_outputs[data.adapter_name_].append(output) + else: + packed_outputs[data.adapter_name_] = [output] return packed_outputs -@torch.inference_mode() -def generate( - model: LLMModel, - tokenizer: Tokenizer, +def _dispatch_task_in( configs: List[GenerateConfig], - max_gen_len: int = None, - use_cache: bool = True, - cache_implementation: Optional[str] = None, - stream_callback: Optional[Callable] = None, + concurrent_jobs: int, + strategy: str = "fair", ): - - device = torch.device(model.device_) - raw_prompts: List[Tokens] = [] - batch_data_config: List[LLMBatchConfig] = [] - config_dict = {} + assert strategy in ["fair", "fifo"], f"Unknown dispatch strategy {strategy}" + current_jobs = [] + batch_config = [] + input_tokens = [] + max_tokens_len = 0 + min_tokens_len = sys.maxsize for config in configs: - config_dict[config.adapter_name] = config - if config.stop_token: - stop_token = tokenizer.encode(" " + config.stop_token, False)[-1] + if len(batch_config) >= concurrent_jobs: + break + + if len(config.data_) == 0: + continue + + if strategy == "fair": + per_task_jobs = max(concurrent_jobs // len(configs), 1) else: - stop_token = tokenizer.eos_id_ - config.stop_token_ = torch.tensor( - [stop_token], dtype=torch.int64, device=device - ) - tokens = [tokenizer.encode(prompt) for prompt in config.get_prompts()] - config.batch_start_idx_ = len(raw_prompts) - config.batch_end_idx_ = config.batch_start_idx_ + len(tokens) - batch_data_config.append( + per_task_jobs = concurrent_jobs + + per_task_jobs = min(per_task_jobs, config.batch_size) + + batch_start_idx = len(input_tokens) + while per_task_jobs > 0 and len(config.data_) > 0: + per_task_jobs = per_task_jobs - 1 + data = config.data_.pop(0) + current_jobs.append(data) + tokens = data.raw_tokens_ + max_tokens_len = max(len(tokens), max_tokens_len) + min_tokens_len = min(len(tokens), min_tokens_len) + input_tokens.append(tokens) + + batch_config.append( LLMBatchConfig( - config.adapter_name, config.batch_start_idx_, config.batch_end_idx_ + adapter_name_=config.adapter_name, + batch_start_idx_=batch_start_idx, + batch_end_idx_=len(input_tokens), ) ) - raw_prompts.extend(tokens) - batch_size = len(raw_prompts) - min_tokens_len = min(len(t) for t in raw_prompts) - max_tokens_len = max(len(t) for t in raw_prompts) - assert max_tokens_len <= model.config_.max_seq_len_ - if max_gen_len is None: - max_gen_len = model.config_.max_seq_len_ - max_tokens_len + return ( + current_jobs, + batch_config, + input_tokens, + max_tokens_len, + min_tokens_len, + ) - total_len = min(model.config_.max_seq_len_, max_gen_len + max_tokens_len) - if cache_implementation is not None: - use_cache = True +def _dispatch_task_out( + tokenizer: Tokenizer, + config_dict: Dict[str, GenerateConfig], + current_jobs: List[GenerateData], + tokens: torch.Tensor, + stop_reached: torch.Tensor, +): + tokens = tokens.tolist() + stop_reached = stop_reached.view(-1).tolist() + packed_outputs: Dict[str, List[str]] = {} + running_jobs: List[GenerateData] = [] + for idx, data in enumerate(current_jobs): + if stop_reached[idx]: + output = config_dict[data.adapter_name_].get_response( + tokenizer.decode( + _extract_effective_tokens( + tokenizer, + data.prefix_length_, + tokens[idx], + remove_prefix=True, + remove_pad=True, + remove_eos=True, + ) + ) + ) + if data.adapter_name_ in packed_outputs: + packed_outputs[data.adapter_name_].append(output) + else: + packed_outputs[data.adapter_name_] = [output] + else: + data.raw_tokens_ = _extract_effective_tokens( + tokenizer, + data.prefix_length_, + tokens[idx], + remove_prefix=False, + remove_pad=True, + remove_eos=False, + ) + running_jobs.append(data) - if use_cache and cache_implementation is None: - cache_implementation = model.model_.cache_implementation() + return packed_outputs, running_jobs + + +def _batch_generate( + model: LLMModel, + tokenizer: Tokenizer, + max_gen_len: Optional[int], + use_cache: bool, + cache_implementation: Optional[str], + stream_callback: Optional[Callable], + config_dict: Dict[str, GenerateConfig], + current_jobs: List[GenerateData], + batch_config: List[LLMBatchConfig], + input_tokens: List[Tokens], + max_tokens_len: int, + min_tokens_len: int, +): + get_backend().empty_cache() + device = torch.device(model.device_) + batch_size = len(input_tokens) + if max_gen_len is None: + max_gen_len = model.config_.max_seq_len_ - max_tokens_len + total_len = min(model.config_.max_seq_len_, max_gen_len + max_tokens_len) past_key_values = ( cache_factory( @@ -207,7 +315,7 @@ def generate( tokens = torch.full( (batch_size, total_len), tokenizer.pad_id_, dtype=torch.int64, device=device ) - for k, t in enumerate(raw_prompts): + for k, t in enumerate(input_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.int64, device=device) prev_pos = 0 @@ -215,7 +323,7 @@ def generate( input_text_mask = tokens != tokenizer.pad_id_ for cur_pos in range(min_tokens_len, total_len): input_data = LLMModelInput( - batch_configs_=batch_data_config, + batch_configs_=batch_config, batch_tokens_=tokens[:, prev_pos:cur_pos].tolist(), inference_mode_=True, ) @@ -242,20 +350,113 @@ def generate( next_token, ).to(torch.int64) tokens[start_idx:end_idx, cur_pos] = next_token - stop_reached |= (~input_text_mask[:, cur_pos]) & ( + stop_criteria = (~input_text_mask[start_idx:end_idx, cur_pos]) & ( next_token == config.stop_token_ ) + stop_reached[start_idx:end_idx] |= stop_criteria + + stop_reached |= total_len - cur_pos == 1 + + if any(stop_reached): + break if stream_callback is not None: stream_callback( cur_pos, - gen_outputs(configs, tokenizer, raw_prompts, tokens, max_gen_len), + _gen_outputs( + tokenizer, + config_dict, + current_jobs, + tokens, + ), ) - if all(stop_reached): - break - if use_cache: prev_pos = cur_pos - return gen_outputs(configs, tokenizer, raw_prompts, tokens, max_gen_len) + return _dispatch_task_out( + tokenizer, config_dict, current_jobs, tokens, stop_reached + ) + + +@torch.inference_mode() +def generate( + model: LLMModel, + tokenizer: Tokenizer, + configs: List[GenerateConfig], + max_gen_len: Optional[int] = None, + use_cache: bool = True, + dispatch_strategy: str = "fair", + concurrent_jobs: Optional[int] = None, + cache_implementation: Optional[str] = None, + stream_callback: Optional[Callable] = None, +): + if concurrent_jobs is None: + concurrent_jobs = len(configs) + logging.info(f"Setting concurrent jobs to {concurrent_jobs} automatically") + + assert concurrent_jobs > 0 + + # prepare for generation + device = torch.device(model.device_) + config_dict = {} + for config in configs: + config.reset_parameters() + config_dict[config.adapter_name] = config + if config.stop_token is not None: + stop_token = tokenizer.encode(" " + config.stop_token, False)[-1] + else: + stop_token = tokenizer.eos_id_ + config.stop_token_ = torch.tensor( + [stop_token], dtype=torch.int64, device=device + ) + for idx, prompt in enumerate(config.prompts): + args = prompt if isinstance(prompt, Tuple) else (prompt, None) + tokens = tokenizer.encode(config.generate_prompt(*args)) + assert ( + len(tokens) < model.config_.max_seq_len_ + ), "Inputs exceeded max sequence length of model." + config.data_.append( + GenerateData( + adapter_name_=config.adapter_name, + prompt_index_=idx, + prefix_length_=len(tokens), + raw_tokens_=tokens, + ) + ) + + if cache_implementation is not None: + use_cache = True + + if use_cache and cache_implementation is None: + cache_implementation = model.model_.cache_implementation() + + packed_outputs: Dict[str, List] = {} + + while True: + dispatch_args = _dispatch_task_in(configs, concurrent_jobs, dispatch_strategy) + + if len(dispatch_args[0]) == 0: + break + + outputs, running_jobs = _batch_generate( + model, + tokenizer, + max_gen_len, + use_cache, + cache_implementation, + stream_callback, + config_dict, + *dispatch_args, + ) + + for name, output in outputs.items(): + if name in packed_outputs: + packed_outputs[name].extend(output) + else: + packed_outputs[name] = output + + for data in running_jobs: + config_dict[data.adapter_name_].data_.append(data) + + return packed_outputs diff --git a/tests/dummy_train.py b/tests/dummy_train.py index a0f4e8fa..92b0a1f2 100644 --- a/tests/dummy_train.py +++ b/tests/dummy_train.py @@ -12,7 +12,7 @@ def main( ): mlora.setup_logging("INFO") - model = mlora.LLMModel.from_pretrained( + model: mlora.LLMModel = mlora.LLMModel.from_pretrained( base_model, device=mlora.get_backend().default_device_name(), load_dtype=torch.bfloat16, @@ -47,23 +47,34 @@ def main( lora_config, lora_weight = model.unload_adapter(adapter_name) model.init_adapter(lora_config, lora_weight) + model.init_adapter(mlora.AdapterConfig(adapter_name="default")) - generate_config = mlora.GenerateConfig( - adapter_name=adapter_name, - prompts=[test_prompt], - stop_token="\n", - ) + generate_configs = [ + mlora.GenerateConfig( + adapter_name=adapter_name, + prompts=[test_prompt], + stop_token="\n", + ), + mlora.GenerateConfig( + adapter_name="default", + prompts=[test_prompt], + stop_token="\n", + ), + ] - output = mlora.generate( + outputs = mlora.generate( model=model, tokenizer=tokenizer, - configs=[generate_config], + configs=generate_configs, + max_gen_len=128, ) - for prompt in output[adapter_name]: - print(f"\n{'='*10}\n") - print(prompt) - print(f"\n{'='*10}\n") + print(f"\n{'='*10}\n") + print(f"PROMPT: {test_prompt}") + for adapter_name, output in outputs.items(): + print(f"{adapter_name} OUTPUT:") + print(output[0]) + print(f"\n{'='*10}\n") if __name__ == "__main__":