Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Llama3-8B-Instruct model #4

Open
KwokhoTsui opened this issue Sep 28, 2024 · 2 comments
Open

Support for Llama3-8B-Instruct model #4

KwokhoTsui opened this issue Sep 28, 2024 · 2 comments

Comments

@KwokhoTsui
Copy link

KwokhoTsui commented Sep 28, 2024

Great work!

I am trying to run pyramidinfer with a Llama3-8B-Instruct model, but it seems that the version of "transformers" is too old to load the weight of Llama3-8B model.

I ran this command to run a demo for Llama3-8B:

python simple_infer_comparison.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --pyramid_config configs/llama3_8b.json

and I got this warning:

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:07<00:00,  1.83s/it]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /xxxxx/xxxx/Meta-Llama-3-8B-Instruct and are newly initialized: ['model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Pyramidinfer Model GPU Memory Per GPU (MB):  16446.523

Could you please provide more details on how to run pyramidinfer with a Llama3-8B model?

@mutonix
Copy link
Owner

mutonix commented Oct 17, 2024

Thanks for your attention and sorry for the delayed reply! Since the versions of transformers iterate quickly, if you want to use PyramidInfer with a new version of transformers, you'll need to modify the KV cache part (which differs significantly from older versions). After you have acquired the KV cache, you can modify the source code here.

# main implementation for pyramidinfer
if use_cache:
next_decoder_cache += [layer_outputs[2]]
with torch.no_grad():
### prefilling stage ###
if past_key_values is None:
# determine the decay ratio schedule
if prefill_decay_strategy == "linear":
schedule_prefill_decay_ratio = (1.0 - prefill_decay_ratio) * (idx / self.config.num_hidden_layers) + prefill_decay_ratio
if prefill_decay_strategy == "cosine":
schedule_prefill_decay_ratio = (1.0 - prefill_decay_ratio) * (math.cos(math.pi * idx / self.config.num_hidden_layers) + 1) / 2 + prefill_decay_ratio
else:
schedule_prefill_decay_ratio = prefill_decay_ratio
if (idx % layerwise_downsample_interval) == 0:
attn_weights = layer_outputs[1].mean(dim=1) # average over attention heads
recent2context_attn_weights = attn_weights[:, -(1 + recent_length):, :-(1 + recent_length)]
recent2context_attn_weights *= torch.linspace(1.0, distance_weight, recent2context_attn_weights.shape[1], device=recent2context_attn_weights.device)[None, :, None] # weight the recent2context attention weights by distance
recent2context_attn_weights = recent2context_attn_weights.mean(dim=-2)
recent2context_attn_weights[:, :streamingllm_sink_len] = torch.finfo(recent2context_attn_weights.dtype).max # always keep the sink tokens
context_length = recent2context_attn_weights.shape[-1]
if context_length > min_context_length and schedule_prefill_decay_ratio < 1.0:
topk = int(context_length * schedule_prefill_decay_ratio) if int(context_length * schedule_prefill_decay_ratio) > min_context_length else context_length
recent2context_topk_indices = torch.topk(recent2context_attn_weights, topk, dim=-1, largest=True, sorted=False).indices.sort(dim=-1).values
### slower version ###
# bottomk = context_length - topk
# recent2context_bottomk_indices = torch.topk(recent2context_attn_weights, bottomk, dim=-1, largest=False).indices
# recent2context_topk_indices_mask = recent2context_attn_weights.fill_(1.0).scatter_(-1, recent2context_bottomk_indices, 0.0).bool()
# recent2context_topk_indices = position_ids[:, :context_length][recent2context_topk_indices_mask].view(batch_size, -1)
# gather the original position ids for the selected topk indices
selected_position_ids = selected_position_ids.to(recent2context_topk_indices.device)
selected_position_ids = torch.cat([
torch.gather(selected_position_ids[:, :-(1 + recent_length)], dim=-1, index=recent2context_topk_indices),
selected_position_ids[:, -(1 + recent_length):],
], dim=-1)
compressed_hidden_states = torch.gather(hidden_states[:, :-(1 + recent_length)], dim=-2, index=recent2context_topk_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]))
hidden_states = torch.cat([compressed_hidden_states, hidden_states[:, -(1 + recent_length):]], dim=1)
if past_kv_seq_lens is not None:
past_kv_seq_lens.append(layer_outputs[2][0].shape[-2])
if recent_attn_weights is not None:
recent_attn_weights.append(attn_weights[:, -(1 + recent_length):])
else:
### generation stage ###
attn_weights = layer_outputs[1].mean(dim=1) # average over attention heads
recent_attn_weights[idx] = torch.cat([recent_attn_weights[idx].to(attn_weights.device), torch.zeros((recent_attn_weights[idx].shape[0], recent_attn_weights[idx].shape[1], 1), device=attn_weights.device)], dim=-1)
attn_weights = torch.cat([recent_attn_weights[idx], attn_weights], dim=-2)
past_kv_seq_len = past_kv_seq_lens[idx]
current_kv_seq_len = next_decoder_cache[-1][0].shape[-2]
if gen_decay_strategy == "linear":
schedule_gen_decay_ratio = (1.0 - gen_decay_ratio) * (idx / self.config.num_hidden_layers) + gen_decay_ratio
if gen_decay_strategy == "cosine":
schedule_gen_decay_ratio = (1.0 - gen_decay_ratio) * (math.cos(math.pi * idx / self.config.num_hidden_layers) + 1) / 2 + gen_decay_ratio
else:
schedule_gen_decay_ratio = gen_decay_ratio
if current_kv_seq_len - recent_length - past_kv_seq_len >= exceed_length_to_compress:
recent2context_attn_weights = attn_weights[:, -(1 + recent_length):, -(1 + recent_length + exceed_length_to_compress):-(1 + recent_length)]
recent2context_attn_weights *= torch.linspace(1.0, distance_weight, recent2context_attn_weights.shape[1], device=recent2context_attn_weights.device)[None, :, None]
recent2context_attn_weights = recent2context_attn_weights.mean(dim=-2)
context_length = recent2context_attn_weights.shape[-1] * gen_compress_ratio
topk = max(int(context_length * schedule_gen_decay_ratio), 1)
recent2context_topk_indices = torch.topk(
recent2context_attn_weights,
topk,
dim=-1, largest=True
).indices.sort(dim=-1).values
key_states, value_states = next_decoder_cache[-1]
# gather key_states from recent2context_topk_indices
gather_indices = recent2context_topk_indices[:, None, :, None].expand(-1, key_states.shape[1], -1, key_states.shape[3])
key_states = torch.cat([
key_states[:, :, :-(1 + recent_length + exceed_length_to_compress)],
torch.gather(key_states[:, :, -(1 + recent_length + exceed_length_to_compress):-(1 + recent_length)], dim=-2, index=gather_indices),
key_states[:, :, -(1 + recent_length):],
], dim=-2)
value_states = torch.cat([
value_states[:, :, :-(1 + recent_length + exceed_length_to_compress)],
torch.gather(value_states[:, :, -(1 + recent_length + exceed_length_to_compress):-(1 + recent_length)], dim=-2, index=gather_indices),
value_states[:, :, -(1 + recent_length):],
], dim=-2)
# attention weights [bsz, 1 + recent_length, seq_len]
attn_weights = torch.cat([
attn_weights[:, :, :-(1 + recent_length + exceed_length_to_compress)],
torch.gather(attn_weights[:, :, -(1 + recent_length + exceed_length_to_compress):-(1 + recent_length)], dim=-1, index=recent2context_topk_indices[:, None, :].expand(-1, attn_weights.shape[1], -1)),
attn_weights[:, :, -(1 + recent_length):],
], dim=-1)
next_decoder_cache[-1] = (key_states, value_states)
past_kv_seq_lens[idx] = key_states.shape[-2] - recent_length
recent_attn_weights[idx] = attn_weights[:, -(1 + recent_length):]

@mutonix mutonix closed this as completed Nov 1, 2024
@mutonix
Copy link
Owner

mutonix commented Nov 19, 2024

we have supported transformers 4.46.3 now.

@mutonix mutonix reopened this Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants