Skip to content

smooth-quant with tp=2, and build llama 7b with tp=2, pp=2 failed #267

@forrestjgq

Description

@forrestjgq

I wish to build llama-7b-hf model with tp size 2 and pp size 2, with smooth-quant, here is the processing:

python3 hf_llama_convert.py -i /data/jgq/lmmodels/7b/ \
                                                  -o /data/jgq/lmmodels/trt_engines/7b/sq0.8/2-gpu \
                                                  -sq 0.8 \
                                                  --tensor-parallelism 2 \
                                                  --storage-type fp16

then I build it with:

python build.py --ft_model_dir /data/jgq/lmmodels/trt_engines/7b/sq0.8/2-gpu 
--dtype float16 \
--remove_input_padding \
--enable_context_fmha \
--output_dir /data/jgq/lmmodels/trt_engines/7b/sq/b32_i1024_o2048/4-gpu \
--max_batch_size 32 \
--max_input_len 1024 \
--max_output_len 2048 \
--world_size 4 \
--tp_size 2 \
--pp_size 2 \
--use_smooth_quant \
--per_token \
--per_channel

this will lead to a failure:

[11/03/2023-07:15:08] [TRT-LLM] [I] Loading weights from FT...
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /data/jgq/trtllm/examples/llama/build.py:718 in <module>                                         │
│                                                                                                  │
│   715 │   else:                                                                                  │
│   716 │   │   args.parallel_build = False                                                        │
│   717 │   │   logger.info('Serially build TensorRT engines.')                                    │
│ ❱ 718 │   │   build(0, args)                                                                     │
│   719 │                                                                                          │
│   720 │   tok = time.time()                                                                      │
│   721 │   t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))                                  │
│                                                                                                  │
│ /data/jgq/trtllm/examples/llama/build.py:689 in build                                            │
│                                                                                                  │
│   686 │   │   │   opt_level=args.builder_opt)                                                    │
│   687 │   │   engine_name = get_engine_name(MODEL_NAME, args.dtype, args.tp_size,                │
│   688 │   │   │   │   │   │   │   │   │     args.pp_size, cur_rank)                              │
│ ❱ 689 │   │   engine = build_rank_engine(builder, builder_config, engine_name,                   │
│   690 │   │   │   │   │   │   │   │      cur_rank, args)                                         │
│   691 │   │   assert engine is not None, f'Failed to build engine for rank {cur_rank}'           │
│   692                                                                                            │
│                                                                                                  │
│ /data/jgq/trtllm/examples/llama/build.py:569 in build_rank_engine                                │
│                                                                                                  │
│   566 │   │   │   │   │   │      dtype=args.dtype)                                               │
│   567 │   │   del hf_llama                                                                       │
│   568 │   elif args.ft_model_dir is not None:                                                    │
│ ❱ 569 │   │   load_from_binary(tensorrt_llm_llama,                                               │
│   570 │   │   │   │   │   │    args.ft_model_dir,                                                │
│   571 │   │   │   │   │   │    mapping,                                                          │
│   572 │   │   │   │   │   │    fp16=(args.dtype == 'float16'),                                   │
│                                                                                                  │
│ /data/jgq/trtllm/examples/llama/weight.py:671 in load_from_binary                                │
│                                                                                                  │
│    668 │   │   │   │   n_embd // mapping.tp_size +                                               │
│    669 │   │   │   │   (n_embd // n_head * n_groups) // mapping.tp_size * 2)                     │
│    670 │   │   idx = i - mapping.pp_rank * tensorrt_llm_llama.num_layers                         │
│ ❱  671 │   │   tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = (fromfile(          │
│    672 │   │   │   dir_path, 'model.layers.' + str(i) + '.input_layernorm.weight.bin'))          │
│    673 │   │   t = fromfile(                                                                     │
│    674 │   │   │   dir_path, 'model.layers.' + str(i) +                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tensorrt_llm/module.py:171 in __getitem__                │
│                                                                                                  │
│   168 │   │   if isinstance(idx, slice):                                                         │
│   169 │   │   │   return self.__class__(list(self._modules.values())[idx])                       │
│   170 │   │   else:                                                                              │
│ ❱ 171 │   │   │   return self._modules[self._get_abs_string_index(idx)]                          │
│   172 │                                                                                          │
│   173 │   def __setitem__(self, idx, module) -> None:                                            │
│   174 │   │   idx = self._get_abs_string_index(idx)                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/tensorrt_llm/module.py:162 in _get_abs_string_index      │
│                                                                                                  │
│   159 │   │   """Get the absolute index for the list of modules"""                               │
│   160 │   │   idx = operator.index(idx)                                                          │
│   161 │   │   if not (-len(self) <= idx < len(self)):                                            │
│ ❱ 162 │   │   │   raise IndexError('index {} is out of range'.format(idx))                       │
│   163 │   │   if idx < 0:                                                                        │
│   164 │   │   │   idx += len(self)                                                               │
│   165 │   │   return str(idx)                                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: index 16 is out of range

Note that if pipeline parallelism is not applied(pp_size=1), this will succeeds.

Metadata

Metadata

Assignees

Labels

Low PrecisionLower-precision formats (INT8/INT4/FP8) for TRTLLM quantization (AWQ, GPTQ).triagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions