Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix SmoothQuantGatedMLP ffn_hidden_size bug #1712

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def smooth_quantize_plugin(model, quant_mode):
elif isinstance(layer.mlp, MLP):
mlp_norm_cls = SmoothQuantMLP

mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size
mlp_hidden_size = layer.mlp.ffn_hidden_size
layer.mlp = mlp_norm_cls(hidden_size=config.hidden_size,
ffn_hidden_size=mlp_hidden_size,
hidden_act=config.hidden_act,
Expand Down
190 changes: 181 additions & 9 deletions tensorrt_llm/runtime/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,14 @@ def __init__(self,
debug_mode=False,
debug_tensors_to_save=None,
cuda_graph_mode=False,
stream: torch.cuda.Stream = None):
stream: torch.cuda.Stream = None,
use_cached_tensor=False,
cached_max_batch_size=None,
cached_max_context_length=None,
cached_max_tokens=None,
cached_beam_width=1,
cached_sink_token_length=0,
cached_max_attention_window_size=None):
assert isinstance(model_config, ModelConfig)
self._model_config = model_config
self.mapping = mapping
Expand Down Expand Up @@ -827,6 +834,21 @@ def __init__(self,
self.debug_tensors = list(
set(found_tensor_names) - set(expected_tensor_names))

self.cached_buffer = {}
self.use_cached_tensor = use_cached_tensor
self.cached_max_batch_size = cached_max_batch_size
self.cached_max_context_length = cached_max_context_length
self.cached_max_tokens = cached_max_tokens
self.cached_beam_width = cached_beam_width
self.cached_sink_token_length = cached_sink_token_length
self.cached_max_attention_window_size = cached_max_attention_window_size
if use_cached_tensor:
if cached_max_batch_size is None:
self.cached_max_batch_size = model_config.max_batch_size
assert self.cached_max_context_length is not None, "cached_max_context_length can not be None"
assert self.cached_max_tokens is not None, "cached_max_tokens can not be None"
self._init_cache_buffer()

@property
def context_mem_size(self) -> int:
return self.runtime.context_mem_size
Expand Down Expand Up @@ -1364,6 +1386,156 @@ def _get_num_paged_blocks(self, max_attention_window_size,

return num_blocks, max_blocks_per_seq

def _init_cache_buffer(self):
if not self.use_cached_tensor:
return
batch_size = self.cached_max_batch_size
max_context_length = self.cached_max_context_length
max_tokens = self.cached_max_tokens
beam_width = self.cached_beam_width
max_attention_window_size = max_context_length \
if self.cached_max_attention_window_size is None else self.cached_max_attention_window_size
if self.mapping.is_last_pp_rank():
if self.is_medusa_mode:
self.cached_buffer['logits'] = torch.empty(
(batch_size * (self.num_medusa_tokens + 1) * self.vocab_size_padded)
if not self.gather_context_logits else
(max_tokens * self.vocab_size_padded),
dtype=self._tensor_dtype('logits'),
device=self.device)
medusa_logits_shape = (self.num_medusa_heads * batch_size * (self.num_medusa_tokens + 1)
* self.vocab_size_padded)
if self.remove_input_padding:
medusa_logits_shape = (self.num_medusa_heads * batch_size *
(self.num_medusa_tokens + 1) * self.vocab_size_padded)

self.cached_buffer['medusa_logits'] = torch.empty(
medusa_logits_shape if not self.gather_context_logits else
(self.num_medusa_heads * max_tokens * self.vocab_size_padded),
dtype=self._tensor_dtype('medusa_logits'),
device=self.device)
else:
self.cached_buffer['logits'] = torch.empty(
(batch_size * self.vocab_size_padded)
if not self.gather_context_logits else
(max_tokens * self.vocab_size_padded),
dtype=self._tensor_dtype('logits'),
device=self.device)

cross_cache_shape = [1]
sink_token_length = self.cached_sink_token_length

if self.quant_mode.has_kv_cache_quant():
# Since torch does not support fp8 now, using int8 here.
kv_cache_type = torch.int8
else:
if self.has_attn_layers:
first_atten_layer = self.layer_types.index('attention')
kv_cache_type = self.dtype if self.paged_kv_cache else self._tensor_dtype(
f'present_key_value_{first_atten_layer}')
else:
kv_cache_type = None

self.use_one_more_block = (
self.paged_kv_cache and beam_width > 1
and max_context_length > max_attention_window_size)
if self.paged_kv_cache and self.has_attn_layers:
num_blocks, _ = self._get_num_paged_blocks(
max_attention_window_size, sink_token_length,
self.use_one_more_block)
cache_shape = (
num_blocks,
self.num_attn_layers,
2,
self.num_heads_kv,
self.tokens_per_block,
self.head_size,
)
self.cached_buffer["kv_cache_pool"] = torch.empty(cache_shape,
dtype=kv_cache_type,
device=self.device)
if self.cross_attention: # As for now we enable cross paged kv and self paged kv to share the same tokens_per_block
cross_num_blocks, _ = self._get_num_paged_blocks(
self.encoder_max_input_length,
sink_token_length=0,
use_one_more_block=False)
cross_cache_shape = (
cross_num_blocks,
self.num_layers,
2,
self.num_heads_kv,
self.tokens_per_block,
self.head_size,
)
self.cached_buffer["cross_kv_cache_pool"] = torch.empty(cross_cache_shape,
dtype=kv_cache_type,
device=self.device)
elif self.has_attn_layers:
cache_shape = (
batch_size,
2,
self.num_heads_kv,
self.max_attention_window_size,
self.head_size,
)
for i in range(self.first_layer, self.last_layer):
if self.layer_types[i] == 'attention':
self.cached_buffer[f'present_key_value_{i}'] = torch.empty(cache_shape, dtype=kv_cache_type, device=self.device)

if self.cross_attention:
cross_cache_shape = (
batch_size,
2,
self.num_heads_kv,
self.encoder_max_input_length,
self.head_size,
)
for i in range(self.first_layer, self.last_layer):
if self.layer_types[i] == 'attention':
self.cached_buffer[
f'cross_present_key_value_{i}'] = torch.empty(
cross_cache_shape,
dtype=kv_cache_type,
device=self.device)

if not self.use_gpt_attention_plugin:
# without plugin, we need two set of kv cache buffers,
# one for inputs, and the other for outputs.
# They will take turns to act as input and output buffers.
# Not applicable to cross KV buffers as it's constant
for i in range(self.first_layer, self.last_layer):
trt_dtype = self.runtime.engine.get_tensor_dtype(
f'present_key_value_{i}')
if trt_dtype == trt.fp8:
# PyTorch doesn't support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8.
# TODO: Remove this section when PyTorch support fp8 datatype
dtype = torch.int8
else:
dtype = self._tensor_dtype(f'present_key_value_{i}')
self.cached_buffer[f'1_present_key_value_{i}'] = torch.empty(
cache_shape, dtype=dtype, device=self.device)

for key in list(self.cached_buffer.keys()):
self.cached_buffer[key] = self.cached_buffer[key].flatten()

def _get_cached_buffer(self, key, shape=None, dtype=None, device=None):
assert shape is not None, "shape is None"
assert dtype is not None, "dtype is None"
assert device is not None, "device is not None"
if not self.use_cached_tensor or key not in self.cached_buffer:
return torch.empty(shape, dtype=dtype, device=device)
cached_tensor = self.cached_buffer[key]
if cached_tensor.dtype != dtype:
raise RuntimeError(f"cached_tensor dtype is not match, key:{key}, "
f"cached_tensor.dtype:{cached_tensor.dtype}, expected_dtype:{dtype}")
from functools import reduce
expected_size = reduce(lambda x, y: x * y, shape)
cache_size = reduce(lambda x, y: x * y, cached_tensor.shape)
if expected_size > cache_size:
raise RuntimeError(f"cache_tensor.size is not enough, key:{key}, shape:{shape}, "
f"cached_size:{cache_size}, expected_size:{expected_size}")
return cached_tensor[:expected_size].reshape(shape)

def setup(self,
batch_size: int,
max_context_length: int,
Expand Down Expand Up @@ -1448,7 +1620,7 @@ def setup(self,
self.buffer = {}
if self.mapping.is_last_pp_rank():
if self.is_medusa_mode:
self.buffer['logits'] = torch.empty(
self.buffer['logits'] = self._get_cached_buffer("logits",
(batch_size, self.num_medusa_tokens + 1,
self.vocab_size_padded)
if not self.gather_context_logits else
Expand All @@ -1463,14 +1635,14 @@ def setup(self,
(self.num_medusa_tokens + 1),
self.vocab_size_padded)

self.buffer['medusa_logits'] = torch.empty(
self.buffer['medusa_logits'] = self._get_cached_buffer('medusa_logits',
medusa_logits_shape if not self.gather_context_logits else
(self.num_medusa_heads, batch_size, max_context_length,
self.vocab_size_padded),
dtype=self._tensor_dtype('medusa_logits'),
device=self.device)
else:
self.buffer['logits'] = torch.empty(
self.buffer['logits'] = self._get_cached_buffer('logits',
(batch_size, self.vocab_size_padded)
if not self.gather_context_logits else
(batch_size, max_context_length, self.vocab_size_padded),
Expand Down Expand Up @@ -1507,7 +1679,7 @@ def setup(self,
self.tokens_per_block,
self.head_size,
)
self.kv_cache_pool = torch.empty(cache_shape,
self.kv_cache_pool = self._get_cached_buffer("kv_cache_pool", cache_shape,
dtype=kv_cache_type,
device=self.device)
if self.cross_attention: # As for now we enable cross paged kv and self paged kv to share the same tokens_per_block
Expand All @@ -1523,7 +1695,7 @@ def setup(self,
self.tokens_per_block,
self.head_size,
)
self.cross_kv_cache_pool = torch.empty(cross_cache_shape,
self.cross_kv_cache_pool = self._get_cached_buffer("cross_kv_cache_pool", cross_cache_shape,
dtype=kv_cache_type,
device=self.device)
elif self.has_attn_layers:
Expand All @@ -1536,7 +1708,7 @@ def setup(self,
)
for i in range(self.first_layer, self.last_layer):
if self.layer_types[i] == 'attention':
self.buffer[f'present_key_value_{i}'] = torch.empty(
self.buffer[f'present_key_value_{i}'] = self._get_cached_buffer(f'present_key_value_{i}',
cache_shape, dtype=kv_cache_type, device=self.device)

if self.cross_attention:
Expand All @@ -1550,7 +1722,7 @@ def setup(self,
for i in range(self.first_layer, self.last_layer):
if self.layer_types[i] == 'attention':
self.buffer[
f'cross_present_key_value_{i}'] = torch.empty(
f'cross_present_key_value_{i}'] = self._get_cached_buffer(f'cross_present_key_value_{i}',
cross_cache_shape,
dtype=kv_cache_type,
device=self.device)
Expand All @@ -1574,7 +1746,7 @@ def setup(self,
dtype = torch.int8
else:
dtype = self._tensor_dtype(f'present_key_value_{i}')
self.buffer[f'1_present_key_value_{i}'] = torch.empty(
self.buffer[f'1_present_key_value_{i}'] = self._get_cached_buffer(f'1_present_key_value_{i}',
cache_shape, dtype=dtype, device=self.device)
if os.getenv('TRTLLM_DISABLE_OOTB_KVCACHE_REUSE') != 'ON':
# We can do reuse between different layers' inputs and outputs, i.e. current layer's output can
Expand Down