diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index cd956fb50..84623ae76 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -146,9 +146,11 @@ def parse_args(): parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.') - parser.add_argument('--dtype', type=str, default='fp16', + parser.add_argument('--dtype', + type=str, + default='fp16', choices=['fp16', 'bf16'], - help = 'Training data type') + help='Training data type') parser.add_argument( '--zero_stage', type=int, diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index 0bb78cd23..b84ccdbaf 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -145,9 +145,11 @@ def parse_args(): parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.') - parser.add_argument('--dtype', type=str, default='fp16', + parser.add_argument('--dtype', + type=str, + default='fp16', choices=['fp16', 'bf16'], - help = 'Training data type') + help='Training data type') parser.add_argument( '--zero_stage', type=int, diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 9263b6d8d..342eeea4a 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -241,9 +241,11 @@ def parse_args(): parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.') - parser.add_argument('--dtype', type=str, default='fp16', + parser.add_argument('--dtype', + type=str, + default='fp16', choices=['fp16', 'bf16'], - help = 'Training data type') + help='Training data type') parser.add_argument( '--offload_reference_model', action='store_true', diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py index 187a36efe..96b3ad632 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py @@ -140,8 +140,7 @@ def _init_ref(self, actor_model_name_or_path): # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model zero_stage = 0 ds_config = get_eval_ds_config(self.args.offload_reference_model, - self.args.dtype, - zero_stage) + self.args.dtype, zero_stage) ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size #TODO(jeff): we should probably set grad accumlation steps here as well for clarity @@ -167,8 +166,7 @@ def _init_ema(self, actor_model_name_or_path): # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory zero_stage = 0 ds_config = get_eval_ds_config(self.args.offload_reference_model, - self.args.dtype, - zero_stage) + self.args.dtype, zero_stage) ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size #TODO(jeff): we should probably set grad accumlation steps here as well for clarity @@ -279,7 +277,9 @@ def _init_reward(self, critic_model_name_or_path): 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( ) * self.args.gradient_accumulation_steps - ds_eval_config = get_eval_ds_config(offload=False, dtype=self.args.dtype, stage=zero_stage) + ds_eval_config = get_eval_ds_config(offload=False, + dtype=self.args.dtype, + stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. ds_eval_config[ diff --git a/applications/DeepSpeed-Chat/training/utils/ds_utils.py b/applications/DeepSpeed-Chat/training/utils/ds_utils.py index f042283e7..9c15e5143 100644 --- a/applications/DeepSpeed-Chat/training/utils/ds_utils.py +++ b/applications/DeepSpeed-Chat/training/utils/ds_utils.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import torch import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -28,15 +27,10 @@ def get_train_ds_config(offload, device = "cpu" if offload else "none" if dtype == "fp16": data_type = "fp16" - dtype_config = { - "enabled": True, - "loss_scale_window": 100 - } + dtype_config = {"enabled": True, "loss_scale_window": 100} elif dtype == "bf16": data_type = "bfloat16" - dtype_config = { - "enabled": True - } + dtype_config = {"enabled": True} zero_opt_dict = { "stage": stage, "offload_param": { @@ -85,13 +79,11 @@ def get_eval_ds_config(offload, dtype, stage=0): if dtype == "fp16": data_type = "fp16" dtype_config = { - "enabled": True, - } + "enabled": True, + } elif dtype == "bf16": data_type = "bfloat16" - dtype_config = { - "enabled": True - } + dtype_config = {"enabled": True} zero_opt_dict = { "stage": stage, "stage3_param_persistence_threshold": 1e4, diff --git a/inference/huggingface/zero_inference/run_model.py b/inference/huggingface/zero_inference/run_model.py index 5aa28fd7f..fea8e0be1 100644 --- a/inference/huggingface/zero_inference/run_model.py +++ b/inference/huggingface/zero_inference/run_model.py @@ -20,14 +20,28 @@ BloomForCausalLM, OPTForCausalLM, LlamaForCausalLM, ) from transformers.deepspeed import HfDeepSpeedConfig -from utils import (GB, add_model_hooks, cache_bytes, disable_torch_init, +from utils import (GB, add_model_hooks, cache_bytes, get_filename, get_quant_config, hidden_bytes, meta_to_cpu, model_bytes, write_benchmark_log) from packaging import version - assert version.parse(deepspeed.__version__) >= version.parse("0.10.3"), "ZeRO-Inference with weight quantization and kv cache offloading is available only in DeepSpeed 0.10.3+, please upgrade DeepSpeed" +def get_tokenizer(model_name, config): + if config.model_type == "opt": + # opt175b is not available on HF (at this time), + # so as a hack we use opt66b which has similar tokenizer. + tokenizer = AutoTokenizer.from_pretrained( + model_name.replace("175b", "66b"), + padding_side="left" + ) + else: + tokenizer = AutoTokenizer.from_pretrained(model_name) + + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + def get_model_config(model_name): if "175b" in model_name: config = AutoConfig.from_pretrained("facebook/opt-66b") @@ -46,7 +60,6 @@ def get_model_config(model_name): def get_ds_model( model_name, - dtype, cpu_offload, disk_offload, offload_dir, @@ -58,9 +71,13 @@ def get_ds_model( config = get_model_config(model_name) hidden_size = config.hidden_size deepspeed.init_distributed("nccl") - rank = dist.get_rank() pin_memory = bool(args.pin_memory) + if getattr(config, 'torch_dtype', None) is None: + dtype = torch.float16 + else: + dtype = config.torch_dtype + ds_config = { "fp16": { "enabled": dtype == torch.float16, @@ -155,32 +172,12 @@ def run_generation( quant_group_size, pin_kv_cache, async_kv_offload, + loops, ): # Load tokenizer - config = get_model_config(model_name) - return_token_type_ids = True - padding_side = "left" if config.model_type in ["opt"] else "right" + config = get_model_config(model_name) - if config.model_type == "opt": - tokenizer = AutoTokenizer.from_pretrained( - model_name.replace("175b", "66b"), - return_token_type_ids=return_token_type_ids, - padding_side=padding_side - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - model_name, - return_token_type_ids=return_token_type_ids, - padding_side=padding_side - ) - - - tokenizer.pad_token = tokenizer.eos_token - - if hasattr(config, 'torch_dtype'): - dtype = config.torch_dtype - else: - dtype = torch.float + tokenizer = get_tokenizer(model_name, config) if dummy: filename = os.path.join( @@ -208,7 +205,6 @@ def run_generation( with torch.no_grad(): model = get_ds_model( model_name, - dtype, cpu_offload, disk_offload, offload_dir, @@ -221,14 +217,14 @@ def run_generation( execute_gen_len = gen_len prompts = ["Paris is the capital city of"] * (batch_size // dist.get_world_size()) - def _batch_encode(prompts, return_token_type_ids): - input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding="max_length", max_length=prompt_len, return_token_type_ids=return_token_type_ids) + def _batch_encode(prompts): + input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding="max_length", max_length=prompt_len) for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) return input_tokens - input_tokens = _batch_encode(prompts, return_token_type_ids) + input_tokens = _batch_encode(prompts) if kv_offload: model.set_kv_cache_offload(True, gen_len, pin_kv_cache, async_kv_offload) @@ -247,11 +243,10 @@ def set_model_stage(model, stage): generate_kwargs = dict(max_new_tokens=execute_gen_len, do_sample=False) prefill_timings = [] timer = timers("generate-forward") - for _ in range(2): + for _ in range(loops): timer.start(sync_func=get_accelerator().synchronize) with torch.no_grad(): set_model_stage(model, "prefill") - # output_ids = model.generate(input_ids=input_ids, **generate_kwargs) output_ids = model.generate(**input_tokens, **generate_kwargs) prefill_timings.append(model.__duration__) timer.stop(sync_func=get_accelerator().synchronize) @@ -343,6 +338,7 @@ def remove_model_hooks(module): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="facebook/opt-1.3b", help="model name or path; currently only supports OPT and BLOOM models") parser.add_argument("--dummy", action="store_true", help="Use dummy weights for benchmark purposes.") + parser.add_argument("--loops", type=int, default=3, help="Number of token generation iterations") parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--prompt-len", type=int, default=512, help="prompt length") parser.add_argument("--gen-len", type=int, default=32, help="number of tokens to generate") @@ -383,4 +379,5 @@ def remove_model_hooks(module): args.quant_group_size, args.pin_kv_cache, args.async_kv_offload, + args.loops )